MedLLM-Agent / pipeline.py
Y Phung Nguyen
Upd history followup
590a3e5
raw
history blame
42.1 kB
"""Main chat pipeline - stream_chat function"""
import os
import json
import time
import logging
import threading
import concurrent.futures
import hashlib
import gradio as gr
import spaces
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core import Settings
from llama_index.core.retrievers import AutoMergingRetriever
from logger import logger, ThoughtCaptureHandler
from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state
from utils import detect_language, translate_text, format_url_as_domain
from search import search_web, summarize_web_content
from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
from supervisor import (
gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
gemini_supervisor_rag_brainstorm, execute_medswin_task,
gemini_supervisor_synthesize, gemini_supervisor_challenge,
gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity,
gemini_clinical_intake_triage, gemini_summarize_clinical_insights,
MAX_SEARCH_STRATEGIES
)
MAX_CLINICAL_QA_ROUNDS = 5
_clinical_intake_sessions = {}
_clinical_intake_lock = threading.Lock()
# Thread pool executor for running Gemini supervisor calls without blocking GPU task
_gemini_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="gemini-supervisor")
def run_gemini_in_thread(fn, *args, **kwargs):
"""
Run Gemini supervisor function in a separate thread to avoid blocking GPU task.
This ensures Gemini API calls don't consume GPU task time and cause timeouts.
"""
try:
future = _gemini_executor.submit(fn, *args, **kwargs)
# Set a reasonable timeout (30s) to prevent hanging
result = future.result(timeout=30.0)
return result
except concurrent.futures.TimeoutError:
logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
# Return fallback based on function
if "breakdown" in fn.__name__:
return {
"sub_topics": [
{"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
],
"strategy": "Direct answer (timeout fallback)",
"exploration_note": "Gemini supervisor timeout"
}
elif "search_strategies" in fn.__name__:
return {
"search_strategies": [
{"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"}
],
"max_strategies": 1
}
elif "rag_brainstorm" in fn.__name__:
return {
"contexts": [
{"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"}
],
"max_contexts": 1
}
elif "synthesize" in fn.__name__:
return "\n\n".join(args[1] if len(args) > 1 else [])
elif "challenge" in fn.__name__:
return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
elif "enhance_answer" in fn.__name__:
return args[1] if len(args) > 1 else ""
elif "check_clarity" in fn.__name__:
return {"is_unclear": False, "needs_search": False, "search_queries": []}
elif "clinical_intake_triage" in fn.__name__:
return {
"needs_additional_info": False,
"decision_reason": "Timeout fallback",
"max_rounds": args[2] if len(args) > 2 else 5,
"questions": [],
"initial_hypotheses": []
}
elif "summarize_clinical_insights" in fn.__name__:
return {
"patient_profile": "",
"refined_problem_statement": args[0] if args else "",
"key_findings": [],
"handoff_note": "Proceed with regular workflow."
}
else:
logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn.__name__}, returning None")
return None
except Exception as e:
logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
# Return appropriate fallback
if "breakdown" in fn.__name__:
return {
"sub_topics": [
{"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
],
"strategy": "Direct answer (error fallback)",
"exploration_note": "Gemini supervisor error"
}
return None
def _get_clinical_intake_state(session_id: str):
with _clinical_intake_lock:
return _clinical_intake_sessions.get(session_id)
def _set_clinical_intake_state(session_id: str, state: dict):
with _clinical_intake_lock:
_clinical_intake_sessions[session_id] = state
def _clear_clinical_intake_state(session_id: str):
with _clinical_intake_lock:
_clinical_intake_sessions.pop(session_id, None)
def _history_to_text(history: list, limit: int = 6) -> str:
if not history:
return "No prior conversation."
recent = history[-limit:]
lines = []
for turn in recent:
role = turn.get("role", "user")
content = turn.get("content", "")
lines.append(f"{role}: {content}")
return "\n".join(lines)
def _format_intake_question(question: dict, round_idx: int, max_rounds: int, target_lang: str) -> str:
header = f"🩺 Question for clarity {round_idx}/{max_rounds}"
body = question.get("question") or "Could you share a bit more detail so I can give an accurate answer?"
prompt_parts = [
header,
body,
"Please answer in 1-2 sentences so I can continue."
]
prompt_text = "\n\n".join(prompt_parts)
if target_lang and target_lang != "en":
try:
prompt_text = translate_text(prompt_text, target_lang=target_lang, source_lang="en")
except Exception as exc:
logger.warning(f"[INTAKE] Question translation failed: {exc}")
return prompt_text
def _format_qa_transcript(qa_pairs: list) -> str:
if not qa_pairs:
return ""
lines = []
for idx, qa in enumerate(qa_pairs, 1):
question = qa.get("question", "").strip()
answer = qa.get("answer", "").strip()
if question:
lines.append(f"Q{idx}: {question}")
if answer:
lines.append(f"A{idx}: {answer}")
lines.append("")
return "\n".join(lines).strip()
def _format_insights_block(insights: dict) -> str:
if not insights:
return ""
lines = []
profile = insights.get("patient_profile")
if profile:
lines.append(f"- Patient profile: {profile}")
for finding in insights.get("key_findings", []):
title = finding.get("title", "Insight")
detail = finding.get("detail", "")
implication = finding.get("clinical_implication", "")
line = f"- {title}: {detail}"
if implication:
line += f" (Clinical note: {implication})"
lines.append(line)
return "\n".join(lines)
def _build_refined_query(base_query: str, insights: dict, insights_block: str) -> str:
sections = [base_query.strip()] if base_query else []
if insights_block:
sections.append(f"Clinical intake summary:\n{insights_block}")
refined = insights.get("refined_problem_statement")
if refined:
sections.append(f"Refined problem statement:\n{refined}")
handoff = insights.get("handoff_note")
if handoff:
sections.append(f"Handoff note:\n{handoff}")
return "\n\n".join([section for section in sections if section])
def _hash_prompt_text(text: str) -> str:
if not text:
return ""
digest = hashlib.sha1()
digest.update(text.strip().encode("utf-8"))
return digest.hexdigest()
def _extract_pending_intake_prompt(history: list) -> str:
if not history:
return ""
for turn in reversed(history):
if turn.get("role") != "assistant":
continue
content = turn.get("content", "")
if content.startswith("🩺 Question for clarity"):
return content
return ""
def _rehydrate_intake_state(session_id: str, history: list):
state = _get_clinical_intake_state(session_id)
if state or not history:
return state
pending_prompt = _extract_pending_intake_prompt(history)
if not pending_prompt:
return None
prompt_hash = _hash_prompt_text(pending_prompt)
if not prompt_hash:
return None
with _clinical_intake_lock:
for existing_id, existing_state in list(_clinical_intake_sessions.items()):
if existing_state.get("awaiting_answer") and existing_state.get("last_prompt_hash") == prompt_hash:
if existing_id != session_id:
_clinical_intake_sessions.pop(existing_id, None)
_clinical_intake_sessions[session_id] = existing_state
return existing_state
return None
def _get_last_assistant_answer(history: list) -> str:
"""
Extract the last non-empty assistant answer from history.
Skips clinical intake clarification prompts so that follow-up
questions like "clarify your answer" refer to the real medical
answer, not an intake question.
"""
if not history:
return ""
for turn in reversed(history):
if turn.get("role") != "assistant":
continue
content = (turn.get("content") or "").strip()
if not content:
continue
# Skip intake prompts that start with the standard header
if content.startswith("🩺 Question for clarity"):
continue
return content
return ""
def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str):
questions = plan.get("questions", []) or []
if not questions:
return None
max_rounds = plan.get("max_rounds") or len(questions)
max_rounds = max(1, min(MAX_CLINICAL_QA_ROUNDS, max_rounds, len(questions)))
state = {
"base_query": base_query,
"original_language": original_language or "en",
"questions": questions,
"max_rounds": max_rounds,
"current_round": 1,
"pending_question_index": 0,
"awaiting_answer": True,
"answers": [],
"decision_reason": plan.get("decision_reason", ""),
"initial_hypotheses": plan.get("initial_hypotheses", []),
"started_at": time.time(),
"last_prompt_hash": ""
}
_set_clinical_intake_state(session_id, state)
first_prompt = _format_intake_question(
questions[0],
round_idx=1,
max_rounds=max_rounds,
target_lang=state["original_language"]
)
state["last_prompt_hash"] = _hash_prompt_text(first_prompt)
_set_clinical_intake_state(session_id, state)
return first_prompt
def _handle_clinical_answer(session_id: str, answer_text: str):
state = _get_clinical_intake_state(session_id)
if not state:
return {"type": "error"}
questions = state.get("questions", [])
idx = state.get("pending_question_index", 0)
if idx >= len(questions):
logger.warning("[INTAKE] Pending question index out of range, ending intake session")
_clear_clinical_intake_state(session_id)
return {"type": "error"}
question_meta = questions[idx] or {}
qa_entry = {
"question": question_meta.get("question", ""),
"focus": question_meta.get("clinical_focus"),
"why_it_matters": question_meta.get("why_it_matters"),
"round": state.get("current_round", len(state.get("answers", [])) + 1),
"answer": answer_text.strip()
}
state["answers"].append(qa_entry)
next_index = idx + 1
reached_round_limit = len(state["answers"]) >= state["max_rounds"]
if reached_round_limit or next_index >= len(questions):
# Run in thread pool to avoid blocking GPU task
insights = run_gemini_in_thread(gemini_summarize_clinical_insights, state["base_query"], state["answers"])
insights_block = _format_insights_block(insights)
refined_query = _build_refined_query(state["base_query"], insights, insights_block)
transcript = _format_qa_transcript(state["answers"])
_clear_clinical_intake_state(session_id)
return {
"type": "insights",
"insights": insights,
"insights_block": insights_block,
"refined_query": refined_query,
"qa_pairs": state["answers"],
"qa_transcript": transcript
}
state["pending_question_index"] = next_index
state["current_round"] = len(state["answers"]) + 1
state["awaiting_answer"] = True
_set_clinical_intake_state(session_id, state)
next_question = questions[next_index]
prompt = _format_intake_question(
next_question,
round_idx=state["current_round"],
max_rounds=state["max_rounds"],
target_lang=state["original_language"]
)
state["last_prompt_hash"] = _hash_prompt_text(prompt)
_set_clinical_intake_state(session_id, state)
return {"type": "question", "prompt": prompt}
@spaces.GPU(max_duration=120)
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
penalty: float,
retriever_k: int,
merge_threshold: float,
use_rag: bool,
medical_model: str,
use_web_search: bool,
enable_clinical_intake: bool,
disable_agentic_reasoning: bool,
show_thoughts: bool,
request: gr.Request
):
"""Main chat pipeline implementing MAC architecture"""
if not request:
yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], ""
return
# Check if model is loaded before proceeding
if not is_model_loaded(medical_model):
loading_state = get_model_loading_state(medical_model)
if loading_state == "loading":
error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
else:
error_msg = f"⚠️ {medical_model} is not ready. Please wait for the model to finish loading."
# Try to load it
try:
set_model_loading_state(medical_model, "loading")
initialize_medical_model(medical_model)
# If successful, continue
except Exception as e:
error_msg = f"⚠️ Error loading {medical_model}: {str(e)[:200]}. Please try again."
yield history + [{"role": "assistant", "content": error_msg}], ""
return
if not is_model_loaded(medical_model):
yield history + [{"role": "assistant", "content": error_msg}], ""
return
thought_handler = None
if show_thoughts:
thought_handler = ThoughtCaptureHandler()
thought_handler.setLevel(logging.INFO)
thought_handler.clear()
logger.addHandler(thought_handler)
session_start = time.time()
soft_timeout = 100
hard_timeout = 118
def elapsed():
return time.time() - session_start
user_id = request.session_hash or "anonymous"
index_dir = f"./{user_id}_index"
has_rag_index = os.path.exists(index_dir)
original_lang = detect_language(message)
original_message = message
needs_translation = original_lang != "en"
pipeline_diagnostics = {
"reasoning": None,
"plan": None,
"strategy_decisions": [],
"stage_metrics": {},
"search": {"strategies": [], "total_results": 0},
"clinical_intake": {
"enabled": enable_clinical_intake,
"activated": False,
"rounds": 0,
"reason": "",
"insights": [],
"plan": [],
"qa_pairs": [],
"transcript": "",
"insights_block": ""
}
}
def record_stage(stage_name: str, start_time: float):
pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
translation_stage_start = time.time()
if needs_translation:
logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
message = translate_text(message, target_lang="en", source_lang=original_lang)
logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
record_stage("translation", translation_stage_start)
final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
final_use_web_search = use_web_search and not disable_agentic_reasoning
# Initialize updated_history early to avoid UnboundLocalError
updated_history = history + [
{"role": "user", "content": original_message},
{"role": "assistant", "content": ""}
]
clinical_intake_context_block = ""
# Clinical intake currently uses Gemini-based supervisors.
# When agentic reasoning is disabled, we also skip all Gemini-driven
# intake planning and summarization so the flow is purely MedSwin.
if disable_agentic_reasoning or not enable_clinical_intake:
_clear_clinical_intake_state(user_id)
else:
intake_state = _rehydrate_intake_state(user_id, history)
if intake_state and intake_state.get("awaiting_answer"):
logger.info("[INTAKE] Awaiting patient response - processing answer")
intake_result = _handle_clinical_answer(user_id, message)
if intake_result.get("type") == "question":
logger.info("[INTAKE] Requesting additional follow-up")
updated_history[-1]["content"] = intake_result["prompt"]
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
if thought_handler:
logger.removeHandler(thought_handler)
return
if intake_result.get("type") == "insights":
pipeline_diagnostics["clinical_intake"]["activated"] = True
pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", [])
pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "")
pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "")
base_refined = intake_result.get("refined_query", message)
summary_section = ""
transcript_section = ""
if intake_result.get("insights_block"):
summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}"
if intake_result.get("qa_transcript"):
transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}"
sections = [base_refined, summary_section, transcript_section]
message = "\n\n---\n\n".join([section for section in sections if section])
clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg])
else:
history_context = _history_to_text(history)
# Run in thread pool to avoid blocking GPU task
triage_plan = run_gemini_in_thread(gemini_clinical_intake_triage, message, history_context, MAX_CLINICAL_QA_ROUNDS)
pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", [])
needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
if needs_intake:
first_prompt = _start_clinical_intake_session(
user_id,
triage_plan,
message,
original_lang
)
if first_prompt:
pipeline_diagnostics["clinical_intake"]["activated"] = True
updated_history[-1]["content"] = first_prompt
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
if thought_handler:
logger.removeHandler(thought_handler)
return
plan = None
if not disable_agentic_reasoning:
reasoning_stage_start = time.time()
reasoning = autonomous_reasoning(message, history)
record_stage("autonomous_reasoning", reasoning_stage_start)
pipeline_diagnostics["reasoning"] = reasoning
plan = create_execution_plan(reasoning, message, has_rag_index)
pipeline_diagnostics["plan"] = plan
execution_strategy = autonomous_execution_strategy(
reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
)
if final_use_rag and not reasoning.get("requires_rag", True):
final_use_rag = False
pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
if final_use_web_search and not reasoning.get("requires_web_search", False):
final_use_web_search = False
pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
elif not final_use_web_search and reasoning.get("requires_web_search", False):
if not use_web_search:
pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
else:
pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
else:
pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
# Update thoughts after reasoning stage
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
if disable_agentic_reasoning:
logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
breakdown = {
"sub_topics": [
{"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
],
"strategy": "Direct answer",
"exploration_note": "Direct mode - no breakdown"
}
else:
logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
# Provide previous assistant answer as context so Gemini can
# interpret follow-up queries like "clarify your answer".
previous_answer = _get_last_assistant_answer(history)
# Run in thread pool to avoid blocking GPU task
breakdown = run_gemini_in_thread(
gemini_supervisor_breakdown,
message,
final_use_rag,
final_use_web_search,
elapsed(),
120,
previous_answer,
)
logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
# Update thoughts after breakdown
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
search_contexts = []
web_urls = []
if final_use_web_search:
search_stage_start = time.time()
logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
# Run in thread pool to avoid blocking GPU task
search_strategies = run_gemini_in_thread(gemini_supervisor_search_strategies, message, elapsed())
all_search_results = []
strategy_jobs = []
for strategy in search_strategies.get("search_strategies", [])[:MAX_SEARCH_STRATEGIES]:
search_query = strategy.get("strategy", message)
target_sources = strategy.get("target_sources", 2)
strategy_jobs.append({
"query": search_query,
"target_sources": target_sources,
"meta": strategy
})
def execute_search(job):
job_start = time.time()
try:
results = search_web(job["query"], max_results=job["target_sources"])
duration = time.time() - job_start
return results, duration, None
except Exception as exc:
return [], time.time() - job_start, exc
def record_search_diag(job, duration, results_count, error=None):
entry = {
"query": job["query"],
"target_sources": job["target_sources"],
"duration": round(duration, 3),
"results": results_count
}
if error:
entry["error"] = str(error)
pipeline_diagnostics["search"]["strategies"].append(entry)
if strategy_jobs:
max_workers = min(len(strategy_jobs), 4)
if len(strategy_jobs) > 1:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
for future in concurrent.futures.as_completed(future_map):
job = future_map[future]
try:
results, duration, error = future.result()
except Exception as exc:
results, duration, error = [], 0.0, exc
record_search_diag(job, duration, len(results), error)
if not error and results:
all_search_results.extend(results)
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
else:
job = strategy_jobs[0]
results, duration, error = execute_search(job)
record_search_diag(job, duration, len(results), error)
if not error and results:
all_search_results.extend(results)
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
else:
pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
if all_search_results:
logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
search_summary = summarize_web_content(all_search_results, message)
if search_summary:
search_contexts.append(search_summary)
logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
record_stage("web_search", search_stage_start)
rag_contexts = []
if final_use_rag and has_rag_index:
rag_stage_start = time.time()
if elapsed() >= soft_timeout - 10:
logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
final_use_rag = False
else:
logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
embed_model = get_or_create_embed_model()
Settings.embed_model = embed_model
storage_context = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_context, settings=Settings)
base_retriever = index.as_retriever(similarity_top_k=retriever_k)
auto_merging_retriever = AutoMergingRetriever(
base_retriever,
storage_context=storage_context,
simple_ratio_thresh=merge_threshold,
verbose=False
)
merged_nodes = auto_merging_retriever.retrieve(message)
retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
# Run in thread pool to avoid blocking GPU task
rag_brainstorm = run_gemini_in_thread(gemini_supervisor_rag_brainstorm, message, retrieved_docs, elapsed())
rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
record_stage("rag_retrieval", rag_stage_start)
medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
context_sections = []
if clinical_intake_context_block:
context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block)
if rag_contexts:
context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4]))
if search_contexts:
context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts))
combined_context = "\n\n".join(context_sections)
logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
medswin_answers = []
# Update thoughts before starting MedSwin tasks
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
medswin_stage_start = time.time()
for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
if elapsed() >= hard_timeout - 5:
logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
break
task_instruction = sub_topic.get("instruction", "")
topic_name = sub_topic.get("topic", f"Topic {idx}")
priority = sub_topic.get("priority", "medium")
logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
task_context = combined_context
if len(rag_contexts) > 1 and idx <= len(rag_contexts):
task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
# Add small delay between GPU requests to prevent ZeroGPU scheduler conflicts
if idx > 1:
delay = 0.5 # 500ms delay between sequential GPU requests
logger.debug(f"[MEDSWIN] Waiting {delay}s before next GPU request to avoid scheduler conflicts...")
time.sleep(delay)
try:
task_answer = execute_medswin_task(
medical_model_obj=medical_model_obj,
medical_tokenizer=medical_tokenizer,
task_instruction=task_instruction,
context=task_context if task_context else "",
system_prompt_base=base_system_prompt,
temperature=temperature,
max_new_tokens=min(max_new_tokens, 800),
top_p=top_p,
top_k=top_k,
penalty=penalty
)
formatted_answer = f"## {topic_name}\n\n{task_answer}"
medswin_answers.append(formatted_answer)
logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
partial_final = "\n\n".join(medswin_answers)
updated_history[-1]["content"] = partial_final
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
except Exception as e:
logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
continue
record_stage("medswin_tasks", medswin_stage_start)
# If agentic reasoning is disabled, we skip all Gemini-based synthesis,
# challenge, and enhancement loops. The final answer is just the
# concatenation of MedSwin task outputs.
if disable_agentic_reasoning:
logger.info("[MAC] Agentic reasoning disabled - skipping Gemini synthesis and challenge")
if medswin_answers:
final_answer = "\n\n".join(medswin_answers)
else:
final_answer = "I apologize, but I was unable to generate a response."
else:
logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers]
synthesis_stage_start = time.time()
# Run in thread pool to avoid blocking GPU task
final_answer = run_gemini_in_thread(
gemini_supervisor_synthesize, message, raw_medswin_answers, rag_contexts, search_contexts, breakdown
)
record_stage("synthesis", synthesis_stage_start)
if not final_answer or len(final_answer.strip()) < 50:
logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation")
final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
if "|" in final_answer and "---" in final_answer:
logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
lines = final_answer.split('\n')
cleaned_lines = []
for line in lines:
if '|' in line and '---' not in line:
cells = [cell.strip() for cell in line.split('|') if cell.strip()]
if cells:
cleaned_lines.append(f"- {' / '.join(cells)}")
elif '---' not in line:
cleaned_lines.append(line)
final_answer = '\n'.join(cleaned_lines)
max_challenge_iterations = 2
challenge_iteration = 0
challenge_stage_start = time.time()
while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
challenge_iteration += 1
logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...")
# Run in thread pool to avoid blocking GPU task
evaluation = run_gemini_in_thread(
gemini_supervisor_challenge, message, final_answer, raw_medswin_answers, rag_contexts, search_contexts
)
if evaluation.get("is_optimal", False):
logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)")
break
enhancement_instructions = evaluation.get("enhancement_instructions", "")
if not enhancement_instructions:
logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal")
break
logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...")
# Run in thread pool to avoid blocking GPU task
enhanced_answer = run_gemini_in_thread(
gemini_supervisor_enhance_answer, message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts
)
if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8:
final_answer = enhanced_answer
logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)")
else:
logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
break
record_stage("challenge_loop", challenge_stage_start)
if final_use_web_search and elapsed() < soft_timeout - 10:
logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
clarity_stage_start = time.time()
# Run in thread pool to avoid blocking GPU task
clarity_check = run_gemini_in_thread(gemini_supervisor_check_clarity, message, final_answer, final_use_web_search)
record_stage("clarity_check", clarity_stage_start)
if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
additional_search_results = []
followup_stage_start = time.time()
for search_query in clarity_check.get("search_queries", [])[:3]:
if elapsed() >= soft_timeout - 5:
break
extra_start = time.time()
results = search_web(search_query, max_results=2)
extra_duration = time.time() - extra_start
pipeline_diagnostics["search"]["strategies"].append({
"query": search_query,
"target_sources": 2,
"duration": round(extra_duration, 3),
"results": len(results),
"type": "followup"
})
additional_search_results.extend(results)
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
if additional_search_results:
pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
additional_summary = summarize_web_content(additional_search_results, message)
if additional_summary:
search_contexts.append(additional_summary)
logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...")
# Run in thread pool to avoid blocking GPU task
enhanced_with_search = run_gemini_in_thread(
gemini_supervisor_enhance_answer, message, final_answer,
f"Incorporate the following additional information from web search: {additional_summary}",
raw_medswin_answers, rag_contexts, search_contexts
)
if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
final_answer = enhanced_with_search
logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
record_stage("followup_search", followup_stage_start)
# Update thoughts after followup search
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
yield updated_history, thoughts_text
citations_text = ""
if needs_translation and final_answer:
logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
if web_urls:
unique_urls = list(dict.fromkeys(web_urls))
citation_links = []
for url in unique_urls[:5]:
domain = format_url_as_domain(url)
if domain:
citation_links.append(f"[{domain}]({url})")
if citation_links:
citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
speaker_icon = ' 🔊'
final_answer_with_metadata = final_answer + citations_text + speaker_icon
updated_history[-1]["content"] = final_answer_with_metadata
thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
# Always yield thoughts_text, even if empty, to ensure UI updates
yield updated_history, thoughts_text
if thought_handler:
logger.removeHandler(thought_handler)
diag_summary = {
"stage_metrics": pipeline_diagnostics["stage_metrics"],
"decisions": pipeline_diagnostics["strategy_decisions"],
"search": pipeline_diagnostics["search"],
}
try:
logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
except Exception:
logger.info(f"[MAC] Diagnostics summary (non-serializable)")
logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")