| from transformers import PreTrainedModel | |
| from package.audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck | |
| from .autoencoder_config import AutoEncoder1dConfig | |
| bottleneck = { 'tanh': TanhBottleneck } | |
| class AutoEncoder1d(PreTrainedModel): | |
| config_class = AutoEncoder1dConfig | |
| def __init__(self, config: AutoEncoder1dConfig): | |
| super().__init__(config) | |
| self.autoencoder = AE1d( | |
| in_channels = config.in_channels, | |
| patch_size = config.patch_size, | |
| channels = config.channels, | |
| multipliers = config.multipliers, | |
| factors = config.factors, | |
| num_blocks = config.num_blocks, | |
| bottleneck = bottleneck[config.bottleneck]() | |
| ) | |
| def forward(self, *args, **kwargs): | |
| return self.autoencoder(*args, **kwargs) | |
| def encode(self, *args, **kwargs): | |
| return self.autoencoder.encode(*args, **kwargs) | |
| def decode(self, *args, **kwargs): | |
| return self.autoencoder.decode(*args, **kwargs) | |