Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu | |
| # Wei Kang) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Computes speaker similarity (SIM-o) using a WavLM-based | |
| ECAPA-TDNN speaker verification model. | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import warnings | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM | |
| from zipvoice.eval.utils import load_waveform | |
| warnings.filterwarnings("ignore") | |
| def get_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="Calculate speaker similarity (SIM-o) score." | |
| ) | |
| parser.add_argument( | |
| "--wav-path", | |
| type=str, | |
| required=True, | |
| help="Path to the directory containing evaluated speech files.", | |
| ) | |
| parser.add_argument( | |
| "--test-list", | |
| type=str, | |
| required=True, | |
| help="Path to the file list that contains the correspondence between prompts " | |
| "and evaluated speech. Each line contains (audio_name, prompt_text_1, " | |
| "prompt_text_2, prompt_audio_1, prompt_audio_2, text) separated by tabs.", | |
| ) | |
| parser.add_argument( | |
| "--model-dir", | |
| type=str, | |
| required=True, | |
| help="Local path of our evaluatioin model repository." | |
| "Download from https://huggingface.co/k2-fsa/TTS_eval_models." | |
| "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'" | |
| "and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script", | |
| ) | |
| parser.add_argument( | |
| "--extension", | |
| type=str, | |
| default="wav", | |
| help="Extension of the speech files. Default: wav", | |
| ) | |
| return parser | |
| class SpeakerSimilarity: | |
| """ | |
| Computes speaker similarity (SIM-o) using a WavLM-based | |
| ECAPA-TDNN speaker verification model. | |
| """ | |
| def __init__( | |
| self, | |
| sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth", | |
| ssl_model_path: str = "speaker_similarity/wavlm_large/", | |
| ): | |
| """ | |
| Initializes the speaker similarity evaluator with the specified models. | |
| Args: | |
| sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint. | |
| ssl_model_path (str): Path of the wavlm SSL model directory. | |
| """ | |
| self.sample_rate = 16000 | |
| self.device = ( | |
| torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| logging.info(f"Using device: {self.device}") | |
| self.model = ECAPA_TDNN_WAVLM( | |
| feat_dim=1024, | |
| channels=512, | |
| emb_dim=256, | |
| sr=self.sample_rate, | |
| ssl_model_path=ssl_model_path, | |
| ) | |
| state_dict = torch.load( | |
| sv_model_path, map_location=lambda storage, loc: storage | |
| ) | |
| self.model.load_state_dict(state_dict["model"], strict=False) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def get_embeddings(self, wav_paths: List[str]) -> List[torch.Tensor]: | |
| """ | |
| Extracts speaker embeddings from a list of audio files. | |
| Args: | |
| wav_paths (List[str]): List of paths to audio files. | |
| Returns: | |
| List[torch.Tensor]: List of speaker embeddings. | |
| """ | |
| embeddings = [] | |
| for wav_path in tqdm(wav_paths, desc="Extracting speaker embeddings"): | |
| # Load and preprocess waveform | |
| speech = load_waveform( | |
| wav_path, self.sample_rate, device=self.device, max_seconds=120 | |
| ) | |
| # Extract embedding | |
| embedding = self.model([speech]) | |
| embeddings.append(embedding) | |
| return embeddings | |
| def score(self, wav_path: str, extension: str, test_list: str) -> float: | |
| """ | |
| Computes the Speaker Similarity (SIM-o) score between reference and | |
| evaluated speech. | |
| Args: | |
| wav_path (str): Path to the directory containing evaluated speech files. | |
| test_list (str): Path to the test list file mapping evaluated files | |
| to reference prompts. | |
| Returns: | |
| float: Average similarity score between reference and evaluated embeddings. | |
| """ | |
| logging.info(f"Calculating Speaker Similarity (SIM-o) score for {wav_path}") | |
| # Read test pairs | |
| try: | |
| with open(test_list, "r", encoding="utf-8") as f: | |
| lines = [line.strip().split("\t") for line in f if line.strip()] | |
| except Exception as e: | |
| logging.error(f"Failed to read test list: {e}") | |
| raise | |
| if not lines: | |
| raise ValueError(f"Test list {test_list} is empty or malformed") | |
| # Parse test pairs | |
| prompt_wavs = [] | |
| eval_wavs = [] | |
| for line in lines: | |
| if len(line) != 4: | |
| raise ValueError(f"Invalid line: {line}") | |
| wav_name, prompt_text, prompt_wav, text = line | |
| eval_wav_path = os.path.join(wav_path, f"{wav_name}.{extension}") | |
| # Validate file existence | |
| if not os.path.exists(prompt_wav): | |
| raise FileNotFoundError(f"Prompt file not found: {prompt_wav}") | |
| if not os.path.exists(eval_wav_path): | |
| raise FileNotFoundError(f"Evaluated file not found: {eval_wav_path}") | |
| prompt_wavs.append(prompt_wav) | |
| eval_wavs.append(eval_wav_path) | |
| logging.info(f"Found {len(prompt_wavs)} valid test pairs") | |
| # Extract embeddings | |
| prompt_embeddings = self.get_embeddings(prompt_wavs) | |
| eval_embeddings = self.get_embeddings(eval_wavs) | |
| if len(prompt_embeddings) != len(eval_embeddings): | |
| raise RuntimeError( | |
| f"Mismatch: {len(prompt_embeddings)} prompt vs " | |
| f" {len(eval_embeddings)} eval embeddings" | |
| ) | |
| # Calculate similarity scores | |
| scores = [] | |
| for prompt_emb, eval_emb in zip(prompt_embeddings, eval_embeddings): | |
| # Compute cosine similarity | |
| similarity = torch.nn.functional.cosine_similarity( | |
| prompt_emb, eval_emb, dim=-1 | |
| ) | |
| scores.append(similarity.item()) | |
| return float(np.mean(scores)) | |
| if __name__ == "__main__": | |
| torch.set_num_threads(1) | |
| torch.set_num_interop_threads(1) | |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | |
| logging.basicConfig(format=formatter, level=logging.INFO, force=True) | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| # Initialize evaluator | |
| sv_model_path = os.path.join( | |
| args.model_dir, "speaker_similarity/wavlm_large_finetune.pth" | |
| ) | |
| ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/") | |
| if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path): | |
| logging.error( | |
| "Please download evaluation models from " | |
| "https://huggingface.co/k2-fsa/TTS_eval_models" | |
| " and pass this dir with --model-dir" | |
| ) | |
| exit(1) | |
| sim_evaluator = SpeakerSimilarity( | |
| sv_model_path=sv_model_path, ssl_model_path=ssl_model_path | |
| ) | |
| # Compute similarity score | |
| score = sim_evaluator.score(args.wav_path, args.extension, args.test_list) | |
| print("-" * 50) | |
| logging.info(f"SIM-o score: {score:.3f}") | |
| print("-" * 50) | |