File size: 3,686 Bytes
fda6e40
bfe1e2c
 
fda6e40
bfe1e2c
 
fda6e40
bfe1e2c
 
fda6e40
bfe1e2c
fda6e40
 
 
 
 
 
bfe1e2c
 
 
 
 
 
 
 
 
 
 
 
fda6e40
bfe1e2c
 
 
 
 
 
 
 
 
 
 
fda6e40
bfe1e2c
 
 
 
 
 
 
 
 
 
 
fda6e40
bfe1e2c
 
 
 
 
 
 
 
 
 
 
fda6e40
 
 
04ac83f
fda6e40
 
 
 
bfe1e2c
fda6e40
45cd355
b3653e3
45cd355
bfe1e2c
fda6e40
 
 
 
bfe1e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
bfe1e2c
fda6e40
bfe1e2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torchaudio
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, AutoModelForCTC

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# load examples 
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])

# Load model and processor
MODEL_PATH = "badrex/w2v-bert-2.0-kinyarwanda-asr"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForCTC.from_pretrained(MODEL_PATH)

# move model and processor to device
model = model.to(device)

@spaces.GPU()
def process_audio(audio_path):
    """Process audio with return the generated response.
    
    Args:
        audio_path: Path to the audio file to be transcribed.    
    Returns:
        String containing the transcribed text from the audio file, or an error message
        if the audio file is missing.
    """
    if not audio_path:
        return "Please upload an audio file."

    # get audio array
    audio_array, sample_rate = torchaudio.load(audio_path)

    # if sample rate is not 16000, resample to 16000
    if sample_rate != 16000:
        audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)

    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits

    outputs = torch.argmax(logits, dim=-1)
    
    decoded_outputs = processor.batch_decode(
        outputs,
        skip_special_tokens=True
    )
    
    return decoded_outputs[0].strip()


# Define Gradio interface
with gr.Blocks(title="<div>ASRwanda ๐ŸŽ™๏ธ <br>Speech Recognition for Kinyarwanda</div>") as demo:
    gr.Markdown("""
        <div class="centered-content">
            <div>
                <p>
                Developed with โค by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> โ˜•
                </p>
                <br>
                <p style="font-size: 15px; line-height: 1.8;">
                 Muraho ๐Ÿ‘‹๐Ÿผ
                <br><br>
                 This is a demo for ASRwanda, a Transformer-based automatic speech recognition (ASR) system for Kinyarwanda language.
                 The underlying ASR model was trained on 1000 hours of transcribed speech provided by 
                 <a href="https://digitalumuganda.com/" style="color: #2563eb;">Digital Umuganda</a> as part of the Kinyarwanda
                 <a href="https://www.kaggle.com/competitions/kinyarwanda-automatic-speech-recognition-track-b" style="color: #2563eb;"> ASR hackathon</a> on Kaggle.
                <br><br>                   
                Simply <strong>upload an audio file</strong> ๐Ÿ“ค or <strong>record yourself speaking</strong> ๐ŸŽ™๏ธโบ๏ธ to try out the model!
                </p>
            </div>
        </div>
    """)

    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(type="filepath", label="Upload Audio")
            submit_btn = gr.Button("Transcribe Audio", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Text Transcription", lines=10)
    
    submit_btn.click(
        fn=process_audio,
        inputs=[audio_input],
        outputs=output_text
    )

    gr.Examples(
        examples=examples if examples else None,
        inputs=[audio_input],
    )

# Launch the app
if __name__ == "__main__":
    demo.queue().launch()