LiamKhoaLe commited on
Commit
927a9b8
·
1 Parent(s): eb6b193

Upd MAC architecture with 4 supervised task for Gemini and responsive answer from MedSwin

Browse files
Files changed (2) hide show
  1. README.md +8 -1
  2. app.py +467 -371
README.md CHANGED
@@ -55,7 +55,14 @@ tags:
55
 
56
  ### 🎤 **Voice Features**
57
  - **Speech-to-Text**: Voice input transcription using Gemini MCP
58
- - **Text-to-Speech**: Voice output generation using Maya1 TTS model (optional, fallback to MCP if unavailable)
 
 
 
 
 
 
 
59
 
60
  ### ⚙️ **Advanced Configuration**
61
  - Customizable generation parameters (temperature, top-p, top-k)
 
55
 
56
  ### 🎤 **Voice Features**
57
  - **Speech-to-Text**: Voice input transcription using Gemini MCP
58
+ - **Inline Mic Experience**: Built-in microphone widget with live recording timer that drops transcripts straight into the chat box
59
+ - **Text-to-Speech**: Voice output generation using Maya1 TTS model (optional, fallback to MCP if unavailable) plus a one-click "Play Response" control for the latest answer
60
+
61
+ ### 🛡️ **Autonomous Guardrails**
62
+ - **Gemini Supervisor Tasks**: Time-aware directives keep MedSwin within token budgets and can fast-track by skipping optional web search
63
+ - **Self-Reflection Loop**: Gemini MCP scores complex answers and appends improvement hints when quality drops
64
+ - **Automatic Citations**: Web-grounded replies include deduplicated source links from the latest search batch
65
+ - **Deterministic Mode**: `Disable agentic reasoning` switch runs MedSwin alone for offline-friendly, model-only answers
66
 
67
  ### ⚙️ **Advanced Configuration**
68
  - Customizable generation parameters (temperature, top-p, top-k)
app.py CHANGED
@@ -1170,37 +1170,48 @@ def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, us
1170
 
1171
  return strategy
1172
 
1173
- async def gemini_supervisor_directives_async(query: str, reasoning: dict, plan: dict, time_elapsed: float, max_duration: int = 120) -> dict:
1174
- """Request supervisor-style task breakdown from Gemini MCP"""
 
 
 
 
1175
  remaining_time = max(15, max_duration - time_elapsed)
1176
- plan_json = json.dumps(plan)
1177
- reasoning_json = json.dumps(reasoning)
1178
 
1179
- prompt = f"""You supervise a MedSwin medical specialist model that has a limited output window (~800 tokens).
1180
- Break the following medical query into concise sequential tasks so MedSwin can answer step-by-step.
 
 
 
 
 
 
 
 
1181
 
1182
  Query: "{query}"
1183
- Reasoning Analysis: {reasoning_json}
1184
- Existing Plan: {plan_json}
1185
- Time Remaining (soft limit): ~{remaining_time:.1f}s (hard limit {max_duration}s). Avoid more than 3 tasks.
1186
 
1187
- Return JSON with:
1188
  {{
1189
- "overall_strategy": "short summary of approach (<=200 chars)",
1190
- "tasks": [
1191
- {{"id": 1, "instruction": "concrete directive for MedSwin", "expected_tokens": 200, "challenge": "optional short challenge to double-check"}},
 
 
 
 
 
1192
  ...
1193
  ],
1194
- "fast_track": true/false, # true if remaining_time < 25s
1195
- "escalation_prompt": "optional single-line reminder to wrap up quickly if time is low"
1196
  }}
1197
 
1198
- Ensure tasks reference medical reasoning and are ordered so MedSwin can execute one-by-one."""
1199
 
1200
- system_prompt = (
1201
- "You are Gemini MCP supervising a constrained MedSwin model. "
1202
- "Produce structured JSON that keeps MedSwin focused and concise."
1203
- )
1204
 
1205
  response = await call_agent(
1206
  user_prompt=prompt,
@@ -1210,38 +1221,157 @@ Ensure tasks reference medical reasoning and are ordered so MedSwin can execute
1210
  )
1211
 
1212
  try:
1213
- start = response.find('{')
1214
- end = response.rfind('}') + 1
1215
- if start >= 0 and end > start:
1216
- directives = json.loads(response[start:end])
 
 
 
 
1217
  else:
1218
  raise ValueError("Supervisor JSON not found")
1219
  except Exception as exc:
1220
- logger.error(f"Supervisor directive parsing failed: {exc}")
1221
- directives = {
1222
- "overall_strategy": "Address tasks sequentially with concise clinical bullets.",
1223
- "tasks": [
1224
- {"id": 1, "instruction": "Summarize the patient's core question.", "expected_tokens": 120},
1225
- {"id": 2, "instruction": "List key clinical insights or differential items.", "expected_tokens": 200},
1226
- {"id": 3, "instruction": "Deliver final guidance and follow-up recommendations.", "expected_tokens": 180},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1227
  ],
1228
- "fast_track": remaining_time < 25,
1229
- "escalation_prompt": "Wrap up immediately if time is almost over."
1230
  }
1231
- return directives
1232
 
