gpaasch commited on
Commit
2d164e9
·
1 Parent(s): 05e7351

added a streaming Audio input component per the Gradio guide—using Whisper for transcription in real time and piping the resulting text into the chatflow

Browse files
Files changed (1) hide show
  1. src/app.py +37 -35
src/app.py CHANGED
@@ -1,11 +1,9 @@
1
  import os
2
  import gradio as gr
 
3
  from transformers import pipeline
4
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
5
- from llama_index.llm_predictor import HuggingFaceLLMPredictor, LLMPredictor
6
-
7
- # Optional OpenAI import remains for default predictor
8
- import openai
9
 
10
  # --- Whisper ASR setup ---
11
  asr = pipeline(
@@ -26,49 +24,53 @@ or, if you have enough info, output a final JSON with fields:
26
  {"diagnoses":[…], "confidences":[…]}.
27
  """
28
 
 
 
 
 
 
 
29
 
30
- def transcribe_and_respond(audio, history):
31
- # 1) Transcribe audio → text
32
- user_text = asr(audio)["text"]
33
- history = history or []
34
- history.append(("user", user_text))
35
-
36
- # 2) Build unified prompt for LLM
37
- messages = [("system", SYSTEM_PROMPT)] + history
38
- prompt = "\n".join(f"{role.capitalize()}: {text}" for role, text in messages)
39
- prompt += "\nAssistant:"
40
 
41
- # 3) Select predictor (OpenAI or Mistral/local)
42
- predictor = get_llm_predictor()
43
- resp = predictor.predict(prompt)
44
 
45
- # 4) If JSON-style output, treat as final
46
- if resp.strip().startswith("{"):
47
- result = query_symptoms(resp)
48
- history.append(("assistant", f"Here is your diagnosis: {result}"))
49
- return "", history
 
 
 
50
 
51
- # 5) Otherwise, it's a follow-up question
52
- history.append(("assistant", resp))
53
- return "", history
54
 
 
 
55
 
56
- # --- Build Gradio app ---
57
- with gr.Blocks() as demo:
58
- gr.Markdown("## Symptom to ICD-10 Diagnoser (audio & chat)")
 
59
  chatbot = gr.Chatbot(label="Conversation")
60
- mic = gr.Microphone(label="Describe your symptoms")
61
- state = gr.State([])
 
62
 
63
- mic.submit(
64
  fn=transcribe_and_respond,
65
  inputs=[mic, state],
66
- outputs=[mic, chatbot, state]
 
 
 
67
  )
68
 
69
  if __name__ == "__main__":
70
  demo.launch(
71
- server_name="0.0.0.0",
72
- server_port=7860,
73
- mcp_server=True
74
  )
 
1
  import os
2
  import gradio as gr
3
+ import openai
4
  from transformers import pipeline
5
  from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
6
+ from llama_index.llm_predictor import HuggingFaceLLMPredictor
 
 
 
7
 
8
  # --- Whisper ASR setup ---
9
  asr = pipeline(
 
24
  {"diagnoses":[…], "confidences":[…]}.
25
  """
26
 
27
+ def transcribe_and_respond(audio_chunk, state):
28
+ # Transcribe audio chunk
29
+ result = asr(audio_chunk)
30
+ text = result.get('text', '').strip()
31
+ if not text:
32
+ return state, []
33
 
34
+ # Append user message
35
+ state.append(("user", text))
 
 
 
 
 
 
 
 
36
 
37
+ # Build LLM predictor (you can swap OpenAI / HuggingFace here)
38
+ llm_predictor = HuggingFaceLLMPredictor(model_name_or_path=os.getenv("HF_MODEL", "gpt2-medium"))
 
39
 
40
+ # Query index with conversation
41
+ # (Assuming `symptom_index` is your GPTVectorStoreIndex)
42
+ # Prepare combined prompt from state
43
+ prompt = "\n".join([f"{role}: {msg}" for role, msg in state])
44
+ response = symptom_index.as_query_engine(
45
+ llm_predictor=llm_predictor
46
+ ).query(prompt)
47
+ reply = response.response
48
 
49
+ # Append assistant message
50
+ state.append(("assistant", reply))
 
51
 
52
+ # Return updated state to chatbot
53
+ return state, state
54
 
55
+ # Build Gradio interface
56
+ demo = gr.Blocks()
57
+ with demo:
58
+ gr.Markdown("# Symptom to ICD-10 Code Lookup (Audio Input)")
59
  chatbot = gr.Chatbot(label="Conversation")
60
+ state = gr.State([])
61
+ # Use streaming audio input for real-time transcription
62
+ mic = gr.Audio(source="microphone", type="filepath", streaming=True, label="Describe your symptoms")
63
 
64
+ mic.stream(
65
  fn=transcribe_and_respond,
66
  inputs=[mic, state],
67
+ outputs=[chatbot, state],
68
+ time_limit=60,
69
+ stream_every=5,
70
+ concurrency_limit=1
71
  )
72
 
73
  if __name__ == "__main__":
74
  demo.launch(
75
+ server_name="0.0.0.0", server_port=7860, mcp_server=True
 
 
76
  )