|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  | from typing import Dict, Any | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from .adaptor_generic import GenericAdaptor, AdaptorBase | 
					
						
						|  |  | 
					
						
						|  | dict_t = Dict[str, Any] | 
					
						
						|  | state_t = Dict[str, torch.Tensor] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AdaptorRegistry: | 
					
						
						|  | def __init__(self): | 
					
						
						|  | self._registry = {} | 
					
						
						|  |  | 
					
						
						|  | def register_adaptor(self, name): | 
					
						
						|  | def decorator(factory_function): | 
					
						
						|  | if name in self._registry: | 
					
						
						|  | raise ValueError(f"Model '{name}' already registered") | 
					
						
						|  | self._registry[name] = factory_function | 
					
						
						|  | return factory_function | 
					
						
						|  | return decorator | 
					
						
						|  |  | 
					
						
						|  | def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: | 
					
						
						|  | if name not in self._registry: | 
					
						
						|  | return GenericAdaptor(main_config, adaptor_config, state) | 
					
						
						|  | return self._registry[name](main_config, adaptor_config, state) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | adaptor_registry = AdaptorRegistry() | 
					
						
						|  |  |