Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| import os | |
| import random | |
| import torchaudio | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import WavLMModel | |
| from speechbrain.lobes.features import Fbank | |
| from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN | |
| import torch.optim as optim | |
| import itertools | |
| print("All libraries imported successfully.") | |
| class AFM(nn.Module): | |
| def __init__(self, channels): | |
| super(AFM, self).__init__(); self.conv1 = nn.Conv1d(2*channels, channels, 1); self.bn1 = nn.BatchNorm1d(channels); self.silu = nn.SiLU(); self.conv2 = nn.Conv1d(channels, channels, 1); self.bn2 = nn.BatchNorm1d(channels) | |
| def forward(self, x, y): | |
| out = torch.cat((x, y), dim=1); out = self.conv1(out); out = self.bn1(out); out = self.silu(out); out = self.conv2(out); out = self.bn2(out); out = torch.tanh(out) | |
| return x * (1-(out+1)/2) + y * ((out+1)/2) | |
| class FFPTM(nn.Module): | |
| def __init__(self, ptm_name="microsoft/wavlm-base-plus"): | |
| super(FFPTM, self).__init__(); self.ptm = WavLMModel.from_pretrained(ptm_name); self.num_layers = self.ptm.config.num_hidden_layers; self.feature_dim = self.ptm.config.hidden_size; self.layer_weights = nn.Parameter(torch.ones(self.num_layers + 1)); self.fusion_module = AFM(self.feature_dim); self.aux_pooling = nn.AdaptiveAvgPool1d(1); self.aux_fc = nn.Linear(self.feature_dim, 40) | |
| def forward(self, waveform): | |
| outputs = self.ptm(waveform, output_hidden_states=True); hidden_states = outputs.hidden_states; normalized_weights = F.softmax(self.layer_weights, dim=-1); n_split = (self.num_layers + 1) // 2 | |
| shallow = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split)).transpose(1, 2) | |
| deep = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split, self.num_layers + 1)).transpose(1, 2) | |
| fused_ptm_features = self.fusion_module(shallow, deep); last_layer = hidden_states[-1].transpose(1, 2) | |
| aux_logits = self.aux_fc(self.aux_pooling(last_layer).squeeze(-1)) | |
| return fused_ptm_features, aux_logits | |
| class DBE(nn.Module): | |
| def __init__(self, ptm_feature_dim=768): | |
| super(DBE, self).__init__(); ecapa_channels = 512; k=[5,3,3,3]; d=[1,2,3,4]; c=[ecapa_channels]*4 | |
| self.main_branch_input_conv = nn.Conv1d(ptm_feature_dim, ecapa_channels, 1) | |
| temp_ecapa = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d) | |
| self.main_blocks = temp_ecapa.blocks | |
| self.fbank_extractor = Fbank(n_mels=80); self.fbank_align_conv = nn.Conv1d(80, ecapa_channels, 1, stride=2) | |
| temp_ecapa_aux = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d); self.aux_branch_blocks = temp_ecapa_aux.blocks | |
| self.fusion_modules = nn.ModuleList([AFM(ecapa_channels) for _ in range(len(self.main_blocks))]) | |
| self.mfa = nn.Conv1d(3 * ecapa_channels, ecapa_channels, 1) | |
| pooled_dim = ecapa_channels * 2 | |
| self.bn = nn.BatchNorm1d(pooled_dim) | |
| self.fc = nn.Linear(pooled_dim, 256) | |
| def forward(self, ptm_features, waveform): | |
| fbank = self.fbank_extractor(waveform).permute(0, 2, 1); aligned_fbank = self.fbank_align_conv(fbank) | |
| aux_outs = []; x_aux = aligned_fbank | |
| for block in self.aux_branch_blocks: | |
| x_aux = block(x_aux); aux_outs.append(x_aux) | |
| x_main = self.main_branch_input_conv(ptm_features) | |
| target_len = min(x_main.shape[2], aux_outs[0].shape[2]) | |
| main_in = x_main[:, :, :target_len]; main_outs = [] | |
| for i, (block, mod) in enumerate(zip(self.main_blocks, self.fusion_modules)): | |
| fused = mod(main_in, aux_outs[i][:, :, :target_len]); main_in = block(fused); main_outs.append(main_in) | |
| cat_feats = torch.cat(main_outs[-3:], dim=1) | |
| mfa_feats = self.mfa(cat_feats) | |
| mean = mfa_feats.mean(dim=2) | |
| std = mfa_feats.std(dim=2) | |
| pooled_feats = torch.cat((mean, std), dim=1) | |
| norm_feats = self.bn(pooled_feats) | |
| embedding = self.fc(norm_feats) | |
| return embedding | |
| class FullModel(nn.Module): | |
| def __init__(self): super(FullModel, self).__init__(); self.frontend=FFPTM(); self.backend=DBE() | |
| def forward(self, waveform): fused_ptm, aux_logits = self.frontend(waveform); embedding = self.backend(fused_ptm, waveform); return embedding, aux_logits | |
| class SpeakerDataset(Dataset): | |
| def __init__(self, data_folder, sample_rate=16000, seg_len_sec=3): | |
| super().__init__() | |
| self.data_folder = data_folder | |
| if not os.path.exists(self.data_folder): os.makedirs(self.data_folder) | |
| self.dataset = torchaudio.datasets.LIBRISPEECH(self.data_folder, "dev-clean", download=True) | |
| s_ids = sorted(list({i[3] for i in self.dataset})); self.s_map = {sid: i for i, sid in enumerate(s_ids)}; self.speaker_ids=s_ids | |
| self.files = []; self.sr=sample_rate; self.seg_len=seg_len_sec*self.sr | |
| for _,_,_,sid,cid,uid in self.dataset: | |
| fname=f"{sid}-{cid}-{uid:04d}.flac"; p=os.path.join(self.data_folder,"LibriSpeech","dev-clean",str(sid),str(cid),fname); self.files.append((p, self.s_map[sid])) | |
| def __len__(self): return len(self.files) | |
| def _load_and_process(self, file_path): | |
| waveform, sr = torchaudio.load(file_path) | |
| if sr != self.sr: waveform = torchaudio.transforms.Resample(sr, self.sr)(waveform) | |
| if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if waveform.shape[1] < self.seg_len: pad_amount = self.seg_len - waveform.shape[1]; waveform = F.pad(waveform, (0, pad_amount)) | |
| else: start_point = random.randint(0, waveform.shape[1] - self.seg_len); waveform = waveform[:, start_point : start_point + self.seg_len] | |
| return waveform.squeeze(0) | |
| def __getitem__(self, index): file_path, label = self.files[index]; waveform = self._load_and_process(file_path); return waveform, label | |
| class AAMSoftmax(nn.Module): | |
| def __init__(self, idim, n_cl, m=0.2, s=32): | |
| super(AAMSoftmax, self).__init__(); self.w = nn.Parameter(torch.FloatTensor(n_cl, idim)); nn.init.xavier_uniform_(self.w) | |
| self.s=s; self.m=m; self.cos_m=math.cos(m); self.sin_m=math.sin(m); self.th=math.cos(math.pi-m); self.mm=math.sin(math.pi-m)*m | |
| def forward(self, x, lbl): | |
| cos = F.linear(F.normalize(x), F.normalize(self.w)); sin = torch.sqrt(1.0 - torch.pow(cos, 2)).clamp(0, 1) | |
| phi = cos * self.cos_m - sin * self.sin_m; phi = torch.where(cos > self.th, phi, cos - self.mm) | |
| hot = torch.zeros(cos.size(), device=x.device); hot.scatter_(1, lbl.view(-1, 1).long(), 1) | |
| return ((hot * phi) + ((1.0 - hot) * cos)) * self.s | |
| print("All classes defined successfully.") |