1233
- def gemini_supervisor_directives(query: str, reasoning: dict, plan: dict, time_elapsed: float, max_duration: int = 120) -> dict:
1234
- """Wrapper to obtain supervisor directives synchronously"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1235
  if not MCP_AVAILABLE:
1236
- logger.warning("Gemini MCP unavailable for supervisor directives, using fallback.")
1237
  return {
1238
- "overall_strategy": "Follow the internal plan sequentially with concise sections.",
1239
- "tasks": [
1240
- {"id": step_idx + 1, "instruction": step.get("description", step.get("action", "")), "expected_tokens": 180}
1241
- for step_idx, step in enumerate(plan.get("steps", [])[:3])
1242
  ],
1243
- "fast_track": False,
1244
- "escalation_prompt": ""
1245
  }
1246
 
1247
  try:
@@ -1250,54 +1380,166 @@ def gemini_supervisor_directives(query: str, reasoning: dict, plan: dict, time_e
1250
  try:
1251
  import nest_asyncio
1252
  return nest_asyncio.run(
1253
- gemini_supervisor_directives_async(query, reasoning, plan, time_elapsed, max_duration)
1254
  )
1255
  except Exception as exc:
1256
- logger.error(f"Nested supervisor directive execution failed: {exc}")
1257
  raise
1258
  return loop.run_until_complete(
1259
- gemini_supervisor_directives_async(query, reasoning, plan, time_elapsed, max_duration)
1260
  )
1261
  except Exception as exc:
1262
- logger.error(f"Supervisor directive request failed: {exc}")
1263
  return {
1264
- "overall_strategy": "Provide a concise clinical answer with numbered mini-sections.",
1265
- "tasks": [
1266
- {"id": 1, "instruction": "Clarify the medical problem and relevant context.", "expected_tokens": 150},
1267
- {"id": 2, "instruction": "Give evidence-based assessment or reasoning.", "expected_tokens": 200},
1268
- {"id": 3, "instruction": "State actionable guidance and cautions.", "expected_tokens": 150},
1269
  ],
1270
- "fast_track": True,
1271
- "escalation_prompt": "Deliver the final summary immediately if time is nearly done."
1272
  }
1273
 
1274
- def format_supervisor_directives_text(directives: dict) -> str:
1275
- """Convert supervisor directive dict into prompt-friendly text"""
1276
- if not directives:
1277
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
 
1279
- lines = []
1280
- overall = directives.get("overall_strategy")
1281
- if overall:
1282
- lines.append(f"Supervisor Strategy: {overall}")
1283
-
1284
- tasks = directives.get("tasks") or []
1285
- for task in tasks:
1286
- task_id = task.get("id")
1287
- instruction = task.get("instruction", "").strip()
1288
- challenge = task.get("challenge", "").strip()
1289
- expected_tokens = task.get("expected_tokens", 180)
1290
- if instruction:
1291
- task_line = f"{task_id}. {instruction} (target ≤{expected_tokens} tokens)"
1292
- if challenge:
1293
- task_line += f" | Challenge: {challenge}"
1294
- lines.append(task_line)
1295
-
1296
- escalation = directives.get("escalation_prompt")
1297
- if escalation:
1298
- lines.append(f"Escalation: {escalation}")
1299
-
1300
- return "\n".join(lines)
 
 
 
 
 
 
 
1301
 
1302
  async def self_reflection_gemini(answer: str, query: str) -> dict:
1303
  """Self-reflection using Gemini MCP"""
@@ -1615,104 +1857,72 @@ def stream_chat(
1615
  index_dir = f"./{user_id}_index"
1616
  has_rag_index = os.path.exists(index_dir)
1617
 
1618
- supervisor_directives = None
1619
- supervisor_directives_text = ""
1620
- time_pressure_flag = False
1621
- time_pressure_message = ""
1622
 
1623
- # If agentic reasoning is disabled, skip all reasoning/planning and use MedSwin model alone
1624
- if disable_agentic_reasoning:
1625
- logger.info("🚫 Agentic reasoning disabled - using MedSwin model alone")
1626
- reasoning = None
1627
- plan = None
1628
- execution_strategy = None
1629
- final_use_rag = False # Disable RAG when agentic reasoning is disabled
1630
- final_use_web_search = False # Disable web search when agentic reasoning is disabled
1631
- reasoning_note = ""
1632
- original_lang = detect_language(message)
1633
- original_message = message
1634
- needs_translation = False # Skip translation when agentic reasoning is disabled
1635
- else:
1636
- # ===== AUTONOMOUS REASONING =====
1637
- logger.info("🤔 Starting autonomous reasoning...")
1638
- reasoning = autonomous_reasoning(message, history)
1639
-
1640
- # ===== PLANNING =====
1641
- logger.info("📋 Creating execution plan...")
1642
- plan = create_execution_plan(reasoning, message, has_rag_index)
1643
-
1644
- # ===== AUTONOMOUS EXECUTION STRATEGY =====
1645
- logger.info("🎯 Determining execution strategy...")
1646
- execution_strategy = autonomous_execution_strategy(reasoning, plan, use_rag, use_web_search, has_rag_index)
1647
-
1648
- # Use autonomous strategy decisions (respect user's RAG setting and user toggles)
1649
- final_use_rag = execution_strategy["use_rag"] and has_rag_index # Only use RAG if enabled AND documents exist
1650
- final_use_web_search = execution_strategy["use_web_search"]
1651
-
1652
- reasoning_note = execution_strategy.get("rationale", "")
1653
- if reasoning_note:
1654
- logger.info(f"Autonomous reasoning note: {reasoning_note}")
1655
-
1656
- supervisor_directives = gemini_supervisor_directives(
1657
- message,
1658
- reasoning,
1659
- plan,
1660
- elapsed(),
1661
- max_duration=120
1662
- )
1663
- supervisor_directives_text = format_supervisor_directives_text(supervisor_directives)
1664
- if supervisor_directives_text:
1665
- logger.info(f"Gemini Supervisor Tasks:\n{supervisor_directives_text}")
1666
-
1667
- if supervisor_directives.get("fast_track"):
1668
- logger.info("⚡ Supervisor requested fast-track execution to respect time budget.")
1669
- final_use_web_search = False # Skip optional web search when supervisor requests fast track
1670
- logger.info("⚡ Supervisor: Fast-track requested due to limited time. Prioritizing concise synthesis.")
1671
-
1672
-
1673
- # Detect language and translate if needed (Step 1 of plan)
1674
  original_lang = detect_language(message)
1675
  original_message = message
1676
  needs_translation = original_lang != "en"
1677
 
1678
  if needs_translation:
1679
- logger.info(f"Detected non-English language: {original_lang}, translating to English...")
1680
  message = translate_text(message, target_lang="en", source_lang=original_lang)
1681
- logger.info(f"Translated query: {message}")
1682
 
1683
- # Initialize medical model
1684
- medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
 
1685
 
1686
- # Adjust system prompt based on RAG setting and reasoning
1687
  if disable_agentic_reasoning:
1688
- # Simple system prompt when agentic reasoning is disabled
1689
- base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers."
1690
- elif final_use_rag:
1691
- base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers based on the provided medical documents and context."
 
 
 
 
 
1692
  else:
1693
- base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only."
 
 
1694
 
1695
- # Add reasoning context to system prompt for complex queries (only if reasoning is enabled)
1696
- if not disable_agentic_reasoning and reasoning and reasoning.get("complexity") in ["complex", "multi_faceted"]:
1697
- base_system_prompt += f"\n\nQuery Analysis: This is a {reasoning['complexity']} {reasoning['query_type']} query. Address all sub-questions: {', '.join(reasoning.get('sub_questions', [])[:3])}"
1698
-
1699
- if supervisor_directives_text:
1700
- base_system_prompt += (
1701
- f"\n\nGemini Supervisor Directives:\n{supervisor_directives_text}\n"
1702
- "Execute the tasks one-by-one, keeping each section within the suggested token budget. "
1703
- "If time becomes tight, summarize remaining insights immediately."
1704
- )
1705
-
1706
- # ===== EXECUTION: RAG Retrieval (Step 2) =====
1707
- rag_context = ""
1708
- source_info = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
  if final_use_rag and has_rag_index:
1710
  if elapsed() >= soft_timeout - 10:
1711
- logger.warning("⏱️ Skipping RAG retrieval due to time pressure.")
1712
- time_pressure_flag = True
1713
- time_pressure_message = "Skipped some retrieval steps to finish within the time limit."
1714
  final_use_rag = False
1715
  else:
 
1716
  embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
1717
  Settings.embed_model = embed_model
1718
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
@@ -1722,178 +1932,37 @@ def stream_chat(
1722
  base_retriever,
1723
  storage_context=storage_context,
1724
  simple_ratio_thresh=merge_threshold,
1725
- verbose=True
1726
  )
1727
- logger.info(f"Query: {message}")
1728
- retrieval_start = time.time()
1729
  merged_nodes = auto_merging_retriever.retrieve(message)
1730
- logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - retrieval_start:.2f}s")
1731
- merged_file_sources = {}
1732
- for node in merged_nodes:
1733
- if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
1734
- file_name = node.node.metadata['file_name']
1735
- if file_name not in merged_file_sources:
1736
- merged_file_sources[file_name] = 0
1737
- merged_file_sources[file_name] += 1
1738
- logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
1739
- rag_context = "\n\n".join([n.node.text for n in merged_nodes])
1740
- if merged_file_sources:
1741
- source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
1742
-
1743
- # ===== EXECUTION: Web Search (Step 3) =====
1744
- web_context = ""
1745
- web_sources = []
1746
- web_urls = [] # Store URLs for citations
1747
- if final_use_web_search:
1748
- if elapsed() >= soft_timeout - 5:
1749
- logger.warning("⏱️ Skipping web search to stay within execution window.")
1750
- time_pressure_flag = True
1751
- time_pressure_message = "Web search skipped due to time constraints."
1752
- final_use_web_search = False
1753
- else:
1754
- logger.info("🌐 Performing web search (will use Gemini MCP for summarization)...")
1755
- web_results = search_web(message, max_results=5)
1756
- if web_results:
1757
- logger.info(f"📊 Found {len(web_results)} web search results, now summarizing with Gemini MCP...")
1758
- web_summary = summarize_web_content(web_results, message)
1759
- if web_summary and len(web_summary) > 50: # Check if we got a real summary
1760
- web_context = f"\n\nAdditional Web Sources (summarized with Gemini MCP):\n{web_summary}"
1761
- else:
1762
- # Fallback: use first result's content
1763
- web_context = f"\n\nAdditional Web Sources:\n{web_results[0].get('content', '')[:500]}"
1764
- web_sources = [r['title'] for r in web_results[:3]]
1765
- # Extract unique URLs for citations
1766
- web_urls = [r.get('url', '') for r in web_results if r.get('url')]
1767
- logger.info(f"✅ Web search completed: {len(web_results)} results, summarized with Gemini MCP")
1768
- else:
1769
- logger.warning("⚠️ Web search returned no results")
1770
-
1771
- # Build final context
1772
- context_parts = []
1773
- if rag_context:
1774
- context_parts.append(f"Document Context:\n{rag_context}")
1775
- if web_context:
1776
- context_parts.append(web_context)
1777
-
1778
- full_context = "\n\n".join(context_parts) if context_parts else ""
1779
-
1780
- # Build system prompt
1781
- if final_use_rag or final_use_web_search:
1782
- formatted_system_prompt = f"{base_system_prompt}\n\n{full_context}{source_info}"
1783
- else:
1784
- formatted_system_prompt = base_system_prompt
1785
-
1786
- # Prepare messages
1787
- messages = [{"role": "system", "content": formatted_system_prompt}]
1788
- for entry in history:
1789
- messages.append(entry)
1790
- messages.append({"role": "user", "content": message})
1791
-
1792
- # Get EOS token and adjust stopping criteria
1793
- eos_token_id = medical_tokenizer.eos_token_id
1794
- if eos_token_id is None:
1795
- eos_token_id = medical_tokenizer.pad_token_id
1796
-
1797
- # Increase max tokens for medical models (prevent early stopping)
1798
- max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
1799
- max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
1800
-
1801
- # Check if tokenizer has chat template, otherwise format manually
1802
- if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
1803
- try:
1804
- prompt = medical_tokenizer.apply_chat_template(
1805
- messages,
1806
- tokenize=False,
1807
- add_generation_prompt=True
1808
- )
1809
- except Exception as e:
1810
- logger.warning(f"Chat template failed, using manual formatting: {e}")
1811
- # Fallback to manual formatting
1812
- prompt = format_prompt_manually(messages, medical_tokenizer)
1813
- else:
1814
- # Manual formatting for models without chat template
1815
- prompt = format_prompt_manually(messages, medical_tokenizer)
1816
-
1817
- inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
1818
- prompt_length = inputs['input_ids'].shape[1]
1819
-
1820
- stop_event = threading.Event()
1821
-
1822
- class StopOnEvent(StoppingCriteria):
1823
- def __init__(self, stop_event):
1824
- super().__init__()
1825
- self.stop_event = stop_event
1826
-
1827
- def __call__(self, input_ids, scores, **kwargs):
1828
- return self.stop_event.is_set()
1829
-
1830
- # Custom stopping criteria that doesn't stop on EOS too early
1831
- class MedicalStoppingCriteria(StoppingCriteria):
1832
- def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
1833
- super().__init__()
1834
- self.eos_token_id = eos_token_id
1835
- self.prompt_length = prompt_length
1836
- self.min_new_tokens = min_new_tokens
1837
-
1838
- def __call__(self, input_ids, scores, **kwargs):
1839
- current_length = input_ids.shape[1]
1840
- new_tokens = current_length - self.prompt_length
1841
- last_token = input_ids[0, -1].item()
1842
 
1843
- # Don't stop on EOS if we haven't generated enough new tokens
1844
- if new_tokens < self.min_new_tokens:
1845
- return False
1846
- # Allow EOS after minimum new tokens have been generated
1847
- return last_token == self.eos_token_id
1848
-
1849
- stopping_criteria = StoppingCriteriaList([
1850
- StopOnEvent(stop_event),
1851
- MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
1852
- ])
1853
-
1854
- def monitor_timeout():
1855
- nonlocal time_pressure_flag, time_pressure_message
1856
- while not stop_event.is_set():
1857
- current_elapsed = elapsed()
1858
- if current_elapsed >= hard_timeout:
1859
- logger.warning("⏳ Hard timeout reached – stopping generation thread.")
1860
- if not time_pressure_flag:
1861
- time_pressure_flag = True
1862
- if not time_pressure_message:
1863
- time_pressure_message = "Stopped early to respect the 120s execution window."
1864
- stop_event.set()
1865
- break
1866
- time.sleep(0.5)
1867
 
1868
- streamer = TextIteratorStreamer(
1869
- medical_tokenizer,
1870
- skip_prompt=True,
1871
- skip_special_tokens=True
1872
- )
1873
 
1874
- temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
1875
- top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
1876
- top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
1877
- penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
1878
 
1879
- generation_kwargs = dict(
1880
- inputs,
1881
- streamer=streamer,
1882
- max_new_tokens=max_new_tokens,
1883
- temperature=temperature,
1884
- top_p=top_p,
1885
- top_k=top_k,
1886
- repetition_penalty=penalty,
1887
- do_sample=True,
1888
- stopping_criteria=stopping_criteria,
1889
- eos_token_id=eos_token_id,
1890
- pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
1891
- )
1892
 
1893
- thread = threading.Thread(target=medical_model_obj.generate, kwargs=generation_kwargs)
1894
- thread.start()
1895
- timeout_thread = threading.Thread(target=monitor_timeout, daemon=True)
1896
- timeout_thread.start()
1897
 
1898
  updated_history = history + [
1899
  {"role": "user", "content": original_message},
@@ -1901,73 +1970,100 @@ def stream_chat(
1901
  ]
1902
  yield updated_history
1903
 
1904
- partial_response = ""
1905
- try:
1906
- for new_text in streamer:
1907
- partial_response += new_text
1908
- updated_history[-1]["content"] = partial_response
1909
- yield updated_history
1910
-
1911
- if not time_pressure_flag and elapsed() >= soft_timeout:
1912
- logger.warning("⏱️ Soft timeout reached – finalizing response.")
1913
- time_pressure_flag = True
1914
- if not time_pressure_message:
1915
- time_pressure_message = "Soft timeout reached. Delivering final answer early."
1916
- stop_event.set()
1917
  break
1918
 
1919
- # ===== SELF-REFLECTION (Step 6) =====
1920
- if not disable_agentic_reasoning and reasoning and reasoning.get("complexity") in ["complex", "multi_faceted"]:
1921
- logger.info("🔍 Performing self-reflection on answer quality...")
1922
- reflection = self_reflection(partial_response, message, reasoning)
1923
-
1924
- # Add reflection note if score is low or improvements suggested
1925
- if reflection.get("overall_score", 10) < 7 or reflection.get("improvement_suggestions"):
1926
- reflection_note = f"\n\n---\n**Self-Reflection** (Score: {reflection.get('overall_score', 'N/A')}/10)"
1927
- if reflection.get("improvement_suggestions"):
1928
- reflection_note += f"\n💡 Suggestions: {', '.join(reflection['improvement_suggestions'][:2])}"
1929
- partial_response += reflection_note
1930
- updated_history[-1]["content"] = partial_response
1931
 
1932
- # Add reasoning note if autonomous override occurred
1933
- # Internal planning notes stay in logs only; nothing is prepended to the user response
 
 
 
1934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1935
  # Translate back if needed
1936
- if needs_translation and partial_response:
1937
- logger.info(f"Translating response back to {original_lang}...")
1938
- translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
1939
- partial_response = translated_response
1940
 
1941
  # Add citations if web sources were used
1942
  citations_text = ""
1943
  if web_urls:
1944
- # Get unique domains
1945
  unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
1946
  citation_links = []
1947
  for url in unique_urls[:5]: # Limit to 5 citations
1948
  domain = format_url_as_domain(url)
1949
  if domain:
1950
- # Create markdown link: [domain](url)
1951
  citation_links.append(f"[{domain}]({url})")
1952
 
1953
  if citation_links:
1954
  citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
1955
 
1956
- if time_pressure_flag and time_pressure_message:
1957
- partial_response += f"\n\n⏱️ {time_pressure_message}"
1958
-
1959
- # Add speaker icon and citations to assistant message
1960
  speaker_icon = ' 🔊'
1961
- partial_response_with_speaker = partial_response + citations_text + speaker_icon
1962
- updated_history[-1]["content"] = partial_response_with_speaker
1963
 
1964
- stop_event.set() # Ensure timeout monitor thread exits once response is finalized
 
1965
  yield updated_history
1966
 
1967
- except GeneratorExit:
1968
- stop_event.set()
1969
- thread.join()
1970
- raise
1971
 
1972
  def generate_speech_for_message(text: str):
1973
  """Generate speech for a message and return audio file"""
 
1170
 
1171
  return strategy
1172
 
1173
+ async def gemini_supervisor_breakdown_async(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
1174
+ """
1175
+ Gemini Supervisor: Break user query into 2-4 sub-topics (JSON format)
1176
+ This is the main supervisor function that orchestrates the MAC architecture.
1177
+ All internal thoughts are logged, not displayed.
1178
+ """
1179
  remaining_time = max(15, max_duration - time_elapsed)
 
 
1180
 
1181
+ mode_description = []
1182
+ if use_rag:
1183
+ mode_description.append("RAG mode enabled - will use retrieved documents")
1184
+ if use_web_search:
1185
+ mode_description.append("Web search mode enabled - will search online sources")
1186
+ if not mode_description:
1187
+ mode_description.append("Direct answer mode - no additional context")
1188
+
1189
+ prompt = f"""You are a supervisor agent coordinating with a MedSwin medical specialist model.
1190
+ Break the following medical query into 2-4 focused sub-topics that MedSwin can answer sequentially.
1191
 
