Bajiyo commited on
Commit
ab46033
·
verified ·
1 Parent(s): f3d6706

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
8
  # from modules.v2.vc_wrapper import VoiceConversionWrapper
9
 
10
  # --- CONFIGURATION (UPDATE YOUR_USERNAME HERE) ---
11
- # Your correct model repository ID
12
  MODEL_REPO_ID = "Bajiyo/dhanush_seedvc"
13
  CFM_FILE = "CFM_epoch_00651_step_21500.pth"
14
  AR_FILE = "AR_epoch_00651_step_21500.pth"
@@ -25,32 +25,44 @@ dtype = torch.float16
25
 
26
  def load_models(args):
27
  """
28
- Loads models, handling checkpoint download from Hugging Face Hub.
 
29
  """
30
- # 1. Setup local directory and download checkpoints
31
- LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
32
- os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True)
33
- print(f"Downloading checkpoints from {MODEL_REPO_ID}...")
34
-
35
- # Download CFM
36
- cfm_local_path = hf_hub_download(
37
- repo_id=MODEL_REPO_ID,
38
- filename=CFM_FILE,
39
- local_dir=LOCAL_CHECKPOINTS_DIR,
40
- local_dir_use_symlinks=False
41
- )
42
- print(f"CFM checkpoint downloaded to: {cfm_local_path}")
43
 
44
- # Download AR
45
- ar_local_path = hf_hub_download(
46
- repo_id=MODEL_REPO_ID,
47
- filename=AR_FILE,
48
- local_dir=LOCAL_CHECKPOINTS_DIR,
49
- local_dir_use_symlinks=False
50
- )
51
- print(f"AR checkpoint downloaded to: {ar_local_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # 2. Instantiate and load models
54
  from hydra.utils import instantiate
55
  from omegaconf import DictConfig
56
 
@@ -58,7 +70,7 @@ def load_models(args):
58
  cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
59
  vc_wrapper = instantiate(cfg)
60
 
61
- # Load the downloaded checkpoints
62
  vc_wrapper.load_checkpoints(
63
  ar_checkpoint_path=ar_local_path,
64
  cfm_checkpoint_path=cfm_local_path
@@ -85,6 +97,7 @@ def main(args):
85
  vc_wrapper = load_models(args)
86
 
87
  # Define wrapper function for Gradio. NO DECORATORS HERE.
 
88
  def convert_voice_wrapper(source_audio_path, target_audio_path, diffusion_steps,
89
  length_adjust, intelligibility_cfg_rate, similarity_cfg_rate,
90
  top_p, temperature, repetition_penalty, convert_style,
@@ -149,7 +162,7 @@ def main(args):
149
 
150
  # Launch the Gradio interface
151
  gr.Interface(
152
- fn=convert_voice_wrapper,
153
  description=description,
154
  inputs=inputs,
155
  outputs=outputs,
@@ -162,10 +175,10 @@ if __name__ == "__main__":
162
  import argparse
163
  parser = argparse.ArgumentParser()
164
  parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
165
- # These arguments are now effectively ignored/not needed since we download the models
166
  parser.add_argument("--ar-checkpoint-path", type=str, default=None,
167
- help="Path to custom checkpoint file (overridden by HF download in Space)")
168
  parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
169
- help="Path to custom checkpoint file (overridden by HF download in Space)")
170
  args = parser.parse_args()
171
  main(args)
 
8
  # from modules.v2.vc_wrapper import VoiceConversionWrapper
9
 
10
  # --- CONFIGURATION (UPDATE YOUR_USERNAME HERE) ---
11
+ # Your correct model repository ID for automatic download in the Space
12
  MODEL_REPO_ID = "Bajiyo/dhanush_seedvc"
13
  CFM_FILE = "CFM_epoch_00651_step_21500.pth"
14
  AR_FILE = "AR_epoch_00651_step_21500.pth"
 
25
 
26
  def load_models(args):
27
  """
28
+ Loads models, prioritizing command-line arguments for local paths,
29
+ and falling back to Hugging Face Hub download for the Space environment.
30
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # --- 1. Determine Checkpoint Paths ---
33
+ if args.cfm_checkpoint_path:
34
+ cfm_local_path = args.cfm_checkpoint_path
35
+ print(f"Using local CFM checkpoint path from arguments: {cfm_local_path}")
36
+ else:
37
+ # Default behavior for Space: download from HF
38
+ LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
39
+ os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True)
40
+ print(f"Arguments not provided. Downloading CFM checkpoint from {MODEL_REPO_ID}...")
41
+ cfm_local_path = hf_hub_download(
42
+ repo_id=MODEL_REPO_ID,
43
+ filename=CFM_FILE,
44
+ local_dir=LOCAL_CHECKPOINTS_DIR,
45
+ local_dir_use_symlinks=False
46
+ )
47
+ print(f"CFM checkpoint downloaded to: {cfm_local_path}")
48
+
49
+ if args.ar_checkpoint_path:
50
+ ar_local_path = args.ar_checkpoint_path
51
+ print(f"Using local AR checkpoint path from arguments: {ar_local_path}")
52
+ else:
53
+ # Default behavior for Space: download from HF
54
+ LOCAL_CHECKPOINTS_DIR = "downloaded_checkpoints"
55
+ os.makedirs(LOCAL_CHECKPOINTS_DIR, exist_ok=True) # Ensure dir exists
56
+ print(f"Arguments not provided. Downloading AR checkpoint from {MODEL_REPO_ID}...")
57
+ ar_local_path = hf_hub_download(
58
+ repo_id=MODEL_REPO_ID,
59
+ filename=AR_FILE,
60
+ local_dir=LOCAL_CHECKPOINTS_DIR,
61
+ local_dir_use_symlinks=False
62
+ )
63
+ print(f"AR checkpoint downloaded to: {ar_local_path}")
64
 
65
+ # --- 2. Instantiate and load models ---
66
  from hydra.utils import instantiate
67
  from omegaconf import DictConfig
68
 
 
70
  cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
71
  vc_wrapper = instantiate(cfg)
72
 
73
+ # Load the determined checkpoints (either local paths or downloaded HF paths)
74
  vc_wrapper.load_checkpoints(
75
  ar_checkpoint_path=ar_local_path,
76
  cfm_checkpoint_path=cfm_local_path
 
97
  vc_wrapper = load_models(args)
98
 
99
  # Define wrapper function for Gradio. NO DECORATORS HERE.
100
+ # This wrapper ensures the streaming output works correctly in the Gradio Interface.
101
  def convert_voice_wrapper(source_audio_path, target_audio_path, diffusion_steps,
102
  length_adjust, intelligibility_cfg_rate, similarity_cfg_rate,
103
  top_p, temperature, repetition_penalty, convert_style,
 
162
 
163
  # Launch the Gradio interface
164
  gr.Interface(
165
+ fn=convert_voice_wrapper, # Using the wrapper for reliable streaming
166
  description=description,
167
  inputs=inputs,
168
  outputs=outputs,
 
175
  import argparse
176
  parser = argparse.ArgumentParser()
177
  parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
178
+ # These are the arguments that allow you to run the script locally with specific paths
179
  parser.add_argument("--ar-checkpoint-path", type=str, default=None,
180
+ help="Path to custom AR checkpoint file. Defaults to HF download in Space.")
181
  parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
182
+ help="Path to custom CFM checkpoint file. Defaults to HF download in Space.")
183
  args = parser.parse_args()
184
  main(args)