|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from collections import namedtuple | 
					
						
						|  | from typing import Callable, Dict, Optional, List, Union | 
					
						
						|  |  | 
					
						
						|  | from timm.models import VisionTransformer | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from transformers import PretrainedConfig, PreTrainedModel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from .common import RESOURCE_MAP, DEFAULT_VERSION | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput | 
					
						
						|  | from .adaptor_generic import GenericAdaptor, AdaptorBase | 
					
						
						|  | from .adaptor_mlp import create_mlp_from_config | 
					
						
						|  | from .adaptor_registry import adaptor_registry | 
					
						
						|  | from .cls_token import ClsToken | 
					
						
						|  | from .enable_cpe_support import enable_cpe | 
					
						
						|  | from .enable_spectral_reparam import configure_spectral_reparam_from_args | 
					
						
						|  | from .eradio_model import eradio | 
					
						
						|  | from .radio_model import create_model_from_args | 
					
						
						|  | from .radio_model import RADIOModel as RADIOModelBase, Resolution | 
					
						
						|  | from .input_conditioner import get_default_conditioner, InputConditioner | 
					
						
						|  | from .open_clip_adaptor import OpenCLIP_RADIO | 
					
						
						|  | from .vit_patch_generator import ViTPatchGenerator | 
					
						
						|  | from .vitdet import apply_vitdet_arch, VitDetArgs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from .extra_timm_models import * | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RADIOConfig(PretrainedConfig): | 
					
						
						|  | """Pretrained Hugging Face configuration for RADIO models.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | args: Optional[dict] = None, | 
					
						
						|  | version: Optional[str] = DEFAULT_VERSION, | 
					
						
						|  | patch_size: Optional[int] = None, | 
					
						
						|  | max_resolution: Optional[int] = None, | 
					
						
						|  | preferred_resolution: Optional[Resolution] = None, | 
					
						
						|  | adaptor_names: Union[str, List[str]] = None, | 
					
						
						|  | adaptor_configs: Dict[str, Dict[str, int]] = None, | 
					
						
						|  | vitdet_window_size: Optional[int] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | self.args = args | 
					
						
						|  | for field in ["dtype", "amp_dtype"]: | 
					
						
						|  | if self.args is not None and field in self.args: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.args[field] = str(args[field]).split(".")[-1] | 
					
						
						|  | self.version = version | 
					
						
						|  | resource = RESOURCE_MAP[version] | 
					
						
						|  | self.patch_size = patch_size or resource.patch_size | 
					
						
						|  | self.max_resolution = max_resolution or resource.max_resolution | 
					
						
						|  | self.preferred_resolution = ( | 
					
						
						|  | preferred_resolution or resource.preferred_resolution | 
					
						
						|  | ) | 
					
						
						|  | self.adaptor_names = adaptor_names | 
					
						
						|  | self.adaptor_configs = adaptor_configs | 
					
						
						|  | self.vitdet_window_size = vitdet_window_size | 
					
						
						|  | super().__init__(**kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RADIOModel(PreTrainedModel): | 
					
						
						|  | """Pretrained Hugging Face model for RADIO. | 
					
						
						|  |  | 
					
						
						|  | This class inherits from PreTrainedModel, which provides | 
					
						
						|  | HuggingFace's functionality for loading and saving models. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | config_class = RADIOConfig | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: RADIOConfig): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) | 
					
						
						|  | args = RADIOArgs(**config.args) | 
					
						
						|  | self.config = config | 
					
						
						|  |  | 
					
						
						|  | model = create_model_from_args(args) | 
					
						
						|  | input_conditioner: InputConditioner = get_default_conditioner() | 
					
						
						|  |  | 
					
						
						|  | dtype = getattr(args, "dtype", torch.float32) | 
					
						
						|  | if isinstance(dtype, str): | 
					
						
						|  |  | 
					
						
						|  | dtype = getattr(torch, dtype) | 
					
						
						|  | model.to(dtype=dtype) | 
					
						
						|  | input_conditioner.dtype = dtype | 
					
						
						|  |  | 
					
						
						|  | summary_idxs = torch.tensor( | 
					
						
						|  | [i for i, t in enumerate(args.teachers) if t.get("use_summary", True)], | 
					
						
						|  | dtype=torch.int64, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | adaptor_configs = config.adaptor_configs | 
					
						
						|  | adaptor_names = config.adaptor_names or [] | 
					
						
						|  |  | 
					
						
						|  | adaptors = dict() | 
					
						
						|  | for adaptor_name in adaptor_names: | 
					
						
						|  | mlp_config = adaptor_configs[adaptor_name] | 
					
						
						|  | adaptor = GenericAdaptor(args, None, None, mlp_config) | 
					
						
						|  | adaptor.head_idx = mlp_config["head_idx"] | 
					
						
						|  | adaptors[adaptor_name] = adaptor | 
					
						
						|  |  | 
					
						
						|  | self.radio_model = RADIOModelBase( | 
					
						
						|  | model, | 
					
						
						|  | input_conditioner, | 
					
						
						|  | summary_idxs=summary_idxs, | 
					
						
						|  | patch_size=config.patch_size, | 
					
						
						|  | max_resolution=config.max_resolution, | 
					
						
						|  | window_size=config.vitdet_window_size, | 
					
						
						|  | preferred_resolution=config.preferred_resolution, | 
					
						
						|  | adaptors=adaptors, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def adaptors(self) -> nn.ModuleDict: | 
					
						
						|  | return self.radio_model.adaptors | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def model(self) -> VisionTransformer: | 
					
						
						|  | return self.radio_model.model | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def input_conditioner(self) -> InputConditioner: | 
					
						
						|  | return self.radio_model.input_conditioner | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def num_summary_tokens(self) -> int: | 
					
						
						|  | return self.radio_model.num_summary_tokens | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def patch_size(self) -> int: | 
					
						
						|  | return self.radio_model.patch_size | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def max_resolution(self) -> int: | 
					
						
						|  | return self.radio_model.max_resolution | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def preferred_resolution(self) -> Resolution: | 
					
						
						|  | return self.radio_model.preferred_resolution | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def window_size(self) -> int: | 
					
						
						|  | return self.radio_model.window_size | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def min_resolution_step(self) -> int: | 
					
						
						|  | return self.radio_model.min_resolution_step | 
					
						
						|  |  | 
					
						
						|  | def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: | 
					
						
						|  | return self.radio_model.make_preprocessor_external() | 
					
						
						|  |  | 
					
						
						|  | def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: | 
					
						
						|  | return self.radio_model.get_nearest_supported_resolution(height, width) | 
					
						
						|  |  | 
					
						
						|  | def switch_to_deploy(self): | 
					
						
						|  | return self.radio_model.switch_to_deploy() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor): | 
					
						
						|  | return self.radio_model.forward(x) | 
					
						
						|  |  |