import gradio as gr
import torch
import yaml
import os
from huggingface_hub import hf_hub_download
# Assuming these are available in your Space's environment
# from seed_vc_wrapper import SeedVCWrapper
# from modules.v2.vc_wrapper import VoiceConversionWrapper
# --- CONFIGURATION (UPDATE YOUR_USERNAME HERE) ---
# Your correct model repository ID for automatic download in the Space
MODEL_REPO_ID = "Bajiyo/dhanush_seedvc"
CFM_FILE = "CFM_epoch_00651_step_21500.pth"
AR_FILE = "AR_epoch_00651_step_21500.pth"
# -----------------------------------------------
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
dtype = torch.float16
def load_models(args):
"""
Loads models, prioritizing command-line arguments for local paths,
and falling back to Hugging Face Hub download for the Space environment.
"""
# --- 1. Determine Checkpoint Paths ---
if args.cfm_checkpoint_path:
cfm_local_path = args.cfm_checkpoint_path
print(f"Using local CFM checkpoint path from arguments: {cfm_local_path}")
else:
# Default behavior for Space: download from HF
LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True)
print(f"Arguments not provided. Downloading CFM checkpoint from {MODEL_REPO_ID}...")
cfm_local_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=CFM_FILE,
local_dir=LOCAL_CHECKPOINTS_DIR,
local_dir_use_symlinks=False
)
print(f"CFM checkpoint downloaded to: {cfm_local_path}")
if args.ar_checkpoint_path:
ar_local_path = args.ar_checkpoint_path
print(f"Using local AR checkpoint path from arguments: {ar_local_path}")
else:
# Default behavior for Space: download from HF
LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True) # Ensure dir exists
print(f"Arguments not provided. Downloading AR checkpoint from {MODEL_REPO_ID}...")
ar_local_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=AR_FILE,
local_dir=LOCAL_CHECKPOINTS_DIR,
local_dir_use_symlinks=False
)
print(f"AR checkpoint downloaded to: {ar_local_path}")
# --- 2. Instantiate and load models ---
from hydra.utils import instantiate
from omegaconf import DictConfig
# Assuming 'configs/v2/vc_wrapper.yaml' is present in the Space repo
cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
vc_wrapper = instantiate(cfg)
# Load the determined checkpoints (either local paths or downloaded HF paths)
vc_wrapper.load_checkpoints(
ar_checkpoint_path=ar_local_path,
cfm_checkpoint_path=cfm_local_path
)
vc_wrapper.to(device)
vc_wrapper.eval()
vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
if args.compile:
# Standard torch compile settings
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
torch._inductor.config.fx_graph_cache = True
vc_wrapper.compile_ar()
# vc_wrapper.compile_cfm()
return vc_wrapper
def main(args):
# load_models handles the download and initialization now
vc_wrapper = load_models(args)
# Define wrapper function for Gradio. NO DECORATORS HERE.
# This wrapper ensures the streaming output works correctly in the Gradio Interface.
def convert_voice_wrapper(source_audio_path, target_audio_path, diffusion_steps,
length_adjust, intelligibility_cfg_rate, similarity_cfg_rate,
top_p, temperature, repetition_penalty, convert_style,
anonymization_only, stream_output=True):
"""
Wrapper function for vc_wrapper.convert_voice_with_streaming.
"""
yield from vc_wrapper.convert_voice_with_streaming(
source_audio_path=source_audio_path,
target_audio_path=target_audio_path,
diffusion_steps=diffusion_steps,
length_adjust=length_adjust,
intelligebility_cfg_rate=intelligibility_cfg_rate,
similarity_cfg_rate=similarity_cfg_rate,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
convert_style=convert_style,
anonymization_only=anonymization_only,
device=device,
dtype=dtype,
stream_output=stream_output
)
# Set up Gradio interface
description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
"for details and updates.
Note that any reference audio will be forcefully clipped to 25s if beyond this length.
"
"If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.
"
"无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc]
"
"请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。
若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
inputs = [
gr.Audio(type="filepath", label="Source Audio / 源音频"),
gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数",
info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Intelligibility CFG Rate",
info="controls pronunciation intelligibility / 控制发音清晰度"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Similarity CFG Rate",
info="controls similarity to reference audio / 控制与参考音频的相似度"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
info="Controls diversity of generated audio / 控制生成音频的多样性"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
info="Controls randomness of generated audio / 控制生成音频的随机性"),
gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
info="Penalizes repetition in generated audio / 惩罚生成音频中的重复"),
gr.Checkbox(label="convert style", value=False),
gr.Checkbox(label="anonymization only", value=False),
]
examples = [
["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.5, 0.5, 0.9, 1.0, 1.0, False, False],
["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.5, 0.5, 0.9, 1.0, 1.0, False, False],
]
outputs = [
gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
]
# Launch the Gradio interface
gr.Interface(
fn=convert_voice_wrapper, # Using the wrapper for reliable streaming
description=description,
inputs=inputs,
outputs=outputs,
title="Seed Voice Conversion V2",
examples=examples,
cache_examples=False,
).queue().launch(share=False)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
# These are the arguments that allow you to run the script locally with specific paths
parser.add_argument("--ar-checkpoint-path", type=str, default=None,
help="Path to custom AR checkpoint file. Defaults to HF download in Space.")
parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
help="Path to custom CFM checkpoint file. Defaults to HF download in Space.")
args = parser.parse_args()
main(args)