LiamKhoaLe commited on
Commit
b720259
·
1 Parent(s): e4c0a6a

Improve adaptive MAC pipeline

Browse files

- cache MCP tool metadata and reuse embedding instances to cut roundtrips
- add autonomous planning, parallel search, and diagnostics instrumentation in stream_chat
- update README to describe adaptive strategy, telemetry, and faster search flow

Files changed (2) hide show
  1. README.md +6 -1
  2. app.py +217 -58
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 'MedicalMCP RAG & Search with MedSwin'
12
  tags:
13
  - mcp-in-action-track-enterprise
14
  - mcp-in-action-track-creative
@@ -70,6 +70,11 @@ tags:
70
  - **Markdown Format**: Final answers use bullet points (tables automatically converted)
71
  - **Deterministic Mode**: `Disable agentic reasoning` switch runs MedSwin alone for offline-friendly, model-only answers
72
 
 
 
 
 
 
73
  ### ⚙️ **Advanced Configuration**
74
  - Customizable generation parameters (temperature, top-p, top-k)
75
  - Adjustable retrieval settings (top-k, merge threshold)
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: 'Medical MCP agentic RAG & Search with MedSwin'
12
  tags:
13
  - mcp-in-action-track-enterprise
14
  - mcp-in-action-track-creative
 
70
  - **Markdown Format**: Final answers use bullet points (tables automatically converted)
71
  - **Deterministic Mode**: `Disable agentic reasoning` switch runs MedSwin alone for offline-friendly, model-only answers
72
 
73
+ ### ⚡ **Adaptive Strategy & Diagnostics**
74
+ - **Autonomous Planner**: Gemini reasoning now enables/disables RAG and web search dynamically per query while respecting user toggles.
75
+ - **Parallel Search Flow**: Multi-strategy web lookups run concurrently with cached MCP tool discovery and shared embeddings to cut latency.
76
+ - **Pipeline Telemetry**: Every session logs stage durations, strategy decisions, and search outcomes for fast troubleshooting and quality tracking.
77
+
78
  ### ⚙️ **Advanced Configuration**
79
  - Customizable generation parameters (temperature, top-p, top-k)
80
  - Adjustable retrieval settings (top-k, merge threshold)
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import threading
7
  import time
8
  import json
 
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
@@ -241,6 +242,7 @@ global_medical_models = {}
241
  global_medical_tokenizers = {}
242
  global_file_info = {}
243
  global_tts_model = None
 
244
 
245
  # MCP client storage
246
  global_mcp_session = None
@@ -358,6 +360,45 @@ async def get_mcp_session():
358
  global_mcp_stdio_ctx = None
359
  return None
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
362
  """Call Gemini MCP generate_content tool"""
363
  if not MCP_AVAILABLE:
@@ -370,27 +411,23 @@ async def call_agent(user_prompt: str, system_prompt: str = None, files: list =
370
  logger.warning("Failed to get MCP session for Gemini call")
371
  return ""
372
 
373
- # List tools - session is fully initialized via ClientSession.initialize()
374
- try:
375
- tools = await session.list_tools()
376
- except Exception as e:
377
- logger.error(f" Failed to list MCP tools: {e}")
378
  return ""
379
 
380
- if not tools or not hasattr(tools, 'tools'):
381
- logger.error("Invalid tools response from MCP server")
382
- return ""
383
-
384
- # Find generate_content tool
385
  generate_tool = None
386
- for tool in tools.tools:
387
  if tool.name == "generate_content" or "generate_content" in tool.name.lower():
388
  generate_tool = tool
389
  logger.info(f"Found Gemini MCP tool: {tool.name}")
390
  break
391
 
392
  if not generate_tool:
393
- logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools.tools]}")
 
394
  return ""
395
 
396
  # Prepare arguments
@@ -457,6 +494,15 @@ def initialize_tts_model():
457
  global_tts_model = None
458
  return global_tts_model
459
 
 
 
 
 
 
 
 
 
 
460
  async def transcribe_audio_gemini(audio_path: str) -> str:
461
  """Transcribe audio using Gemini MCP"""
462
  if not MCP_AVAILABLE:
@@ -738,44 +784,32 @@ async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
738
  return []
739
 
740
  try:
