from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class EmptyClass(PretrainedConfig): def __init__(self): pass class SDConfig(PretrainedConfig): def __init__(self, override_total_steps = -1, freeze_vae = True, use_flash = False, adapt_topk = -1, loss = 'mse', mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], use_same_noise_among_timesteps = False, random_timestep_per_iteration = True, rand_timestep_equal_int = False, output_dir = './outputs/First_Start', do_center_crop_size = 384, architectures = None, input = None, model = None, tta = None, **kwargs ): super().__init__() self.model = EmptyClass() self.model.override_total_steps = override_total_steps self.model.freeze_vae = freeze_vae self.model.use_flash = use_flash self.tta = EmptyClass() self.tta.gradient_descent = EmptyClass() self.tta.adapt_topk = adapt_topk self.tta.loss = loss self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps self.tta.random_timestep_per_iteration = random_timestep_per_iteration self.tta.rand_timestep_equal_int = rand_timestep_equal_int self.input = EmptyClass() self.input.mean = mean self.input.std = std self.output_dir = output_dir self.do_center_crop_size = do_center_crop_size self.architectures = architectures for k, v in kwargs.items(): setattr(self, k, v) if __name__ =='__main__': SDConfig()