3v324v23's picture
up
af11ce4
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Computes speaker similarity (SIM-o) using a WavLM-based
ECAPA-TDNN speaker verification model.
"""
import argparse
import logging
import os
import warnings
from typing import List
import numpy as np
import torch
from tqdm import tqdm
from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
from zipvoice.eval.utils import load_waveform
warnings.filterwarnings("ignore")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Calculate speaker similarity (SIM-o) score."
)
parser.add_argument(
"--wav-path",
type=str,
required=True,
help="Path to the directory containing evaluated speech files.",
)
parser.add_argument(
"--test-list",
type=str,
required=True,
help="Path to the file list that contains the correspondence between prompts "
"and evaluated speech. Each line contains (audio_name, prompt_text_1, "
"prompt_text_2, prompt_audio_1, prompt_audio_2, text) separated by tabs.",
)
parser.add_argument(
"--model-dir",
type=str,
required=True,
help="Local path of our evaluatioin model repository."
"Download from https://huggingface.co/k2-fsa/TTS_eval_models."
"Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
"and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
)
parser.add_argument(
"--extension",
type=str,
default="wav",
help="Extension of the speech files. Default: wav",
)
return parser
class SpeakerSimilarity:
"""
Computes speaker similarity (SIM-o) using a WavLM-based
ECAPA-TDNN speaker verification model.
"""
def __init__(
self,
sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
ssl_model_path: str = "speaker_similarity/wavlm_large/",
):
"""
Initializes the speaker similarity evaluator with the specified models.
Args:
sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
ssl_model_path (str): Path of the wavlm SSL model directory.
"""
self.sample_rate = 16000
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info(f"Using device: {self.device}")
self.model = ECAPA_TDNN_WAVLM(
feat_dim=1024,
channels=512,
emb_dim=256,
sr=self.sample_rate,
ssl_model_path=ssl_model_path,
)
state_dict = torch.load(
sv_model_path, map_location=lambda storage, loc: storage
)
self.model.load_state_dict(state_dict["model"], strict=False)
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
def get_embeddings(self, wav_paths: List[str]) -> List[torch.Tensor]:
"""
Extracts speaker embeddings from a list of audio files.
Args:
wav_paths (List[str]): List of paths to audio files.
Returns:
List[torch.Tensor]: List of speaker embeddings.
"""
embeddings = []
for wav_path in tqdm(wav_paths, desc="Extracting speaker embeddings"):
# Load and preprocess waveform
speech = load_waveform(
wav_path, self.sample_rate, device=self.device, max_seconds=120
)
# Extract embedding
embedding = self.model([speech])
embeddings.append(embedding)
return embeddings
def score(self, wav_path: str, extension: str, test_list: str) -> float:
"""
Computes the Speaker Similarity (SIM-o) score between reference and
evaluated speech.
Args:
wav_path (str): Path to the directory containing evaluated speech files.
test_list (str): Path to the test list file mapping evaluated files
to reference prompts.
Returns:
float: Average similarity score between reference and evaluated embeddings.
"""
logging.info(f"Calculating Speaker Similarity (SIM-o) score for {wav_path}")
# Read test pairs
try:
with open(test_list, "r", encoding="utf-8") as f:
lines = [line.strip().split("\t") for line in f if line.strip()]
except Exception as e:
logging.error(f"Failed to read test list: {e}")
raise
if not lines:
raise ValueError(f"Test list {test_list} is empty or malformed")
# Parse test pairs
prompt_wavs = []
eval_wavs = []
for line in lines:
if len(line) != 4:
raise ValueError(f"Invalid line: {line}")
wav_name, prompt_text, prompt_wav, text = line
eval_wav_path = os.path.join(wav_path, f"{wav_name}.{extension}")
# Validate file existence
if not os.path.exists(prompt_wav):
raise FileNotFoundError(f"Prompt file not found: {prompt_wav}")
if not os.path.exists(eval_wav_path):
raise FileNotFoundError(f"Evaluated file not found: {eval_wav_path}")
prompt_wavs.append(prompt_wav)
eval_wavs.append(eval_wav_path)
logging.info(f"Found {len(prompt_wavs)} valid test pairs")
# Extract embeddings
prompt_embeddings = self.get_embeddings(prompt_wavs)
eval_embeddings = self.get_embeddings(eval_wavs)
if len(prompt_embeddings) != len(eval_embeddings):
raise RuntimeError(
f"Mismatch: {len(prompt_embeddings)} prompt vs "
f" {len(eval_embeddings)} eval embeddings"
)
# Calculate similarity scores
scores = []
for prompt_emb, eval_emb in zip(prompt_embeddings, eval_embeddings):
# Compute cosine similarity
similarity = torch.nn.functional.cosine_similarity(
prompt_emb, eval_emb, dim=-1
)
scores.append(similarity.item())
return float(np.mean(scores))
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO, force=True)
parser = get_parser()
args = parser.parse_args()
# Initialize evaluator
sv_model_path = os.path.join(
args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
)
ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
logging.error(
"Please download evaluation models from "
"https://huggingface.co/k2-fsa/TTS_eval_models"
" and pass this dir with --model-dir"
)
exit(1)
sim_evaluator = SpeakerSimilarity(
sv_model_path=sv_model_path, ssl_model_path=ssl_model_path
)
# Compute similarity score
score = sim_evaluator.score(args.wav_path, args.extension, args.test_list)
print("-" * 50)
logging.info(f"SIM-o score: {score:.3f}")
print("-" * 50)