741
- session = await get_mcp_session()
742
- if session is None:
743
- return []
744
-
745
- # List tools - session should be ready after proper initialization
746
- # Add a small delay to ensure server has fully processed initialization
747
- await asyncio.sleep(0.1)
748
- try:
749
- tools = await session.list_tools()
750
- except Exception as e:
751
- error_msg = str(e)
752
- # Check if it's an initialization error
753
- if "initialization" in error_msg.lower() or "before initialization" in error_msg.lower():
754
- logger.warning(f"⚠️ Server not ready yet, waiting a bit more...: {error_msg}")
755
- await asyncio.sleep(0.5)
756
- try:
757
- tools = await session.list_tools()
758
- except Exception as retry_error:
759
- logger.error(f"Failed to list MCP tools after retry: {retry_error}")
760
- return []
761
- else:
762
- logger.error(f"Failed to list MCP tools: {error_msg}")
763
- return []
764
-
765
- if not tools or not hasattr(tools, 'tools'):
766
  return []
767
 
768
- # Look for web search tools (DuckDuckGo, search, etc.)
769
  search_tool = None
770
- for tool in tools.tools:
771
  tool_name_lower = tool.name.lower()
772
  if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
773
  search_tool = tool
774
  logger.info(f"Found web search MCP tool: {tool.name}")
775
  break
776
 
 
 
 
 
 
 
 
 
 
777
  if search_tool:
778
  try:
 
 
 
779
  # Call the search tool
