MagicLuke commited on
Commit
eec5d10
·
verified ·
1 Parent(s): 9e1ee6d

Upload SpeakerEncoder

Browse files
Files changed (3) hide show
  1. config.json +3 -3
  2. model.safetensors +2 -2
  3. modeling_ecapa_tdnn.py +6 -94
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "C": 1024,
3
  "architectures": [
4
- "HFECAPATDNN"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "configuration_ecapa_tdnn.ECAPAConfig",
8
- "AutoModel": "modeling_ecapa_tdnn.HFECAPATDNN"
9
  },
10
  "model_type": "ecapa_tdnn",
11
  "torch_dtype": "float32",
 
1
  {
2
  "C": 1024,
3
  "architectures": [
4
+ "SpeakerEncoder"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "modeling_ecapa_tdnn.ECAPAConfig",
8
+ "AutoModel": "modeling_ecapa_tdnn.SpeakerEncoder"
9
  },
10
  "model_type": "ecapa_tdnn",
11
  "torch_dtype": "float32",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5e8babea99d09e708dadae8623a538b406825d7b5af527584cee204b957785d0
3
- size 66667584
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99a87fdb4f4b9608940134f211d1d61f64107667bfad2003948da449a1902197
3
+ size 65020192
modeling_ecapa_tdnn.py CHANGED
@@ -78,85 +78,11 @@ class Bottle2neck(nn.Module):
78
  out += residual
79
  return out
80
 
81
- class PreEmphasis(torch.nn.Module):
82
-
83
- def __init__(self, coef: float = 0.97):
84
- super().__init__()
85
- self.coef = coef
86
- self.register_buffer(
87
- 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
88
- )
89
-
90
- def forward(self, input: torch.tensor) -> torch.tensor:
91
- input = input.unsqueeze(1)
92
- input = F.pad(input, (1, 0), 'reflect')
93
- return F.conv1d(input, self.flipped_filter).squeeze(1)
94
-
95
- class FbankAug(nn.Module):
96
-
97
- def __init__(self, freq_mask_width = (0, 8), time_mask_width = (0, 10)):
98
- self.time_mask_width = time_mask_width
99
- self.freq_mask_width = freq_mask_width
100
- super().__init__()
101
-
102
- def mask_along_axis(self, x, dim):
103
- original_size = x.shape
104
- batch, fea, time = x.shape
105
- if dim == 1:
106
- D = fea
107
- width_range = self.freq_mask_width
108
- else:
109
- D = time
110
- width_range = self.time_mask_width
111
-
112
- mask_len = torch.randint(width_range[0], width_range[1], (batch, 1), device=x.device).unsqueeze(2)
113
- mask_pos = torch.randint(0, max(1, D - mask_len.max()), (batch, 1), device=x.device).unsqueeze(2)
114
- arange = torch.arange(D, device=x.device).view(1, 1, -1)
115
- mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
116
- mask = mask.any(dim=1)
117
-
118
- if dim == 1:
119
- mask = mask.unsqueeze(2)
120
- else:
121
- mask = mask.unsqueeze(1)
122
-
123
- x = x.masked_fill_(mask, 0.0)
124
- return x.view(*original_size)
125
-
126
- def forward(self, x):
127
- x = self.mask_along_axis(x, dim=2)
128
- x = self.mask_along_axis(x, dim=1)
129
- return x
130
-
131
- class ECAPA_TDNN(nn.Module):
132
 
133
  def __init__(self, C):
134
 
135
- super(ECAPA_TDNN, self).__init__()
136
-
137
- self.torchfbank = torch.nn.Sequential(
138
- PreEmphasis(),
139
- # torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, \
140
- # f_min = 20, f_max = 7600, window_fn=torch.hamming_window, n_mels=80),
141
- torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050),
142
- torchaudio.transforms.MelSpectrogram(
143
- sample_rate = 22050,
144
- n_fft = 2048,
145
- hop_length = 512,
146
- win_length = 2048,
147
- # window_fn = lambda *_: window,
148
- center = False,
149
- power = 2.0,
150
- n_mels = 256,
151
- norm = "slaney",
152
- mel_scale = "htk",
153
- ),
154
- torchaudio.transforms.AmplitudeToDB(
155
- stype="power", top_db=80
156
- )
157
- )
158
-
159
- self.specaug = FbankAug() # Spec augmentation
160
 
161
  # self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
162
  # self.conv1 = nn.Conv1d(256, C, kernel_size=5, stride=1, padding=2)
@@ -181,19 +107,7 @@ class ECAPA_TDNN(nn.Module):
181
  self.bn6 = nn.BatchNorm1d(192)
182
 
183
 
184
- def forward(self, x, aug):
185
- with torch.no_grad():
186
- x = self.torchfbank(x)
187
- # x = self.torchfbank(x)+1e-6
188
- # x = x.log()
189
- x = x - torch.mean(x, dim=-1, keepdim=True) # mean normalization
190
- if aug == True:
191
- x = self.specaug(x)
192
- # only take the first 232 mel bins
193
- if x.dim() == 3:
194
- x = x[:, :232, :]
195
- else:
196
- x = x[:232]
197
 
198
  x = self.conv1(x)
199
  x = self.relu(x)
@@ -224,9 +138,7 @@ class ECAPA_TDNN(nn.Module):
224
 
225
 
226
  import torch
227
- from transformers import PreTrainedModel
228
- # from configuration_ecapa_tdnn import ECAPAConfig
229
- from transformers import PretrainedConfig
230
 
231
 
232
  class ECAPAConfig(PretrainedConfig):
@@ -238,11 +150,11 @@ class ECAPAConfig(PretrainedConfig):
238
 
239
 
240
 
241
- class HFECAPATDNN(PreTrainedModel):
242
  config_class = ECAPAConfig
243
  base_model_prefix = "ecapa_tdnn"
244
  def __init__(self, config):
245
  super().__init__(config)
246
- self.model = ECAPA_TDNN(C=config.C)
247
  def forward(self, *args, **kwargs):
248
  return self.model(*args, **kwargs)
 
78
  out += residual
79
  return out
80
 
81
+ class EcapaTdnnEncoder(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def __init__(self, C):
84
 
85
+ super(EcapaTdnnEncoder, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
88
  # self.conv1 = nn.Conv1d(256, C, kernel_size=5, stride=1, padding=2)
 
107
  self.bn6 = nn.BatchNorm1d(192)
108
 
109
 
110
+ def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  x = self.conv1(x)
113
  x = self.relu(x)
 
138
 
139
 
140
  import torch
141
+ from transformers import PreTrainedModel, PretrainedConfig
 
 
142
 
143
 
144
  class ECAPAConfig(PretrainedConfig):
 
150
 
151
 
152
 
153
+ class SpeakerEncoder(PreTrainedModel):
154
  config_class = ECAPAConfig
155
  base_model_prefix = "ecapa_tdnn"
156
  def __init__(self, config):
157
  super().__init__(config)
158
+ self.model = EcapaTdnnEncoder(C=config.C)
159
  def forward(self, *args, **kwargs):
160
  return self.model(*args, **kwargs)