badrex commited on
Commit
bfe1e2c
ยท
verified ยท
1 Parent(s): aca2627

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -48
app.py CHANGED
@@ -1,54 +1,68 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- import numpy as np
4
  import os
5
- from huggingface_hub import login
6
- import librosa
7
  import spaces
 
 
8
 
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
- if HF_TOKEN:
11
- login(token=HF_TOKEN)
12
-
13
- MODEL_ID = "badrex/w2v-bert-2.0-kinyarwanda-asr-1000h"
14
- transcriber = pipeline("automatic-speech-recognition", model=MODEL_ID)
15
-
16
-
17
- @spaces.GPU
18
- def transcribe(audio):
19
- sr, y = audio
20
-
21
- # convert to mono if stereo
22
- if y.ndim > 1:
23
- y = y.mean(axis=1)
24
-
25
- # resample to 16kHz if needed
26
- #if sr != 16000:
27
- # y = librosa.resample(y, orig_sr=sr, target_sr=16000)
28
-
29
- y = y.astype(np.float32)
30
- y /= np.max(np.abs(y))
31
-
32
- return transcriber({"sampling_rate": sr, "raw": y})["text"]
33
 
 
34
  examples = []
35
  examples_dir = "examples"
36
  if os.path.exists(examples_dir):
37
  for filename in os.listdir(examples_dir):
38
  if filename.endswith((".wav", ".mp3", ".ogg")):
39
  examples.append([os.path.join(examples_dir, filename)])
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- print(f"Found {len(examples)} example files")
42
- else:
43
- print("Examples directory not found")
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- demo = gr.Interface(
47
- fn=transcribe,
48
- inputs=gr.Audio(),
49
- outputs="text",
50
- title="<div>ASRwanda ๐ŸŽ™๏ธ <br>Speech Recognition for Kinyarwanda</div>",
51
- description="""
 
 
 
 
 
52
  <div class="centered-content">
53
  <div>
54
  <p>
@@ -57,23 +71,37 @@ demo = gr.Interface(
57
  <br>
58
  <p style="font-size: 15px; line-height: 1.8;">
59
  Muraho ๐Ÿ‘‹๐Ÿผ
60
- <br>
61
- <br>
62
  This is a demo for ASRwanda, a Transformer-based automatic speech recognition (ASR) system for Kinyarwanda language.
63
  The underlying ASR model was trained on 1000 hours of transcribed speech provided by
64
  <a href="https://digitalumuganda.com/" style="color: #2563eb;">Digital Umuganda</a> as part of the Kinyarwanda
65
  <a href="https://www.kaggle.com/competitions/kinyarwanda-automatic-speech-recognition-track-b" style="color: #2563eb;"> ASR hackathon</a> on Kaggle.
66
- <br>
67
- <p style="font-size: 15px; line-height: 1.8;">
68
  Simply <strong>upload an audio file</strong> ๐Ÿ“ค or <strong>record yourself speaking</strong> ๐ŸŽ™๏ธโบ๏ธ to try out the model!
69
  </p>
70
  </div>
71
  </div>
72
- """,
73
- examples=examples if examples else None,
74
- cache_examples=False,
75
- flagging_mode=None,
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
78
  if __name__ == "__main__":
79
- demo.launch()
 
 
 
 
1
  import os
2
+ import torchaudio
3
+ import gradio as gr
4
  import spaces
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForCTC
7
 
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # load examples
12
  examples = []
13
  examples_dir = "examples"
14
  if os.path.exists(examples_dir):
15
  for filename in os.listdir(examples_dir):
16
  if filename.endswith((".wav", ".mp3", ".ogg")):
17
  examples.append([os.path.join(examples_dir, filename)])
18
+
19
+ # Load model and processor
20
+ MODEL_PATH = "badrex/w2v-bert-2.0-kinyarwanda-asr"
21
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
+ model = AutoModelForCTC.from_pretrained(MODEL_PATH)
23
+
24
+ # move model and processor to device
25
+ model = model.to(device)
26
+
27
+ @spaces.GPU()
28
+ def process_audio(audio_path):
29
+ """Process audio with return the generated response.
30
 
31
+ Args:
32
+ audio_path: Path to the audio file to be transcribed.
33
+ Returns:
34
+ String containing the transcribed text from the audio file, or an error message
35
+ if the audio file is missing.
36
+ """
37
+ if not audio_path:
38
+ return "Please upload an audio file."
39
+
40
+ # get audio array
41
+ audio_array, sample_rate = torchaudio.load(audio_path)
42
 
43
+ # if sample rate is not 16000, resample to 16000
44
+ if sample_rate != 16000:
45
+ audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)
46
+
47
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
48
+ inputs = {k: v.to(device) for k, v in inputs.items()}
49
+
50
+ with torch.no_grad():
51
+ logits = model(**inputs).logits
52
+
53
+ outputs = torch.argmax(logits, dim=-1)
54
 
55
+ decoded_outputs = processor.batch_decode(
56
+ outputs,
57
+ skip_special_tokens=True
58
+ )
59
+
60
+ return decoded_outputs[0].strip()
61
+
62
+
63
+ # Define Gradio interface
64
+ with gr.Blocks(title="<div>ASRwanda ๐ŸŽ™๏ธ <br>Speech Recognition for Kinyarwanda</div>") as demo:
65
+ gr.Markdown("""
66
  <div class="centered-content">
67
  <div>
68
  <p>
 
71
  <br>
72
  <p style="font-size: 15px; line-height: 1.8;">
73
  Muraho ๐Ÿ‘‹๐Ÿผ
74
+ <br><br>
 
75
  This is a demo for ASRwanda, a Transformer-based automatic speech recognition (ASR) system for Kinyarwanda language.
76
  The underlying ASR model was trained on 1000 hours of transcribed speech provided by
77
  <a href="https://digitalumuganda.com/" style="color: #2563eb;">Digital Umuganda</a> as part of the Kinyarwanda
78
  <a href="https://www.kaggle.com/competitions/kinyarwanda-automatic-speech-recognition-track-b" style="color: #2563eb;"> ASR hackathon</a> on Kaggle.
79
+ <br><br>
 
80
  Simply <strong>upload an audio file</strong> ๐Ÿ“ค or <strong>record yourself speaking</strong> ๐ŸŽ™๏ธโบ๏ธ to try out the model!
81
  </p>
82
  </div>
83
  </div>
84
+ """)
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
89
+ submit_btn = gr.Button("Transcribe Audio", variant="primary")
90
+
91
+ with gr.Column():
92
+ output_text = gr.Textbox(label="Text Transcription", lines=10)
93
+
94
+ submit_btn.click(
95
+ fn=process_audio,
96
+ inputs=[audio_input],
97
+ outputs=output_text
98
+ )
99
+
100
+ gr.Examples(
101
+ examples=examples if examples else None,
102
+ inputs=[audio_input],
103
+ )
104
 
105
+ # Launch the app
106
  if __name__ == "__main__":
107
+ demo.queue().launch()