|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class EmptyClass(PretrainedConfig): |
|
def __init__(self): |
|
pass |
|
class SDConfig(PretrainedConfig): |
|
|
|
def __init__(self, |
|
override_total_steps = -1, |
|
freeze_vae = True, |
|
use_flash = False, |
|
adapt_topk = -1, |
|
loss = 'mse', |
|
mean = [0.485, 0.456, 0.406], |
|
std = [0.229, 0.224, 0.225], |
|
use_same_noise_among_timesteps = False, |
|
random_timestep_per_iteration = True, |
|
rand_timestep_equal_int = False, |
|
output_dir = './outputs/First_Start', |
|
do_center_crop_size = 384, |
|
architectures = None, |
|
input = None, |
|
model = None, |
|
tta = None, |
|
**kwargs |
|
|
|
|
|
): |
|
super().__init__() |
|
self.model = EmptyClass() |
|
self.model.override_total_steps = override_total_steps |
|
self.model.freeze_vae = freeze_vae |
|
self.model.use_flash = use_flash |
|
self.tta = EmptyClass() |
|
self.tta.gradient_descent = EmptyClass() |
|
self.tta.adapt_topk = adapt_topk |
|
self.tta.loss = loss |
|
self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps |
|
self.tta.random_timestep_per_iteration = random_timestep_per_iteration |
|
self.tta.rand_timestep_equal_int = rand_timestep_equal_int |
|
self.input = EmptyClass() |
|
self.input.mean = mean |
|
self.input.std = std |
|
self.output_dir = output_dir |
|
self.do_center_crop_size = do_center_crop_size |
|
self.architectures = architectures |
|
for k, v in kwargs.items(): |
|
setattr(self, k, v) |
|
if __name__ =='__main__': |
|
SDConfig() |
|
|