lambertxiao's picture
Upload folder using huggingface_hub
746c807 verified
raw
history blame
1.97 kB
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class EmptyClass(PretrainedConfig):
def __init__(self):
pass
class SDConfig(PretrainedConfig):
def __init__(self,
override_total_steps = -1,
freeze_vae = True,
use_flash = False,
adapt_topk = -1,
loss = 'mse',
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225],
use_same_noise_among_timesteps = False,
random_timestep_per_iteration = True,
rand_timestep_equal_int = False,
output_dir = './outputs/First_Start',
do_center_crop_size = 384,
architectures = None,
input = None,
model = None,
tta = None,
**kwargs
):
super().__init__()
self.model = EmptyClass()
self.model.override_total_steps = override_total_steps
self.model.freeze_vae = freeze_vae
self.model.use_flash = use_flash
self.tta = EmptyClass()
self.tta.gradient_descent = EmptyClass()
self.tta.adapt_topk = adapt_topk
self.tta.loss = loss
self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps
self.tta.random_timestep_per_iteration = random_timestep_per_iteration
self.tta.rand_timestep_equal_int = rand_timestep_equal_int
self.input = EmptyClass()
self.input.mean = mean
self.input.std = std
self.output_dir = output_dir
self.do_center_crop_size = do_center_crop_size
self.architectures = architectures
for k, v in kwargs.items():
setattr(self, k, v)
if __name__ =='__main__':
SDConfig()