import torch import torch.nn as nn import torch.nn.functional as F import math import os import random import torchaudio from torch.utils.data import Dataset, DataLoader from transformers import WavLMModel from speechbrain.lobes.features import Fbank from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN import torch.optim as optim import itertools print("All libraries imported successfully.") class AFM(nn.Module): def __init__(self, channels): super(AFM, self).__init__(); self.conv1 = nn.Conv1d(2*channels, channels, 1); self.bn1 = nn.BatchNorm1d(channels); self.silu = nn.SiLU(); self.conv2 = nn.Conv1d(channels, channels, 1); self.bn2 = nn.BatchNorm1d(channels) def forward(self, x, y): out = torch.cat((x, y), dim=1); out = self.conv1(out); out = self.bn1(out); out = self.silu(out); out = self.conv2(out); out = self.bn2(out); out = torch.tanh(out) return x * (1-(out+1)/2) + y * ((out+1)/2) class FFPTM(nn.Module): def __init__(self, ptm_name="microsoft/wavlm-base-plus"): super(FFPTM, self).__init__(); self.ptm = WavLMModel.from_pretrained(ptm_name); self.num_layers = self.ptm.config.num_hidden_layers; self.feature_dim = self.ptm.config.hidden_size; self.layer_weights = nn.Parameter(torch.ones(self.num_layers + 1)); self.fusion_module = AFM(self.feature_dim); self.aux_pooling = nn.AdaptiveAvgPool1d(1); self.aux_fc = nn.Linear(self.feature_dim, 40) def forward(self, waveform): outputs = self.ptm(waveform, output_hidden_states=True); hidden_states = outputs.hidden_states; normalized_weights = F.softmax(self.layer_weights, dim=-1); n_split = (self.num_layers + 1) // 2 shallow = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split)).transpose(1, 2) deep = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split, self.num_layers + 1)).transpose(1, 2) fused_ptm_features = self.fusion_module(shallow, deep); last_layer = hidden_states[-1].transpose(1, 2) aux_logits = self.aux_fc(self.aux_pooling(last_layer).squeeze(-1)) return fused_ptm_features, aux_logits class DBE(nn.Module): def __init__(self, ptm_feature_dim=768): super(DBE, self).__init__(); ecapa_channels = 512; k=[5,3,3,3]; d=[1,2,3,4]; c=[ecapa_channels]*4 self.main_branch_input_conv = nn.Conv1d(ptm_feature_dim, ecapa_channels, 1) temp_ecapa = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d) self.main_blocks = temp_ecapa.blocks self.fbank_extractor = Fbank(n_mels=80); self.fbank_align_conv = nn.Conv1d(80, ecapa_channels, 1, stride=2) temp_ecapa_aux = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d); self.aux_branch_blocks = temp_ecapa_aux.blocks self.fusion_modules = nn.ModuleList([AFM(ecapa_channels) for _ in range(len(self.main_blocks))]) self.mfa = nn.Conv1d(3 * ecapa_channels, ecapa_channels, 1) pooled_dim = ecapa_channels * 2 self.bn = nn.BatchNorm1d(pooled_dim) self.fc = nn.Linear(pooled_dim, 256) def forward(self, ptm_features, waveform): fbank = self.fbank_extractor(waveform).permute(0, 2, 1); aligned_fbank = self.fbank_align_conv(fbank) aux_outs = []; x_aux = aligned_fbank for block in self.aux_branch_blocks: x_aux = block(x_aux); aux_outs.append(x_aux) x_main = self.main_branch_input_conv(ptm_features) target_len = min(x_main.shape[2], aux_outs[0].shape[2]) main_in = x_main[:, :, :target_len]; main_outs = [] for i, (block, mod) in enumerate(zip(self.main_blocks, self.fusion_modules)): fused = mod(main_in, aux_outs[i][:, :, :target_len]); main_in = block(fused); main_outs.append(main_in) cat_feats = torch.cat(main_outs[-3:], dim=1) mfa_feats = self.mfa(cat_feats) mean = mfa_feats.mean(dim=2) std = mfa_feats.std(dim=2) pooled_feats = torch.cat((mean, std), dim=1) norm_feats = self.bn(pooled_feats) embedding = self.fc(norm_feats) return embedding class FullModel(nn.Module): def __init__(self): super(FullModel, self).__init__(); self.frontend=FFPTM(); self.backend=DBE() def forward(self, waveform): fused_ptm, aux_logits = self.frontend(waveform); embedding = self.backend(fused_ptm, waveform); return embedding, aux_logits class SpeakerDataset(Dataset): def __init__(self, data_folder, sample_rate=16000, seg_len_sec=3): super().__init__() self.data_folder = data_folder if not os.path.exists(self.data_folder): os.makedirs(self.data_folder) self.dataset = torchaudio.datasets.LIBRISPEECH(self.data_folder, "dev-clean", download=True) s_ids = sorted(list({i[3] for i in self.dataset})); self.s_map = {sid: i for i, sid in enumerate(s_ids)}; self.speaker_ids=s_ids self.files = []; self.sr=sample_rate; self.seg_len=seg_len_sec*self.sr for _,_,_,sid,cid,uid in self.dataset: fname=f"{sid}-{cid}-{uid:04d}.flac"; p=os.path.join(self.data_folder,"LibriSpeech","dev-clean",str(sid),str(cid),fname); self.files.append((p, self.s_map[sid])) def __len__(self): return len(self.files) def _load_and_process(self, file_path): waveform, sr = torchaudio.load(file_path) if sr != self.sr: waveform = torchaudio.transforms.Resample(sr, self.sr)(waveform) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if waveform.shape[1] < self.seg_len: pad_amount = self.seg_len - waveform.shape[1]; waveform = F.pad(waveform, (0, pad_amount)) else: start_point = random.randint(0, waveform.shape[1] - self.seg_len); waveform = waveform[:, start_point : start_point + self.seg_len] return waveform.squeeze(0) def __getitem__(self, index): file_path, label = self.files[index]; waveform = self._load_and_process(file_path); return waveform, label class AAMSoftmax(nn.Module): def __init__(self, idim, n_cl, m=0.2, s=32): super(AAMSoftmax, self).__init__(); self.w = nn.Parameter(torch.FloatTensor(n_cl, idim)); nn.init.xavier_uniform_(self.w) self.s=s; self.m=m; self.cos_m=math.cos(m); self.sin_m=math.sin(m); self.th=math.cos(math.pi-m); self.mm=math.sin(math.pi-m)*m def forward(self, x, lbl): cos = F.linear(F.normalize(x), F.normalize(self.w)); sin = torch.sqrt(1.0 - torch.pow(cos, 2)).clamp(0, 1) phi = cos * self.cos_m - sin * self.sin_m; phi = torch.where(cos > self.th, phi, cos - self.mm) hot = torch.zeros(cos.size(), device=x.device); hot.scatter_(1, lbl.view(-1, 1).long(), 1) return ((hot * phi) + ((1.0 - hot) * cos)) * self.s print("All classes defined successfully.")