File size: 3,458 Bytes
fda6e40
27414b4
 
fda6e40
27414b4
 
fda6e40
27414b4
 
fda6e40
27414b4
fda6e40
 
 
 
 
 
27414b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
27414b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
27414b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
27414b4
 
 
 
 
 
 
 
 
 
fda6e40
27414b4
fda6e40
27414b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torchaudio
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, AutoModelForCTC

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# load examples 
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])

# Load model and processor
MODEL_PATH = "badrex/w2v-bert-2.0-zulu-asr"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForCTC.from_pretrained(MODEL_PATH)

# move model and processor to device
model = model.to(device)
#processor = processor.to(device)

@spaces.GPU()
def process_audio(audio_path):
    """Process audio with return the generated respotextnse.
    
    Args:
        audio_path: Path to the audio file to be transcribed.    
    Returns:
        String containing the transcribed text from the audio file, or an error message
        if the audio file is missing.
    """
    if not audio_path:
        return "Please upload an audio file."

    # get audio array
    audio_array, sample_rate = torchaudio.load(audio_path)

    # if sample rate is not 16000, resample to 16000
    if sample_rate != 16000:
        audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)

    #audio_array = audio_array.to(device)

    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    #inputs = inputs.to(device, dtype=torch.bfloat16)

    with torch.no_grad():
        logits = model(**inputs).logits

    outputs = torch.argmax(logits, dim=-1)
    
    decoded_outputs = processor.batch_decode(
        outputs,
        skip_special_tokens=True
    )
    
    return decoded_outputs[0].strip()


# Define Gradio interface
with gr.Blocks(title="Voxtral Demo") as demo:
    gr.Markdown("# isiZulu ASR πŸŽ™οΈ Robust Speech Recognition for Zulu Language πŸ‹β€πŸŸ©")
    gr.Markdown(    
        'Developed with <span style="color:red;">❀</span> by <a href="https://badrex.github.io/">Badr al-Absi</a>'
    )
    gr.Markdown(
        """### Hi there πŸ‘‹πŸΌ

This is a demo for [badrex/w2v-bert-2.0-zulu-asr](https://huggingface.co/badrex/w2v-bert-2.0-zulu-asr), 
a robust Transformer-based automatic speech recognition (ASR) system for the Zulu language that was trained on 250+ hours of 
high-quality human-transcribed speech based on the [ZA-African Next Voices](https://huggingface.co/datasets/dsfsi-anv/za-african-next-voices) dataset.
    """
    )

    gr.Markdown("Simply **upload an audio file** πŸ“€ or **record yourself speaking** πŸŽ™οΈβΊοΈ to try out the model!")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(type="filepath", label="Upload Audio")
            submit_btn = gr.Button("Transcribe Audio", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Text Transcription", lines=10)
    
    submit_btn.click(
        fn=process_audio,
        inputs=[audio_input],
        outputs=output_text
    )

    gr.Examples(
        examples=examples if examples else None,
        inputs=[audio_input],
    )

# Launch the app
if __name__ == "__main__":
    demo.queue().launch() #share=False, ssr_mode=False, mcp_server=True