Randinu002
Remove notebook_login and use secrets for auth
c553417
raw
history blame
6.78 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import random
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WavLMModel
from speechbrain.lobes.features import Fbank
from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN
import torch.optim as optim
import itertools
print("All libraries imported successfully.")
class AFM(nn.Module):
def __init__(self, channels):
super(AFM, self).__init__(); self.conv1 = nn.Conv1d(2*channels, channels, 1); self.bn1 = nn.BatchNorm1d(channels); self.silu = nn.SiLU(); self.conv2 = nn.Conv1d(channels, channels, 1); self.bn2 = nn.BatchNorm1d(channels)
def forward(self, x, y):
out = torch.cat((x, y), dim=1); out = self.conv1(out); out = self.bn1(out); out = self.silu(out); out = self.conv2(out); out = self.bn2(out); out = torch.tanh(out)
return x * (1-(out+1)/2) + y * ((out+1)/2)
class FFPTM(nn.Module):
def __init__(self, ptm_name="microsoft/wavlm-base-plus"):
super(FFPTM, self).__init__(); self.ptm = WavLMModel.from_pretrained(ptm_name); self.num_layers = self.ptm.config.num_hidden_layers; self.feature_dim = self.ptm.config.hidden_size; self.layer_weights = nn.Parameter(torch.ones(self.num_layers + 1)); self.fusion_module = AFM(self.feature_dim); self.aux_pooling = nn.AdaptiveAvgPool1d(1); self.aux_fc = nn.Linear(self.feature_dim, 40)
def forward(self, waveform):
outputs = self.ptm(waveform, output_hidden_states=True); hidden_states = outputs.hidden_states; normalized_weights = F.softmax(self.layer_weights, dim=-1); n_split = (self.num_layers + 1) // 2
shallow = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split)).transpose(1, 2)
deep = sum(normalized_weights[i] * hidden_states[i] for i in range(n_split, self.num_layers + 1)).transpose(1, 2)
fused_ptm_features = self.fusion_module(shallow, deep); last_layer = hidden_states[-1].transpose(1, 2)
aux_logits = self.aux_fc(self.aux_pooling(last_layer).squeeze(-1))
return fused_ptm_features, aux_logits
class DBE(nn.Module):
def __init__(self, ptm_feature_dim=768):
super(DBE, self).__init__(); ecapa_channels = 512; k=[5,3,3,3]; d=[1,2,3,4]; c=[ecapa_channels]*4
self.main_branch_input_conv = nn.Conv1d(ptm_feature_dim, ecapa_channels, 1)
temp_ecapa = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d)
self.main_blocks = temp_ecapa.blocks
self.fbank_extractor = Fbank(n_mels=80); self.fbank_align_conv = nn.Conv1d(80, ecapa_channels, 1, stride=2)
temp_ecapa_aux = ECAPA_TDNN(input_size=ecapa_channels, channels=c, kernel_sizes=k, dilations=d); self.aux_branch_blocks = temp_ecapa_aux.blocks
self.fusion_modules = nn.ModuleList([AFM(ecapa_channels) for _ in range(len(self.main_blocks))])
self.mfa = nn.Conv1d(3 * ecapa_channels, ecapa_channels, 1)
pooled_dim = ecapa_channels * 2
self.bn = nn.BatchNorm1d(pooled_dim)
self.fc = nn.Linear(pooled_dim, 256)
def forward(self, ptm_features, waveform):
fbank = self.fbank_extractor(waveform).permute(0, 2, 1); aligned_fbank = self.fbank_align_conv(fbank)
aux_outs = []; x_aux = aligned_fbank
for block in self.aux_branch_blocks:
x_aux = block(x_aux); aux_outs.append(x_aux)
x_main = self.main_branch_input_conv(ptm_features)
target_len = min(x_main.shape[2], aux_outs[0].shape[2])
main_in = x_main[:, :, :target_len]; main_outs = []
for i, (block, mod) in enumerate(zip(self.main_blocks, self.fusion_modules)):
fused = mod(main_in, aux_outs[i][:, :, :target_len]); main_in = block(fused); main_outs.append(main_in)
cat_feats = torch.cat(main_outs[-3:], dim=1)
mfa_feats = self.mfa(cat_feats)
mean = mfa_feats.mean(dim=2)
std = mfa_feats.std(dim=2)
pooled_feats = torch.cat((mean, std), dim=1)
norm_feats = self.bn(pooled_feats)
embedding = self.fc(norm_feats)
return embedding
class FullModel(nn.Module):
def __init__(self): super(FullModel, self).__init__(); self.frontend=FFPTM(); self.backend=DBE()
def forward(self, waveform): fused_ptm, aux_logits = self.frontend(waveform); embedding = self.backend(fused_ptm, waveform); return embedding, aux_logits
class SpeakerDataset(Dataset):
def __init__(self, data_folder, sample_rate=16000, seg_len_sec=3):
super().__init__()
self.data_folder = data_folder
if not os.path.exists(self.data_folder): os.makedirs(self.data_folder)
self.dataset = torchaudio.datasets.LIBRISPEECH(self.data_folder, "dev-clean", download=True)
s_ids = sorted(list({i[3] for i in self.dataset})); self.s_map = {sid: i for i, sid in enumerate(s_ids)}; self.speaker_ids=s_ids
self.files = []; self.sr=sample_rate; self.seg_len=seg_len_sec*self.sr
for _,_,_,sid,cid,uid in self.dataset:
fname=f"{sid}-{cid}-{uid:04d}.flac"; p=os.path.join(self.data_folder,"LibriSpeech","dev-clean",str(sid),str(cid),fname); self.files.append((p, self.s_map[sid]))
def __len__(self): return len(self.files)
def _load_and_process(self, file_path):
waveform, sr = torchaudio.load(file_path)
if sr != self.sr: waveform = torchaudio.transforms.Resample(sr, self.sr)(waveform)
if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
if waveform.shape[1] < self.seg_len: pad_amount = self.seg_len - waveform.shape[1]; waveform = F.pad(waveform, (0, pad_amount))
else: start_point = random.randint(0, waveform.shape[1] - self.seg_len); waveform = waveform[:, start_point : start_point + self.seg_len]
return waveform.squeeze(0)
def __getitem__(self, index): file_path, label = self.files[index]; waveform = self._load_and_process(file_path); return waveform, label
class AAMSoftmax(nn.Module):
def __init__(self, idim, n_cl, m=0.2, s=32):
super(AAMSoftmax, self).__init__(); self.w = nn.Parameter(torch.FloatTensor(n_cl, idim)); nn.init.xavier_uniform_(self.w)
self.s=s; self.m=m; self.cos_m=math.cos(m); self.sin_m=math.sin(m); self.th=math.cos(math.pi-m); self.mm=math.sin(math.pi-m)*m
def forward(self, x, lbl):
cos = F.linear(F.normalize(x), F.normalize(self.w)); sin = torch.sqrt(1.0 - torch.pow(cos, 2)).clamp(0, 1)
phi = cos * self.cos_m - sin * self.sin_m; phi = torch.where(cos > self.th, phi, cos - self.mm)
hot = torch.zeros(cos.size(), device=x.device); hot.scatter_(1, lbl.view(-1, 1).long(), 1)
return ((hot * phi) + ((1.0 - hot) * cos)) * self.s
print("All classes defined successfully.")