Randinu002 commited on
Commit
b65a3c4
·
1 Parent(s): c553417

Fix model loading on CPU

Browse files
Files changed (1) hide show
  1. app.py +7 -15
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (More Stable Version)
2
 
3
  import gradio as gr
4
  import torch
@@ -7,10 +7,9 @@ import torchaudio
7
  import os
8
  import time
9
 
10
- # --- Ensure model.py with your class definitions is in the same folder ---
11
  from model import FullModel
12
 
13
- # --- 1. Global Setup ---
14
  if not os.path.exists("user_data"): os.makedirs("user_data")
15
  if not os.path.exists("user_data/enrollments"): os.makedirs("user_data/enrollments")
16
  if not os.path.exists("user_data/verifications"): os.makedirs("user_data/verifications")
@@ -18,13 +17,12 @@ if not os.path.exists("user_data/verifications"): os.makedirs("user_data/verific
18
  print("Loading model...")
19
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  MODEL = FullModel().to(DEVICE)
21
- MODEL.load_state_dict(torch.load("speaker_verification_model.pth"))
22
  MODEL.eval()
23
  THRESHOLD = 0.5216
24
  print("Model loaded successfully.")
25
  ENROLLED_USERS = {}
26
 
27
- # --- 2. Helper and Core Functions ---
28
 
29
  def get_embedding(waveform):
30
  if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
@@ -37,11 +35,9 @@ def enroll_speaker(audio_filepath, username):
37
  return "Error: No audio recorded. Please record your voice first.", gr.update()
38
  if not username:
39
  return "Please enter a username.", None
40
-
41
- # Load the audio from the temporary file path provided by Gradio
42
  waveform, sr = torchaudio.load(audio_filepath)
43
-
44
- # Save a permanent copy of the enrollment audio
45
  timestamp = int(time.time())
46
  filename = f"user_data/enrollments/{username}_{timestamp}.wav"
47
  torchaudio.save(filename, waveform, sr)
@@ -60,10 +56,9 @@ def verify_speaker(audio_filepath, username_to_verify):
60
  if username_to_verify not in ENROLLED_USERS:
61
  return f"User '{username_to_verify}' is not enrolled. Please enroll first.", None
62
 
63
- # Load the audio from the temporary file path provided by Gradio
64
  waveform, sr = torchaudio.load(audio_filepath)
65
 
66
- # Save a permanent copy of the verification attempt
67
  timestamp = int(time.time())
68
  filename = f"user_data/verifications/{username_to_verify}_attempt_{timestamp}.wav"
69
  torchaudio.save(filename, waveform, sr)
@@ -77,7 +72,7 @@ def verify_speaker(audio_filepath, username_to_verify):
77
 
78
  return f"Similarity Score: {score:.4f}", decision
79
 
80
- # --- 3. Gradio Interface ---
81
 
82
  with gr.Blocks() as demo:
83
  gr.Markdown("# Voice Authentication System")
@@ -85,13 +80,11 @@ with gr.Blocks() as demo:
85
  with gr.Tabs():
86
  with gr.TabItem("Enrollment"):
87
  enroll_username = gr.Textbox(label="Enter a unique Username")
88
- # <<< --- FIX: Changed type="numpy" to type="filepath" --- >>>
89
  enroll_audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your enrollment phrase (3-5 seconds)")
90
  enroll_button = gr.Button("Enroll Voiceprint")
91
  enroll_output = gr.Textbox(label="Enrollment Status")
92
  with gr.TabItem("Verification"):
93
  verify_username = gr.Textbox(label="Enter your Username to verify")
94
- # <<< --- FIX: Changed type="numpy" to type="filepath" --- >>>
95
  verify_audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your verification phrase (must be different!)")
96
  verify_button = gr.Button("Verify My Voice")
97
  verify_score = gr.Textbox(label="Result Score")
@@ -100,6 +93,5 @@ with gr.Blocks() as demo:
100
  enroll_button.click(fn=enroll_speaker, inputs=[enroll_audio, enroll_username], outputs=[enroll_output, verify_username])
101
  verify_button.click(fn=verify_speaker, inputs=[verify_audio, verify_username], outputs=[verify_score, verify_decision])
102
 
103
- # --- 4. Launch the App ---
104
  if __name__ == "__main__":
105
  demo.queue().launch(share=True)
 
1
+
2
 
3
  import gradio as gr
4
  import torch
 
7
  import os
8
  import time
9
 
 
10
  from model import FullModel
11
 
12
+
13
  if not os.path.exists("user_data"): os.makedirs("user_data")
14
  if not os.path.exists("user_data/enrollments"): os.makedirs("user_data/enrollments")
15
  if not os.path.exists("user_data/verifications"): os.makedirs("user_data/verifications")
 
17
  print("Loading model...")
18
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  MODEL = FullModel().to(DEVICE)
20
+ MODEL.load_state_dict(torch.load("speaker_verification_model.pth", map_location=torch.device('cpu')))
21
  MODEL.eval()
22
  THRESHOLD = 0.5216
23
  print("Model loaded successfully.")
24
  ENROLLED_USERS = {}
25
 
 
26
 
27
  def get_embedding(waveform):
28
  if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
 
35
  return "Error: No audio recorded. Please record your voice first.", gr.update()
36
  if not username:
37
  return "Please enter a username.", None
38
+
 
39
  waveform, sr = torchaudio.load(audio_filepath)
40
+
 
41
  timestamp = int(time.time())
42
  filename = f"user_data/enrollments/{username}_{timestamp}.wav"
43
  torchaudio.save(filename, waveform, sr)
 
56
  if username_to_verify not in ENROLLED_USERS:
57
  return f"User '{username_to_verify}' is not enrolled. Please enroll first.", None
58
 
 
59
  waveform, sr = torchaudio.load(audio_filepath)
60
 
61
+
62
  timestamp = int(time.time())
63
  filename = f"user_data/verifications/{username_to_verify}_attempt_{timestamp}.wav"
64
  torchaudio.save(filename, waveform, sr)
 
72
 
73
  return f"Similarity Score: {score:.4f}", decision
74
 
75
+
76
 
77
  with gr.Blocks() as demo:
78
  gr.Markdown("# Voice Authentication System")
 
80
  with gr.Tabs():
81
  with gr.TabItem("Enrollment"):
82
  enroll_username = gr.Textbox(label="Enter a unique Username")
 
83
  enroll_audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your enrollment phrase (3-5 seconds)")
84
  enroll_button = gr.Button("Enroll Voiceprint")
85
  enroll_output = gr.Textbox(label="Enrollment Status")
86
  with gr.TabItem("Verification"):
87
  verify_username = gr.Textbox(label="Enter your Username to verify")
 
88
  verify_audio = gr.Audio(sources=["microphone"], type="filepath", label="Record your verification phrase (must be different!)")
89
  verify_button = gr.Button("Verify My Voice")
90
  verify_score = gr.Textbox(label="Result Score")
 
93
  enroll_button.click(fn=enroll_speaker, inputs=[enroll_audio, enroll_username], outputs=[enroll_output, verify_username])
94
  verify_button.click(fn=verify_speaker, inputs=[verify_audio, verify_username], outputs=[verify_score, verify_decision])
95
 
 
96
  if __name__ == "__main__":
97
  demo.queue().launch(share=True)