1192
  Query: "{query}"
1193
+ Mode: {', '.join(mode_description)}
1194
+ Time Remaining: ~{remaining_time:.1f}s
 
1195
 
1196
+ Return ONLY valid JSON (no markdown, no tables, no explanations):
1197
  {{
1198
+ "sub_topics": [
1199
+ {{
1200
+ "id": 1,
1201
+ "topic": "concise topic name",
1202
+ "instruction": "specific directive for MedSwin to answer this topic",
1203
+ "expected_tokens": 200,
1204
+ "priority": "high|medium|low"
1205
+ }},
1206
  ...
1207
  ],
1208
+ "max_topics": 4,
1209
+ "strategy": "brief strategy description"
1210
  }}
1211
 
1212
+ Keep topics focused and actionable. Each topic should be answerable in ~200 tokens by MedSwin."""
1213
 
1214
+ system_prompt = "You are a medical query supervisor. Break queries into structured JSON sub-topics. Return ONLY valid JSON."
 
 
 
1215
 
1216
  response = await call_agent(
1217
  user_prompt=prompt,
 
1221
  )
1222
 
1223
  try:
1224
+ # Extract JSON from response
1225
+ json_start = response.find('{')
1226
+ json_end = response.rfind('}') + 1
1227
+ if json_start >= 0 and json_end > json_start:
1228
+ breakdown = json.loads(response[json_start:json_end])
1229
+ logger.info(f"[GEMINI SUPERVISOR] Query broken into {len(breakdown.get('sub_topics', []))} sub-topics")
1230
+ logger.debug(f"[GEMINI SUPERVISOR] Breakdown: {json.dumps(breakdown, indent=2)}")
1231
+ return breakdown
1232
  else:
1233
  raise ValueError("Supervisor JSON not found")
1234
  except Exception as exc:
1235
+ logger.error(f"[GEMINI SUPERVISOR] Breakdown parsing failed: {exc}")
1236
+ # Fallback: simple breakdown
1237
+ breakdown = {
1238
+ "sub_topics": [
1239
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high"},
1240
+ {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium"},
1241
+ ],
1242
+ "max_topics": 2,
1243
+ "strategy": "Sequential answer with key points"
1244
+ }
1245
+ logger.warning(f"[GEMINI SUPERVISOR] Using fallback breakdown")
1246
+ return breakdown
1247
+
1248
+ async def gemini_supervisor_search_strategies_async(query: str, time_elapsed: float) -> dict:
1249
+ """
1250
+ Gemini Supervisor: In search mode, break query into 1-4 searching strategies
1251
+ Returns JSON with search strategies that will be executed with ddgs
1252
+ """
1253
+ prompt = f"""You are supervising web search for a medical query.
1254
+ Break this query into 1-4 focused search strategies (each targeting 1-2 sources).
1255
+
1256
+ Query: "{query}"
1257
+
1258
+ Return ONLY valid JSON:
1259
+ {{
1260
+ "search_strategies": [
1261
+ {{
1262
+ "id": 1,
1263
+ "strategy": "search query string",
1264
+ "target_sources": 1,
1265
+ "focus": "what to search for"
1266
+ }},
1267
+ ...
1268
+ ],
1269
+ "max_strategies": 4
1270
+ }}
1271
+
1272
+ Keep strategies focused and avoid overlap."""
1273
+
1274
+ system_prompt = "You are a search strategy supervisor. Create focused search queries. Return ONLY valid JSON."
1275
+
1276
+ response = await call_agent(
1277
+ user_prompt=prompt,
1278
+ system_prompt=system_prompt,
1279
+ model=GEMINI_MODEL_LITE, # Use lite model for search planning
1280
+ temperature=0.2
1281
+ )
1282
+
1283
+ try:
1284
+ json_start = response.find('{')
1285
+ json_end = response.rfind('}') + 1
1286
+ if json_start >= 0 and json_end > json_start:
1287
+ strategies = json.loads(response[json_start:json_end])
1288
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(strategies.get('search_strategies', []))} search strategies")
1289
+ logger.debug(f"[GEMINI SUPERVISOR] Strategies: {json.dumps(strategies, indent=2)}")
1290
+ return strategies
1291
+ else:
1292
+ raise ValueError("Search strategies JSON not found")
1293
+ except Exception as exc:
1294
+ logger.error(f"[GEMINI SUPERVISOR] Search strategies parsing failed: {exc}")
1295
+ return {
1296
+ "search_strategies": [
1297
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
1298
  ],
1299
+ "max_strategies": 1
 
1300
  }
 
1301
 
1302
+ async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
1303
+ """
1304
+ Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts
1305
+ These contexts will be passed to MedSwin to support decision-making
1306
+ """
1307
+ # Limit retrieved docs to avoid token overflow
1308
+ max_doc_length = 3000
1309
+ if len(retrieved_docs) > max_doc_length:
1310
+ retrieved_docs = retrieved_docs[:max_doc_length] + "..."
1311
+
1312
+ prompt = f"""You are supervising RAG context preparation for a medical query.
1313
+ Brainstorm the retrieved documents into 1-4 concise, focused contexts that MedSwin can use.
1314
+
1315
+ Query: "{query}"
1316
+ Retrieved Documents:
1317
+ {retrieved_docs}
1318
+
1319
+ Return ONLY valid JSON:
1320
+ {{
1321
+ "contexts": [
1322
+ {{
1323
+ "id": 1,
1324
+ "context": "concise summary of relevant information (keep under 500 chars)",
1325
+ "focus": "what this context covers",
1326
+ "relevance": "high|medium|low"
1327
+ }},
1328
+ ...
1329
+ ],
1330
+ "max_contexts": 4
1331
+ }}
1332
+
1333
+ Keep contexts brief and factual. Avoid redundancy."""
1334
+
1335
+ system_prompt = "You are a RAG context supervisor. Summarize documents into concise contexts. Return ONLY valid JSON."
1336
+
1337
+ response = await call_agent(
1338
+ user_prompt=prompt,
1339
+ system_prompt=system_prompt,
1340
+ model=GEMINI_MODEL_LITE, # Use lite model for RAG brainstorming
1341
+ temperature=0.2
1342
+ )
1343
+
1344
+ try:
1345
+ json_start = response.find('{')
1346
+ json_end = response.rfind('}') + 1
1347
+ if json_start >= 0 and json_end > json_start:
1348
+ contexts = json.loads(response[json_start:json_end])
1349
+ logger.info(f"[GEMINI SUPERVISOR] Brainstormed {len(contexts.get('contexts', []))} RAG contexts")
1350
+ logger.debug(f"[GEMINI SUPERVISOR] Contexts: {json.dumps(contexts, indent=2)}")
1351
+ return contexts
1352
+ else:
1353
+ raise ValueError("RAG contexts JSON not found")
1354
+ except Exception as exc:
1355
+ logger.error(f"[GEMINI SUPERVISOR] RAG brainstorming parsing failed: {exc}")
1356
+ # Fallback: use retrieved docs as single context
1357
+ return {
1358
+ "contexts": [
1359
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
1360
+ ],
1361
+ "max_contexts": 1
1362
+ }
1363
+
1364
+ def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
1365
+ """Wrapper to obtain supervisor breakdown synchronously"""
1366
  if not MCP_AVAILABLE:
1367
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable, using fallback breakdown")
1368
  return {
1369
+ "sub_topics": [
1370
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high"},
1371
+ {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium"},
 
1372
  ],
1373
+ "max_topics": 2,
1374
+ "strategy": "Sequential answer with key points"
1375
  }
1376
 
1377
  try:
 
1380
  try:
1381
  import nest_asyncio
1382
  return nest_asyncio.run(
1383
+ gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
1384
  )
1385
  except Exception as exc:
1386
+ logger.error(f"[GEMINI SUPERVISOR] Nested breakdown execution failed: {exc}")
1387
  raise
1388
  return loop.run_until_complete(
1389
+ gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
1390
  )
1391
  except Exception as exc:
1392
+ logger.error(f"[GEMINI SUPERVISOR] Breakdown request failed: {exc}")
1393
  return {
1394
+ "sub_topics": [
1395
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high"},
 
 
 
1396
  ],
1397
+ "max_topics": 1,
1398
+ "strategy": "Direct answer"
1399
  }
1400
 
1401
+ def gemini_supervisor_search_strategies(query: str, time_elapsed: float) -> dict:
1402
+ """Wrapper to obtain search strategies synchronously"""
1403
+ if not MCP_AVAILABLE:
1404
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable for search strategies")
1405
+ return {
1406
+ "search_strategies": [
1407
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
1408
+ ],
1409
+ "max_strategies": 1
1410
+ }
1411
+
1412
+ try:
1413
+ loop = asyncio.get_event_loop()
1414
+ if loop.is_running():
1415
+ try:
1416
+ import nest_asyncio
1417
+ return nest_asyncio.run(gemini_supervisor_search_strategies_async(query, time_elapsed))
1418
+ except Exception as exc:
1419
+ logger.error(f"[GEMINI SUPERVISOR] Nested search strategies execution failed: {exc}")
1420
+ return {
1421
+ "search_strategies": [
1422
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
1423
+ ],
1424
+ "max_strategies": 1
1425
+ }
1426
+ return loop.run_until_complete(gemini_supervisor_search_strategies_async(query, time_elapsed))
1427
+ except Exception as exc:
1428
+ logger.error(f"[GEMINI SUPERVISOR] Search strategies request failed: {exc}")
1429
+ return {
1430
+ "search_strategies": [
1431
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
1432
+ ],
1433
+ "max_strategies": 1
1434
+ }
1435
+
1436
+ def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
1437
+ """Wrapper to obtain RAG brainstorm synchronously"""
1438
+ if not MCP_AVAILABLE:
1439
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable for RAG brainstorm")
1440
+ return {
1441
+ "contexts": [
1442
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
1443
+ ],
1444
+ "max_contexts": 1
1445
+ }
1446
+
1447
+ try:
1448
+ loop = asyncio.get_event_loop()
1449
+ if loop.is_running():
1450
+ try:
1451
+ import nest_asyncio
1452
+ return nest_asyncio.run(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
1453
+ except Exception as exc:
1454
+ logger.error(f"[GEMINI SUPERVISOR] Nested RAG brainstorm execution failed: {exc}")
1455
+ return {
1456
+ "contexts": [
1457
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
1458
+ ],
1459
+ "max_contexts": 1
1460
+ }
1461
+ return loop.run_until_complete(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
1462
+ except Exception as exc:
1463
+ logger.error(f"[GEMINI SUPERVISOR] RAG brainstorm request failed: {exc}")
1464
+ return {
1465
+ "contexts": [
1466
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
1467
+ ],
1468
+ "max_contexts": 1
1469
+ }
1470
+
1471
+ @spaces.GPU(max_duration=120)
1472
+ def execute_medswin_task(
1473
+ medical_model_obj,
1474
+ medical_tokenizer,
1475
+ task_instruction: str,
1476
+ context: str,
1477
+ system_prompt_base: str,
1478
+ temperature: float,
1479
+ max_new_tokens: int,
1480
+ top_p: float,
1481
+ top_k: int,
1482
+ penalty: float
1483
+ ) -> str:
1484
+ """
1485
+ MedSwin Specialist: Execute a single task assigned by Gemini Supervisor
1486
+ This function is tagged with @spaces.GPU to run on GPU (ZeroGPU equivalent)
1487
+ All internal thoughts are logged, only final answer is returned
1488
+ """
1489
+ # Build task-specific prompt
1490
+ if context:
1491
+ full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
1492
+ else:
1493
+ full_prompt = f"{system_prompt_base}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
1494
+
1495
+ messages = [{"role": "system", "content": full_prompt}]
1496
+
1497
+ # Format prompt
1498
+ if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
1499
+ try:
1500
+ prompt = medical_tokenizer.apply_chat_template(
1501
+ messages,
1502
+ tokenize=False,
1503
+ add_generation_prompt=True
1504
+ )
1505
+ except Exception as e:
1506
+ logger.warning(f"[MEDSWIN] Chat template failed, using manual formatting: {e}")
1507
+ prompt = format_prompt_manually(messages, medical_tokenizer)
1508
+ else:
1509
+ prompt = format_prompt_manually(messages, medical_tokenizer)
1510
+
1511
+ # Tokenize and generate
1512
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
1513
 
1514
+ eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
1515
+
1516
+ with torch.no_grad():
1517
+ outputs = medical_model_obj.generate(
1518
+ **inputs,
1519
+ max_new_tokens=min(max_new_tokens, 800), # Limit per task
1520
+ temperature=temperature,
1521
+ top_p=top_p,
1522
+ top_k=top_k,
1523
+ repetition_penalty=penalty,
1524
+ do_sample=True,
1525
+ eos_token_id=eos_token_id,
1526
+ pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
1527
+ )
1528
+
1529
+ # Decode response
1530
+ response = medical_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
1531
+
1532
+ # Clean response - remove any table-like formatting, ensure Markdown bullets
1533
+ response = response.strip()
1534
+ # Remove table markers if present
1535
+ if "|" in response and "---" in response:
1536
+ logger.warning("[MEDSWIN] Detected table format, converting to Markdown bullets")
1537
+ # Simple conversion: split by lines and convert to bullets
1538
+ lines = [line.strip() for line in response.split('\n') if line.strip() and not line.strip().startswith('|') and '---' not in line]
1539
+ response = '\n'.join([f"- {line}" if not line.startswith('-') else line for line in lines])
1540
+
1541
+ logger.info(f"[MEDSWIN] Task completed: {len(response)} chars generated")
1542
+ return response
1543
 
1544
  async def self_reflection_gemini(answer: str, query: str) -> dict:
1545
  """Self-reflection using Gemini MCP"""
 
1857
  index_dir = f"./{user_id}_index"
1858
  has_rag_index = os.path.exists(index_dir)
1859
 
1860
+ # ===== MAC ARCHITECTURE: GEMINI SUPERVISOR + MEDSWIN SPECIALIST =====
1861
+ # All internal thoughts are logged, only final answer is displayed
 
 
1862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1863
  original_lang = detect_language(message)
1864
  original_message = message
1865
  needs_translation = original_lang != "en"
1866
 
1867
  if needs_translation:
1868
+ logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
1869
  message = translate_text(message, target_lang="en", source_lang=original_lang)
1870
+ logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
1871
 
1872
+ # Determine final modes (respect user settings and availability)
1873
+ final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
1874
+ final_use_web_search = use_web_search and not disable_agentic_reasoning
1875
 
1876
+ # ===== STEP 1: GEMINI SUPERVISOR - Break query into sub-topics =====
1877
  if disable_agentic_reasoning:
1878
+ logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
1879
+ # Simple breakdown for direct mode
1880
+ breakdown = {
1881
+ "sub_topics": [
1882
+ {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high"}
1883
+ ],
1884
+ "max_topics": 1,
1885
+ "strategy": "Direct answer"
1886
+ }
1887
  else:
1888
+ logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
1889
+ breakdown = gemini_supervisor_breakdown(message, final_use_rag, final_use_web_search, elapsed(), max_duration=120)
1890
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
1891
 
1892
+ # ===== STEP 2: GEMINI SUPERVISOR - Handle Search Mode =====
1893
+ search_contexts = []
1894
+ web_urls = []
1895
+ if final_use_web_search:
1896
+ logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
1897
+ search_strategies = gemini_supervisor_search_strategies(message, elapsed())
1898
+
1899
+ # Execute searches for each strategy
1900
+ all_search_results = []
1901
+ for strategy in search_strategies.get("search_strategies", [])[:4]: # Max 4 strategies
1902
+ search_query = strategy.get("strategy", message)
1903
+ target_sources = strategy.get("target_sources", 2)
1904
+ logger.info(f"[GEMINI SUPERVISOR] Executing search: {search_query} (target: {target_sources} sources)")
1905
+
1906
+ results = search_web(search_query, max_results=target_sources)
1907
+ all_search_results.extend(results)
1908
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
1909
+
1910
+ # Summarize search results with Gemini
1911
+ if all_search_results:
1912
+ logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
1913
+ search_summary = summarize_web_content(all_search_results, message)
1914
+ if search_summary:
1915
+ search_contexts.append(search_summary)
1916
+ logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
1917
+
1918
+ # ===== STEP 3: GEMINI SUPERVISOR - Handle RAG Mode =====
1919
+ rag_contexts = []
1920
  if final_use_rag and has_rag_index:
1921
  if elapsed() >= soft_timeout - 10:
1922
+ logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
 
 
1923
  final_use_rag = False
1924
  else:
1925
+ logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
1926
  embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
1927
  Settings.embed_model = embed_model
1928
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
 
1932
  base_retriever,
1933
  storage_context=storage_context,
1934
  simple_ratio_thresh=merge_threshold,
1935
+ verbose=False # Reduce logging noise
1936
  )
 
 
1937
  merged_nodes = auto_merging_retriever.retrieve(message)
1938
+ retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
1939
+ logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1940
 
1941
+ # Brainstorm retrieved docs into contexts
1942
+ logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
1943
+ rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
1944
+ rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
1945
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1946
 
1947
+ # ===== STEP 4: MEDSWIN SPECIALIST - Execute tasks sequentially =====
1948
+ # Initialize medical model
1949
+ medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
 
 
1950
 
1951
+ # Base system prompt for MedSwin (clean, no internal thoughts)
1952
+ base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
 
 
1953
 
1954
+ # Prepare context for MedSwin (combine RAG and search contexts)
1955
+ combined_context = ""
1956
+ if rag_contexts:
1957
+ combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4]) # Max 4 contexts
1958
+ if search_contexts:
1959
+ if combined_context:
1960
+ combined_context += "\n\n"
1961
+ combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
 
 
 
 
 
1962
 
1963
+ # Execute MedSwin tasks for each sub-topic
1964
+ logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
1965
+ medswin_answers = []
 
1966
 
1967
  updated_history = history + [
1968
  {"role": "user", "content": original_message},
 
1970
  ]
1971
  yield updated_history
1972
 
1973
+ for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
1974
+ if elapsed() >= hard_timeout - 5:
1975
+ logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
 
 
 
 
 
 
 
 
 
 
1976
  break
1977
 
1978
+ task_instruction = sub_topic.get("instruction", "")
1979
+ topic_name = sub_topic.get("topic", f"Topic {idx}")
1980
+ priority = sub_topic.get("priority", "medium")
1981
+
1982
+ logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
 
 
 
 
 
 
 
1983
 
1984
+ # Select relevant context for this task (if multiple contexts available)
1985
+ task_context = combined_context
1986
+ if len(rag_contexts) > 1 and idx <= len(rag_contexts):
1987
+ # Use corresponding RAG context if available
1988
+ task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
1989
 
1990
+ # Execute MedSwin task (with GPU tag)
1991
+ try:
1992
+ task_answer = execute_medswin_task(
1993
+ medical_model_obj=medical_model_obj,
1994
+ medical_tokenizer=medical_tokenizer,
1995
+ task_instruction=task_instruction,
1996
+ context=task_context if task_context else "",
1997
+ system_prompt_base=base_system_prompt,
1998
+ temperature=temperature,
1999
+ max_new_tokens=min(max_new_tokens, 800), # Limit per task
2000
+ top_p=top_p,
2001
+ top_k=top_k,
2002
+ penalty=penalty
2003
+ )
2004
+
2005
+ # Format task answer with topic header
2006
+ formatted_answer = f"## {topic_name}\n\n{task_answer}"
2007
+ medswin_answers.append(formatted_answer)
2008
+ logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
2009
+
2010
+ # Stream partial answer as we complete each task
2011
+ partial_final = "\n\n".join(medswin_answers)
2012
+ updated_history[-1]["content"] = partial_final
2013
+ yield updated_history
2014
+
2015
+ except Exception as e:
2016
+ logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
2017
+ # Continue with next task
2018
+ continue
2019
+
2020
+ # ===== STEP 5: Combine all MedSwin answers into final answer =====
2021
+ final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
2022
+
2023
+ # Clean final answer - ensure no tables, only Markdown bullets
2024
+ if "|" in final_answer and "---" in final_answer:
2025
+ logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
2026
+ lines = final_answer.split('\n')
2027
+ cleaned_lines = []
2028
+ for line in lines:
2029
+ if '|' in line and '---' not in line:
2030
+ # Convert table row to bullet points
2031
+ cells = [cell.strip() for cell in line.split('|') if cell.strip()]
2032
+ if cells:
2033
+ cleaned_lines.append(f"- {' / '.join(cells)}")
2034
+ elif '---' not in line:
2035
+ cleaned_lines.append(line)
2036
+ final_answer = '\n'.join(cleaned_lines)
2037
+
2038
+ # ===== STEP 6: Finalize answer (translate, add citations, format) =====
2039
  # Translate back if needed
2040
+ if needs_translation and final_answer:
2041
+ logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
2042
+ final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
 
2043
 
2044
  # Add citations if web sources were used
2045
  citations_text = ""
2046
  if web_urls:
 
2047
  unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
2048
  citation_links = []
2049
  for url in unique_urls[:5]: # Limit to 5 citations
2050
  domain = format_url_as_domain(url)
2051
  if domain:
 
2052
  citation_links.append(f"[{domain}]({url})")
2053
 
2054
  if citation_links:
2055
  citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
2056
 
2057
+ # Add speaker icon
 
 
 
2058
  speaker_icon = ' 🔊'
2059
+ final_answer_with_metadata = final_answer + citations_text + speaker_icon
 
2060
 
2061
+ # Update history with final answer (ONLY final answer, no internal thoughts)
2062
+ updated_history[-1]["content"] = final_answer_with_metadata
2063
  yield updated_history
2064
 
2065
+ # Log completion
2066
+ logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
 
 
2067
 
2068
  def generate_speech_for_message(text: str):
2069
  """Generate speech for a message and return audio file"""