| import torch | |
| from transformers import PretrainedConfig | |
| from typing import List | |
| class STDiTConfig(PretrainedConfig): | |
| model_type = "stdit" | |
| def __init__( | |
| self, | |
| input_size=(1, 32, 32), | |
| in_channels=4, | |
| patch_size=(1, 2, 2), | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4.0, | |
| class_dropout_prob=0.1, | |
| pred_sigma=True, | |
| drop_path=0.0, | |
| no_temporal_pos_emb=False, | |
| caption_channels=4096, | |
| model_max_length=120, | |
| space_scale=1.0, | |
| time_scale=1.0, | |
| freeze=None, | |
| enable_flash_attn=False, | |
| enable_layernorm_kernel=False, | |
| enable_sequence_parallelism=False, | |
| **kwargs, | |
| ): | |
| self.input_size = input_size | |
| self.in_channels = in_channels | |
| self.patch_size = patch_size | |
| self.hidden_size = hidden_size | |
| self.depth = depth | |
| self.num_heads = num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.class_dropout_prob = class_dropout_prob | |
| self.pred_sigma = pred_sigma | |
| self.drop_path = drop_path | |
| self.no_temporal_pos_emb = no_temporal_pos_emb | |
| self.caption_channels = caption_channels | |
| self.model_max_length = model_max_length | |
| self.space_scale = space_scale | |
| self.time_scale = time_scale | |
| self.freeze = freeze | |
| self.enable_flash_attn = enable_flash_attn | |
| self.enable_layernorm_kernel = enable_layernorm_kernel | |
| self.enable_sequence_parallelism = enable_sequence_parallelism | |
| super().__init__(**kwargs) |