Make HF interface compatible
#4
by
mranzinger
- opened
- hf_model.py +30 -6
hf_model.py
CHANGED
|
@@ -124,14 +124,38 @@ class RADIOModel(PreTrainedModel):
|
|
| 124 |
def input_conditioner(self) -> InputConditioner:
|
| 125 |
return self.radio_model.input_conditioner
|
| 126 |
|
| 127 |
-
@
|
| 128 |
-
def
|
| 129 |
-
self.radio_model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
def forward(self, x: torch.Tensor):
|
| 137 |
return self.radio_model.forward(x)
|
|
|
|
| 124 |
def input_conditioner(self) -> InputConditioner:
|
| 125 |
return self.radio_model.input_conditioner
|
| 126 |
|
| 127 |
+
@property
|
| 128 |
+
def num_summary_tokens(self) -> int:
|
| 129 |
+
return self.radio_model.num_summary_tokens
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def patch_size(self) -> int:
|
| 133 |
+
return self.radio_model.patch_size
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def max_resolution(self) -> int:
|
| 137 |
+
return self.radio_model.max_resolution
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def preferred_resolution(self) -> Resolution:
|
| 141 |
+
return self.radio_model.preferred_resolution
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def window_size(self) -> int:
|
| 145 |
+
return self.radio_model.window_size
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def min_resolution_step(self) -> int:
|
| 149 |
+
return self.radio_model.min_resolution_step
|
| 150 |
|
| 151 |
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 152 |
+
return self.radio_model.make_preprocessor_external()
|
| 153 |
+
|
| 154 |
+
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
|
| 155 |
+
return self.radio_model.get_nearest_supported_resolution(height, width)
|
| 156 |
+
|
| 157 |
+
def switch_to_deploy(self):
|
| 158 |
+
return self.radio_model.switch_to_deploy()
|
| 159 |
|
| 160 |
def forward(self, x: torch.Tensor):
|
| 161 |
return self.radio_model.forward(x)
|