|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput | 
					
						
						|  | from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GenericAdaptor(AdaptorBase): | 
					
						
						|  | def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | if state is not None: | 
					
						
						|  | self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.') | 
					
						
						|  | self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.') | 
					
						
						|  | else: | 
					
						
						|  | assert mlp_config is not None, "Config must not be None if state is None" | 
					
						
						|  |  | 
					
						
						|  | self.head_mlp =  create_mlp_from_config( | 
					
						
						|  | main_config.mlp_version, | 
					
						
						|  | mlp_config["summary"]["input_dim"], | 
					
						
						|  | mlp_config["summary"]["hidden_dim"], | 
					
						
						|  | mlp_config["summary"]["output_dim"], | 
					
						
						|  | mlp_config["summary"]["num_inner"], | 
					
						
						|  | ) | 
					
						
						|  | self.feat_mlp = create_mlp_from_config( | 
					
						
						|  | main_config.mlp_version, | 
					
						
						|  | mlp_config["feature"]["input_dim"], | 
					
						
						|  | mlp_config["feature"]["hidden_dim"], | 
					
						
						|  | mlp_config["feature"]["output_dim"], | 
					
						
						|  | mlp_config["feature"]["num_inner"], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input: AdaptorInput) -> RadioOutput: | 
					
						
						|  |  | 
					
						
						|  | first_param = next(self.parameters()) | 
					
						
						|  | summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype) | 
					
						
						|  | feat = self.feat_mlp(input.features.to(dtype=first_param.dtype)).to(dtype=input.features.dtype) | 
					
						
						|  |  | 
					
						
						|  | if input.feature_fmt == 'NCHW': | 
					
						
						|  | feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size, input.images.shape[-1] // input.patch_size, feat.shape[2]) | 
					
						
						|  | .permute(0, 3, 1, 2) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return RadioOutput(summary, feat) | 
					
						
						|  |  |