780
  result = await session.call_tool(
781
  search_tool.name,
@@ -823,7 +857,9 @@ async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
823
  except Exception as e:
824
  logger.error(f"Error calling web search MCP tool: {e}")
825
 
826
- return []
 
 
827
  except Exception as e:
828
  logger.error(f"Web search MCP tool error: {e}")
829
  return []
@@ -2084,7 +2120,7 @@ def create_or_update_index(files, request: gr.Request):
2084
  save_dir = f"./{user_id}_index"
2085
  # Initialize LlamaIndex modules
2086
  llm = get_llm_for_rag()
2087
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
2088
  Settings.llm = llm
2089
  Settings.embed_model = embed_model
2090
  file_stats = []
@@ -2228,15 +2264,57 @@ def stream_chat(
2228
  original_message = message
2229
  needs_translation = original_lang != "en"
2230
 
 
 
 
 
 
 
 
 
 
 
 
 
2231
  if needs_translation:
2232
  logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
2233
  message = translate_text(message, target_lang="en", source_lang=original_lang)
2234
  logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
 
2235
 
2236
  # Determine final modes (respect user settings and availability)
2237
  final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
2238
  final_use_web_search = use_web_search and not disable_agentic_reasoning
2239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2240
  # ===== STEP 1: GEMINI SUPERVISOR - Break query into sub-topics =====
2241
  if disable_agentic_reasoning:
2242
  logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
@@ -2257,19 +2335,68 @@ def stream_chat(
2257
  search_contexts = []
2258
  web_urls = []
2259
  if final_use_web_search:
 
2260
  logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
2261
  search_strategies = gemini_supervisor_search_strategies(message, elapsed())
2262
 
2263
  # Execute searches for each strategy
2264
  all_search_results = []
 
2265
  for strategy in search_strategies.get("search_strategies", [])[:4]: # Max 4 strategies
2266
  search_query = strategy.get("strategy", message)
2267
  target_sources = strategy.get("target_sources", 2)
2268
- logger.info(f"[GEMINI SUPERVISOR] Executing search: {search_query} (target: {target_sources} sources)")
2269
-
2270
- results = search_web(search_query, max_results=target_sources)
2271
- all_search_results.extend(results)
2272
- web_urls.extend([r.get('url', '') for r in results if r.get('url')])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2273
 
2274
  # Summarize search results with Gemini
2275
  if all_search_results:
@@ -2278,16 +2405,18 @@ def stream_chat(
2278
  if search_summary:
2279
  search_contexts.append(search_summary)
2280
  logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
 
2281
 
2282
  # ===== STEP 3: GEMINI SUPERVISOR - Handle RAG Mode =====
2283
  rag_contexts = []
2284
  if final_use_rag and has_rag_index:
 
2285
  if elapsed() >= soft_timeout - 10:
2286
  logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
2287
  final_use_rag = False
2288
  else:
2289
  logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
2290
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
2291
  Settings.embed_model = embed_model
2292
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
2293
  index = load_index_from_storage(storage_context, settings=Settings)
@@ -2307,6 +2436,7 @@ def stream_chat(
2307
  rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
2308
  rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
2309
  logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
 
2310
 
2311
  # ===== STEP 4: MEDSWIN SPECIALIST - Execute tasks sequentially =====
2312
  # Initialize medical model
@@ -2335,6 +2465,7 @@ def stream_chat(
2335
  thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
2336
  yield updated_history, thoughts_text
2337
 
 
2338
  for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
2339
  if elapsed() >= hard_timeout - 5:
2340
  logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
@@ -2382,11 +2513,14 @@ def stream_chat(
2382
  logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
2383
  # Continue with next task
2384
  continue
 
2385
 
2386
  # ===== STEP 5: GEMINI SUPERVISOR - Synthesize final answer with clear context =====
2387
  logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
2388
  raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] # Remove headers for synthesis
 
2389
  final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown)
 
2390
 
2391
  if not final_answer or len(final_answer.strip()) < 50:
2392
  # Fallback to simple concatenation if synthesis fails
@@ -2411,6 +2545,7 @@ def stream_chat(
2411
  # ===== STEP 6: GEMINI SUPERVISOR - Challenge and enhance answer iteratively =====
2412
  max_challenge_iterations = 2 # Limit iterations to avoid timeout
2413
  challenge_iteration = 0
 
2414
 
2415
  while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
2416
  challenge_iteration += 1
@@ -2438,23 +2573,37 @@ def stream_chat(
2438
  else:
2439
  logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
2440
  break
 
2441
 
2442
  # ===== STEP 7: Conditional search trigger (only when search mode enabled) =====
2443
  if final_use_web_search and elapsed() < soft_timeout - 10:
2444
  logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
 
2445
  clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search)
 
2446
 
2447
  if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
2448
  logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
2449
  additional_search_results = []
 
2450
  for search_query in clarity_check.get("search_queries", [])[:3]: # Limit to 3 additional searches
2451
  if elapsed() >= soft_timeout - 5:
2452
  break
 
2453
  results = search_web(search_query, max_results=2)
 
 
 
 
 
 
 
 
2454
  additional_search_results.extend(results)
2455
  web_urls.extend([r.get('url', '') for r in results if r.get('url')])
2456
 
2457
  if additional_search_results:
 
2458
  logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
2459
  additional_summary = summarize_web_content(additional_search_results, message)
2460
  if additional_summary:
@@ -2469,6 +2618,7 @@ def stream_chat(
2469
  if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
2470
  final_answer = enhanced_with_search
2471
  logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
 
2472
 
2473
  citations_text = ""
2474
 
@@ -2477,18 +2627,18 @@ def stream_chat(
2477
  if needs_translation and final_answer:
2478
  logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
2479
  final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
 
 
 
 
 
 
 
 
 
2480
 
2481
- # Add citations if web sources were used
2482
- if web_urls:
2483
- unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
2484
- citation_links = []
2485
- for url in unique_urls[:5]: # Limit to 5 citations
2486
- domain = format_url_as_domain(url)
2487
- if domain:
2488
- citation_links.append(f"[{domain}]({url})")
2489
-
2490
- if citation_links:
2491
- citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
2492
 
2493
  # Add speaker icon
2494
  speaker_icon = ' 🔊'
@@ -2504,6 +2654,15 @@ def stream_chat(
2504
  logger.removeHandler(thought_handler)
2505
 
2506
  # Log completion
 
 
 
 
 
 
 
 
 
2507
  logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
2508
 
2509
  def generate_speech_for_message(text: str):
 
6
  import threading
7
  import time
8
  import json
9
+ import concurrent.futures
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
242
  global_medical_tokenizers = {}
243
  global_file_info = {}
244
  global_tts_model = None
245
+ global_embed_model = None
246
 
247
  # MCP client storage
248
  global_mcp_session = None
 
360
  global_mcp_stdio_ctx = None
361
  return None
362
 
363
+ MCP_TOOLS_CACHE_TTL = int(os.environ.get("MCP_TOOLS_CACHE_TTL", "60"))
364
+ global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
365
+
366
+
367
+ def invalidate_mcp_tools_cache():
368
+ """Invalidate cached MCP tool metadata"""
369
+ global global_mcp_tools_cache
370
+ global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
371
+
372
+
373
+ async def get_cached_mcp_tools(force_refresh: bool = False):
374
+ """Return cached MCP tools list to avoid repeated list_tools calls"""
375
+ global global_mcp_tools_cache
376
+ if not MCP_AVAILABLE:
377
+ return []
378
+
379
+ now = time.time()
380
+ if (
381
+ not force_refresh
382
+ and global_mcp_tools_cache["tools"]
383
+ and now - global_mcp_tools_cache["timestamp"] < MCP_TOOLS_CACHE_TTL
384
+ ):
385
+ return global_mcp_tools_cache["tools"]
386
+
387
+ session = await get_mcp_session()
388
+ if session is None:
389
+ return []
390
+
391
+ try:
392
+ tools_resp = await session.list_tools()
393
+ tools_list = list(getattr(tools_resp, "tools", []) or [])
394
+ global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
395
+ return tools_list
396
+ except Exception as e:
397
+ logger.error(f"Failed to refresh MCP tools: {e}")
398
+ invalidate_mcp_tools_cache()
399
+ return []
400
+
401
+
402
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
403
  """Call Gemini MCP generate_content tool"""
404
  if not MCP_AVAILABLE:
 
411
  logger.warning("Failed to get MCP session for Gemini call")
412
  return ""
413
 
414
+ tools = await get_cached_mcp_tools()
415
+ if not tools:
416
+ tools = await get_cached_mcp_tools(force_refresh=True)
417
+ if not tools:
418
+ logger.error("Unable to obtain MCP tool catalog for Gemini calls")
419
  return ""
420
 
 
 
 
 
 
421
  generate_tool = None
422
+ for tool in tools:
423
  if tool.name == "generate_content" or "generate_content" in tool.name.lower():
424
  generate_tool = tool
425
  logger.info(f"Found Gemini MCP tool: {tool.name}")
426
  break
427
 
428
  if not generate_tool:
429
+ logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
430
+ invalidate_mcp_tools_cache()
431
  return ""
432
 
433
  # Prepare arguments
 
494
  global_tts_model = None
495
  return global_tts_model
496
 
497
+
498
+ def get_or_create_embed_model():
499
+ """Reuse embedding model to avoid reloading weights each request"""
500
+ global global_embed_model
501
+ if global_embed_model is None:
502
+ logger.info("Initializing shared embedding model for RAG retrieval...")
503
+ global_embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
504
+ return global_embed_model
505
+
506
  async def transcribe_audio_gemini(audio_path: str) -> str:
507
  """Transcribe audio using Gemini MCP"""
508
  if not MCP_AVAILABLE:
 
784
  return []
785
 
786
  try:
787
+ tools = await get_cached_mcp_tools()
788
+ if not tools:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
  return []
790
 
 
791
  search_tool = None
792
+ for tool in tools:
793
  tool_name_lower = tool.name.lower()
794
  if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
795
  search_tool = tool
796
  logger.info(f"Found web search MCP tool: {tool.name}")
797
  break
798
 
799
+ if not search_tool:
800
+ tools = await get_cached_mcp_tools(force_refresh=True)
801
+ for tool in tools:
802
+ tool_name_lower = tool.name.lower()
803
+ if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
804
+ search_tool = tool
805
+ logger.info(f"Found web search MCP tool after refresh: {tool.name}")
806
+ break
807
+
808
  if search_tool:
809
  try:
810
+ session = await get_mcp_session()
811
+ if session is None:
812
+ return []
813
  # Call the search tool
814
  result = await session.call_tool(
815
  search_tool.name,
 
857
  except Exception as e:
858
  logger.error(f"Error calling web search MCP tool: {e}")
859
 
860
+ else:
861
+ logger.debug("No MCP web search tool discovered in current catalog")
862
+ return []
863
  except Exception as e:
864
  logger.error(f"Web search MCP tool error: {e}")
865
  return []
 
2120
  save_dir = f"./{user_id}_index"
2121
  # Initialize LlamaIndex modules
2122
  llm = get_llm_for_rag()
2123
+ embed_model = get_or_create_embed_model()
2124
  Settings.llm = llm
2125
  Settings.embed_model = embed_model
2126
  file_stats = []
 
2264
  original_message = message
2265
  needs_translation = original_lang != "en"
2266
 
2267
+ pipeline_diagnostics = {
2268
+ "reasoning": None,
2269
+ "plan": None,
2270
+ "strategy_decisions": [],
2271
+ "stage_metrics": {},
2272
+ "search": {"strategies": [], "total_results": 0}
2273
+ }
2274
+
2275
+ def record_stage(stage_name: str, start_time: float):
2276
+ pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
2277
+
2278
+ translation_stage_start = time.time()
2279
  if needs_translation:
2280
  logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
2281
  message = translate_text(message, target_lang="en", source_lang=original_lang)
2282
  logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
2283
+ record_stage("translation", translation_stage_start)
2284
 
2285
  # Determine final modes (respect user settings and availability)
2286
  final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
2287
  final_use_web_search = use_web_search and not disable_agentic_reasoning
2288
 
2289
+ plan = None
2290
+ if not disable_agentic_reasoning:
2291
+ reasoning_stage_start = time.time()
2292
+ reasoning = autonomous_reasoning(message, history)
2293
+ record_stage("autonomous_reasoning", reasoning_stage_start)
2294
+ pipeline_diagnostics["reasoning"] = reasoning
2295
+ plan = create_execution_plan(reasoning, message, has_rag_index)
2296
+ pipeline_diagnostics["plan"] = plan
2297
+ execution_strategy = autonomous_execution_strategy(
2298
+ reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
2299
+ )
2300
+
2301
+ if final_use_rag and not reasoning.get("requires_rag", True):
2302
+ final_use_rag = False
2303
+ pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
2304
+ elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
2305
+ pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
2306
+
2307
+ if final_use_web_search and not reasoning.get("requires_web_search", False):
2308
+ final_use_web_search = False
2309
+ pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
2310
+ elif not final_use_web_search and reasoning.get("requires_web_search", False):
2311
+ if not use_web_search:
2312
+ pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
2313
+ else:
2314
+ pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
2315
+ else:
2316
+ pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
2317
+
2318
  # ===== STEP 1: GEMINI SUPERVISOR - Break query into sub-topics =====
2319
  if disable_agentic_reasoning:
2320
  logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
 
2335
  search_contexts = []
2336
  web_urls = []
2337
  if final_use_web_search:
2338
+ search_stage_start = time.time()
2339
  logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
2340
  search_strategies = gemini_supervisor_search_strategies(message, elapsed())
2341
 
2342
  # Execute searches for each strategy
2343
  all_search_results = []
2344
+ strategy_jobs = []
2345
  for strategy in search_strategies.get("search_strategies", [])[:4]: # Max 4 strategies
2346
  search_query = strategy.get("strategy", message)
2347
  target_sources = strategy.get("target_sources", 2)
2348
+ strategy_jobs.append({
2349
+ "query": search_query,
2350
+ "target_sources": target_sources,
2351
+ "meta": strategy
2352
+ })
2353
+
2354
+ def execute_search(job):
2355
+ job_start = time.time()
2356
+ try:
2357
+ results = search_web(job["query"], max_results=job["target_sources"])
2358
+ duration = time.time() - job_start
2359
+ return results, duration, None
2360
+ except Exception as exc:
2361
+ return [], time.time() - job_start, exc
2362
+
2363
+ def record_search_diag(job, duration, results_count, error=None):
2364
+ entry = {
2365
+ "query": job["query"],
2366
+ "target_sources": job["target_sources"],
2367
+ "duration": round(duration, 3),
2368
+ "results": results_count
2369
+ }
2370
+ if error:
2371
+ entry["error"] = str(error)
2372
+ pipeline_diagnostics["search"]["strategies"].append(entry)
2373
+
2374
+ if strategy_jobs:
2375
+ max_workers = min(len(strategy_jobs), 4)
2376
+ if len(strategy_jobs) > 1:
2377
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
2378
+ future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
2379
+ for future in concurrent.futures.as_completed(future_map):
2380
+ job = future_map[future]
2381
+ try:
2382
+ results, duration, error = future.result()
2383
+ except Exception as exc:
2384
+ results, duration, error = [], 0.0, exc
2385
+ record_search_diag(job, duration, len(results), error)
2386
+ if not error and results:
2387
+ all_search_results.extend(results)
2388
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
2389
+ else:
2390
+ job = strategy_jobs[0]
2391
+ results, duration, error = execute_search(job)
2392
+ record_search_diag(job, duration, len(results), error)
2393
+ if not error and results:
2394
+ all_search_results.extend(results)
2395
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
2396
+ else:
2397
+ pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
2398
+
2399
+ pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
2400
 
2401
  # Summarize search results with Gemini
2402
  if all_search_results:
 
2405
  if search_summary:
2406
  search_contexts.append(search_summary)
2407
  logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
2408
+ record_stage("web_search", search_stage_start)
2409
 
2410
  # ===== STEP 3: GEMINI SUPERVISOR - Handle RAG Mode =====
2411
  rag_contexts = []
2412
  if final_use_rag and has_rag_index:
2413
+ rag_stage_start = time.time()
2414
  if elapsed() >= soft_timeout - 10:
2415
  logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
2416
  final_use_rag = False
2417
  else:
2418
  logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
2419
+ embed_model = get_or_create_embed_model()
2420
  Settings.embed_model = embed_model
2421
  storage_context = StorageContext.from_defaults(persist_dir=index_dir)
2422
  index = load_index_from_storage(storage_context, settings=Settings)
 
2436
  rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
2437
  rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
2438
  logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
2439
+ record_stage("rag_retrieval", rag_stage_start)
2440
 
2441
  # ===== STEP 4: MEDSWIN SPECIALIST - Execute tasks sequentially =====
2442
  # Initialize medical model
 
2465
  thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
2466
  yield updated_history, thoughts_text
2467
 
2468
+ medswin_stage_start = time.time()
2469
  for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
2470
  if elapsed() >= hard_timeout - 5:
2471
  logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
 
2513
  logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
2514
  # Continue with next task
2515
  continue
2516
+ record_stage("medswin_tasks", medswin_stage_start)
2517
 
2518
  # ===== STEP 5: GEMINI SUPERVISOR - Synthesize final answer with clear context =====
2519
  logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
2520
  raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] # Remove headers for synthesis
2521
+ synthesis_stage_start = time.time()
2522
  final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown)
2523
+ record_stage("synthesis", synthesis_stage_start)
2524
 
2525
  if not final_answer or len(final_answer.strip()) < 50:
2526
  # Fallback to simple concatenation if synthesis fails
 
2545
  # ===== STEP 6: GEMINI SUPERVISOR - Challenge and enhance answer iteratively =====
2546
  max_challenge_iterations = 2 # Limit iterations to avoid timeout
2547
  challenge_iteration = 0
2548
+ challenge_stage_start = time.time()
2549
 
2550
  while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
2551
  challenge_iteration += 1
 
2573
  else:
2574
  logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
2575
  break
2576
+ record_stage("challenge_loop", challenge_stage_start)
2577
 
2578
  # ===== STEP 7: Conditional search trigger (only when search mode enabled) =====
2579
  if final_use_web_search and elapsed() < soft_timeout - 10:
2580
  logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
2581
+ clarity_stage_start = time.time()
2582
  clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search)
2583
+ record_stage("clarity_check", clarity_stage_start)
2584
 
2585
  if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
2586
  logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
2587
  additional_search_results = []
2588
+ followup_stage_start = time.time()
2589
  for search_query in clarity_check.get("search_queries", [])[:3]: # Limit to 3 additional searches
2590
  if elapsed() >= soft_timeout - 5:
2591
  break
2592
+ extra_start = time.time()
2593
  results = search_web(search_query, max_results=2)
2594
+ extra_duration = time.time() - extra_start
2595
+ pipeline_diagnostics["search"]["strategies"].append({
2596
+ "query": search_query,
2597
+ "target_sources": 2,
2598
+ "duration": round(extra_duration, 3),
2599
+ "results": len(results),
2600
+ "type": "followup"
2601
+ })
2602
  additional_search_results.extend(results)
2603
  web_urls.extend([r.get('url', '') for r in results if r.get('url')])
2604
 
2605
  if additional_search_results:
2606
+ pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
2607
  logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
2608
  additional_summary = summarize_web_content(additional_search_results, message)
2609
  if additional_summary:
 
2618
  if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
2619
  final_answer = enhanced_with_search
2620
  logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
2621
+ record_stage("followup_search", followup_stage_start)
2622
 
2623
  citations_text = ""
2624
 
 
2627
  if needs_translation and final_answer:
2628
  logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
2629
  final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
2630
+
2631
+ # Add citations if web sources were used
2632
+ if web_urls:
2633
+ unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
2634
+ citation_links = []
2635
+ for url in unique_urls[:5]: # Limit to 5 citations
2636
+ domain = format_url_as_domain(url)
2637
+ if domain:
2638
+ citation_links.append(f"[{domain}]({url})")
2639
 
2640
+ if citation_links:
2641
+ citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
 
 
 
 
 
 
 
 
 
2642
 
2643
  # Add speaker icon
2644
  speaker_icon = ' 🔊'
 
2654
  logger.removeHandler(thought_handler)
2655
 
2656
  # Log completion
2657
+ diag_summary = {
2658
+ "stage_metrics": pipeline_diagnostics["stage_metrics"],
2659
+ "decisions": pipeline_diagnostics["strategy_decisions"],
2660
+ "search": pipeline_diagnostics["search"],
2661
+ }
2662
+ try:
2663
+ logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
2664
+ except Exception:
2665
+ logger.info(f"[MAC] Diagnostics summary (non-serializable)")
2666
  logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
2667
 
2668
  def generate_speech_for_message(text: str):