| import torch | |
| import torch.nn as nn | |
| from diffusers import UNet2DModel, UNet2DConditionModel | |
| import yaml | |
| from einops import repeat, rearrange | |
| from typing import Any | |
| from torch import Tensor | |
| def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: | |
| if proba == 1: | |
| return torch.ones(shape, device=device, dtype=torch.bool) | |
| elif proba == 0: | |
| return torch.zeros(shape, device=device, dtype=torch.bool) | |
| else: | |
| return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) | |
| class DiffVC(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.unet = UNet2DModel(**self.config['unet']) | |
| self.unet.set_use_memory_efficient_attention_xformers(True) | |
| self.speaker_embedding = nn.Sequential( | |
| nn.Linear(self.config['cls_embedding']['speaker_dim'], self.config['cls_embedding']['feature_dim']), | |
| nn.SiLU(), | |
| nn.Linear(self.config['cls_embedding']['feature_dim'], self.config['cls_embedding']['feature_dim'])) | |
| self.uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['speaker_dim']) / | |
| self.config['cls_embedding']['speaker_dim'] ** 0.5) | |
| self.content_embedding = nn.Sequential( | |
| nn.Linear(self.config['cls_embedding']['content_dim'], self.config['cls_embedding']['content_hidden']), | |
| nn.SiLU(), | |
| nn.Linear(self.config['cls_embedding']['content_hidden'], self.config['cls_embedding']['content_hidden'])) | |
| if self.config['cls_embedding']['use_pitch']: | |
| self.pitch_control = True | |
| self.pitch_embedding = nn.Sequential( | |
| nn.Linear(self.config['cls_embedding']['pitch_dim'], self.config['cls_embedding']['pitch_hidden']), | |
| nn.SiLU(), | |
| nn.Linear(self.config['cls_embedding']['pitch_hidden'], | |
| self.config['cls_embedding']['pitch_hidden'])) | |
| self.pitch_uncond = nn.Parameter(torch.randn(self.config['cls_embedding']['pitch_hidden']) / | |
| self.config['cls_embedding']['pitch_hidden'] ** 0.5) | |
| else: | |
| print('no pitch module') | |
| self.pitch_control = False | |
| def forward(self, target, t, content, speaker, pitch, | |
| train_cfg=False, speaker_cfg=0.0, pitch_cfg=0.0): | |
| B, C, M, L = target.shape | |
| content = self.content_embedding(content) | |
| content = repeat(content, "b t c-> b c m t", m=M) | |
| target = target.to(content.dtype) | |
| x = torch.cat([target, content], dim=1) | |
| if self.pitch_control: | |
| if pitch is not None: | |
| pitch = self.pitch_embedding(pitch) | |
| else: | |
| pitch = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype) | |
| if train_cfg: | |
| uncond = repeat(self.uncond, "c-> b c", b=B).to(target.dtype) | |
| batch_mask = rand_bool(shape=(B, 1), proba=speaker_cfg, device=target.device) | |
| speaker = torch.where(batch_mask, uncond, speaker) | |
| if self.pitch_control: | |
| batch_mask = rand_bool(shape=(B, 1, 1), proba=pitch_cfg, device=target.device) | |
| pitch_uncond = repeat(self.pitch_uncond, "c-> b t c", b=B, t=L).to(target.dtype) | |
| pitch = torch.where(batch_mask, pitch_uncond, pitch) | |
| speaker = self.speaker_embedding(speaker) | |
| if self.pitch_control: | |
| pitch = repeat(pitch, "b t c-> b c m t", m=M) | |
| x = torch.cat([x, pitch], dim=1) | |
| output = self.unet(sample=x, timestep=t, class_labels=speaker)['sample'] | |
| return output | |
| if __name__ == "__main__": | |
| with open('diffvc_base_pitch.yaml', 'r') as fp: | |
| config = yaml.safe_load(fp) | |
| device = 'cuda' | |
| model = DiffVC(config['diffwrap']).to(device) | |
| x = torch.rand((2, 1, 100, 256)).to(device) | |
| y = torch.rand((2, 256, 768)).to(device) | |
| p = torch.rand(2, 256, 1).to(device) | |
| t = torch.randint(0, 1000, (2,)).long().to(device) | |
| spk = torch.rand(2, 256).to(device) | |
| output = model(x, t, y, spk, pitch=p, train_cfg=True, cfg_prob=0.25) |