Spaces:
Running
on
Zero
Running
on
Zero
| """Main chat pipeline - stream_chat function""" | |
| import os | |
| import json | |
| import time | |
| import logging | |
| import threading | |
| import concurrent.futures | |
| import hashlib | |
| import gradio as gr | |
| import spaces | |
| from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage | |
| from llama_index.core import Settings | |
| from llama_index.core.retrievers import AutoMergingRetriever | |
| from logger import logger, ThoughtCaptureHandler | |
| from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state | |
| from utils import detect_language, translate_text, format_url_as_domain | |
| from search import search_web, summarize_web_content | |
| from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy | |
| from supervisor import ( | |
| gemini_supervisor_breakdown, gemini_supervisor_search_strategies, | |
| gemini_supervisor_rag_brainstorm, execute_medswin_task, | |
| gemini_supervisor_synthesize, gemini_supervisor_challenge, | |
| gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity, | |
| gemini_clinical_intake_triage, gemini_summarize_clinical_insights, | |
| MAX_SEARCH_STRATEGIES | |
| ) | |
| MAX_CLINICAL_QA_ROUNDS = 5 | |
| _clinical_intake_sessions = {} | |
| _clinical_intake_lock = threading.Lock() | |
| # Thread pool executor for running Gemini supervisor calls without blocking GPU task | |
| _gemini_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="gemini-supervisor") | |
| def run_gemini_in_thread(fn, *args, **kwargs): | |
| """ | |
| Run Gemini supervisor function in a separate thread to avoid blocking GPU task. | |
| This ensures Gemini API calls don't consume GPU task time and cause timeouts. | |
| """ | |
| try: | |
| future = _gemini_executor.submit(fn, *args, **kwargs) | |
| # Set a reasonable timeout (30s) to prevent hanging | |
| result = future.result(timeout=30.0) | |
| return result | |
| except concurrent.futures.TimeoutError: | |
| logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s") | |
| # Return fallback based on function | |
| if "breakdown" in fn.__name__: | |
| return { | |
| "sub_topics": [ | |
| {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"} | |
| ], | |
| "strategy": "Direct answer (timeout fallback)", | |
| "exploration_note": "Gemini supervisor timeout" | |
| } | |
| elif "search_strategies" in fn.__name__: | |
| return { | |
| "search_strategies": [ | |
| {"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"} | |
| ], | |
| "max_strategies": 1 | |
| } | |
| elif "rag_brainstorm" in fn.__name__: | |
| return { | |
| "contexts": [ | |
| {"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"} | |
| ], | |
| "max_contexts": 1 | |
| } | |
| elif "synthesize" in fn.__name__: | |
| return "\n\n".join(args[1] if len(args) > 1 else []) | |
| elif "challenge" in fn.__name__: | |
| return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""} | |
| elif "enhance_answer" in fn.__name__: | |
| return args[1] if len(args) > 1 else "" | |
| elif "check_clarity" in fn.__name__: | |
| return {"is_unclear": False, "needs_search": False, "search_queries": []} | |
| elif "clinical_intake_triage" in fn.__name__: | |
| return { | |
| "needs_additional_info": False, | |
| "decision_reason": "Timeout fallback", | |
| "max_rounds": args[2] if len(args) > 2 else 5, | |
| "questions": [], | |
| "initial_hypotheses": [] | |
| } | |
| elif "summarize_clinical_insights" in fn.__name__: | |
| return { | |
| "patient_profile": "", | |
| "refined_problem_statement": args[0] if args else "", | |
| "key_findings": [], | |
| "handoff_note": "Proceed with regular workflow." | |
| } | |
| else: | |
| logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn.__name__}, returning None") | |
| return None | |
| except Exception as e: | |
| logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}") | |
| # Return appropriate fallback | |
| if "breakdown" in fn.__name__: | |
| return { | |
| "sub_topics": [ | |
| {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"} | |
| ], | |
| "strategy": "Direct answer (error fallback)", | |
| "exploration_note": "Gemini supervisor error" | |
| } | |
| return None | |
| def _get_clinical_intake_state(session_id: str): | |
| with _clinical_intake_lock: | |
| return _clinical_intake_sessions.get(session_id) | |
| def _set_clinical_intake_state(session_id: str, state: dict): | |
| with _clinical_intake_lock: | |
| _clinical_intake_sessions[session_id] = state | |
| def _clear_clinical_intake_state(session_id: str): | |
| with _clinical_intake_lock: | |
| _clinical_intake_sessions.pop(session_id, None) | |
| def _history_to_text(history: list, limit: int = 6) -> str: | |
| if not history: | |
| return "No prior conversation." | |
| recent = history[-limit:] | |
| lines = [] | |
| for turn in recent: | |
| role = turn.get("role", "user") | |
| content = turn.get("content", "") | |
| lines.append(f"{role}: {content}") | |
| return "\n".join(lines) | |
| def _format_intake_question(question: dict, round_idx: int, max_rounds: int, target_lang: str) -> str: | |
| header = f"🩺 Question for clarity {round_idx}/{max_rounds}" | |
| body = question.get("question") or "Could you share a bit more detail so I can give an accurate answer?" | |
| prompt_parts = [ | |
| header, | |
| body, | |
| "Please answer in 1-2 sentences so I can continue." | |
| ] | |
| prompt_text = "\n\n".join(prompt_parts) | |
| if target_lang and target_lang != "en": | |
| try: | |
| prompt_text = translate_text(prompt_text, target_lang=target_lang, source_lang="en") | |
| except Exception as exc: | |
| logger.warning(f"[INTAKE] Question translation failed: {exc}") | |
| return prompt_text | |
| def _format_qa_transcript(qa_pairs: list) -> str: | |
| if not qa_pairs: | |
| return "" | |
| lines = [] | |
| for idx, qa in enumerate(qa_pairs, 1): | |
| question = qa.get("question", "").strip() | |
| answer = qa.get("answer", "").strip() | |
| if question: | |
| lines.append(f"Q{idx}: {question}") | |
| if answer: | |
| lines.append(f"A{idx}: {answer}") | |
| lines.append("") | |
| return "\n".join(lines).strip() | |
| def _format_insights_block(insights: dict) -> str: | |
| if not insights: | |
| return "" | |
| lines = [] | |
| profile = insights.get("patient_profile") | |
| if profile: | |
| lines.append(f"- Patient profile: {profile}") | |
| for finding in insights.get("key_findings", []): | |
| title = finding.get("title", "Insight") | |
| detail = finding.get("detail", "") | |
| implication = finding.get("clinical_implication", "") | |
| line = f"- {title}: {detail}" | |
| if implication: | |
| line += f" (Clinical note: {implication})" | |
| lines.append(line) | |
| return "\n".join(lines) | |
| def _build_refined_query(base_query: str, insights: dict, insights_block: str) -> str: | |
| sections = [base_query.strip()] if base_query else [] | |
| if insights_block: | |
| sections.append(f"Clinical intake summary:\n{insights_block}") | |
| refined = insights.get("refined_problem_statement") | |
| if refined: | |
| sections.append(f"Refined problem statement:\n{refined}") | |
| handoff = insights.get("handoff_note") | |
| if handoff: | |
| sections.append(f"Handoff note:\n{handoff}") | |
| return "\n\n".join([section for section in sections if section]) | |
| def _hash_prompt_text(text: str) -> str: | |
| if not text: | |
| return "" | |
| digest = hashlib.sha1() | |
| digest.update(text.strip().encode("utf-8")) | |
| return digest.hexdigest() | |
| def _extract_pending_intake_prompt(history: list) -> str: | |
| if not history: | |
| return "" | |
| for turn in reversed(history): | |
| if turn.get("role") != "assistant": | |
| continue | |
| content = turn.get("content", "") | |
| if content.startswith("🩺 Question for clarity"): | |
| return content | |
| return "" | |
| def _rehydrate_intake_state(session_id: str, history: list): | |
| state = _get_clinical_intake_state(session_id) | |
| if state or not history: | |
| return state | |
| pending_prompt = _extract_pending_intake_prompt(history) | |
| if not pending_prompt: | |
| return None | |
| prompt_hash = _hash_prompt_text(pending_prompt) | |
| if not prompt_hash: | |
| return None | |
| with _clinical_intake_lock: | |
| for existing_id, existing_state in list(_clinical_intake_sessions.items()): | |
| if existing_state.get("awaiting_answer") and existing_state.get("last_prompt_hash") == prompt_hash: | |
| if existing_id != session_id: | |
| _clinical_intake_sessions.pop(existing_id, None) | |
| _clinical_intake_sessions[session_id] = existing_state | |
| return existing_state | |
| return None | |
| def _get_last_assistant_answer(history: list) -> str: | |
| """ | |
| Extract the last non-empty assistant answer from history. | |
| Skips clinical intake clarification prompts so that follow-up | |
| questions like "clarify your answer" refer to the real medical | |
| answer, not an intake question. | |
| """ | |
| if not history: | |
| return "" | |
| for turn in reversed(history): | |
| if turn.get("role") != "assistant": | |
| continue | |
| content = (turn.get("content") or "").strip() | |
| if not content: | |
| continue | |
| # Skip intake prompts that start with the standard header | |
| if content.startswith("🩺 Question for clarity"): | |
| continue | |
| return content | |
| return "" | |
| def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str): | |
| questions = plan.get("questions", []) or [] | |
| if not questions: | |
| return None | |
| max_rounds = plan.get("max_rounds") or len(questions) | |
| max_rounds = max(1, min(MAX_CLINICAL_QA_ROUNDS, max_rounds, len(questions))) | |
| state = { | |
| "base_query": base_query, | |
| "original_language": original_language or "en", | |
| "questions": questions, | |
| "max_rounds": max_rounds, | |
| "current_round": 1, | |
| "pending_question_index": 0, | |
| "awaiting_answer": True, | |
| "answers": [], | |
| "decision_reason": plan.get("decision_reason", ""), | |
| "initial_hypotheses": plan.get("initial_hypotheses", []), | |
| "started_at": time.time(), | |
| "last_prompt_hash": "" | |
| } | |
| _set_clinical_intake_state(session_id, state) | |
| first_prompt = _format_intake_question( | |
| questions[0], | |
| round_idx=1, | |
| max_rounds=max_rounds, | |
| target_lang=state["original_language"] | |
| ) | |
| state["last_prompt_hash"] = _hash_prompt_text(first_prompt) | |
| _set_clinical_intake_state(session_id, state) | |
| return first_prompt | |
| def _handle_clinical_answer(session_id: str, answer_text: str): | |
| state = _get_clinical_intake_state(session_id) | |
| if not state: | |
| return {"type": "error"} | |
| questions = state.get("questions", []) | |
| idx = state.get("pending_question_index", 0) | |
| if idx >= len(questions): | |
| logger.warning("[INTAKE] Pending question index out of range, ending intake session") | |
| _clear_clinical_intake_state(session_id) | |
| return {"type": "error"} | |
| question_meta = questions[idx] or {} | |
| qa_entry = { | |
| "question": question_meta.get("question", ""), | |
| "focus": question_meta.get("clinical_focus"), | |
| "why_it_matters": question_meta.get("why_it_matters"), | |
| "round": state.get("current_round", len(state.get("answers", [])) + 1), | |
| "answer": answer_text.strip() | |
| } | |
| state["answers"].append(qa_entry) | |
| next_index = idx + 1 | |
| reached_round_limit = len(state["answers"]) >= state["max_rounds"] | |
| if reached_round_limit or next_index >= len(questions): | |
| # Run in thread pool to avoid blocking GPU task | |
| insights = run_gemini_in_thread(gemini_summarize_clinical_insights, state["base_query"], state["answers"]) | |
| insights_block = _format_insights_block(insights) | |
| refined_query = _build_refined_query(state["base_query"], insights, insights_block) | |
| transcript = _format_qa_transcript(state["answers"]) | |
| _clear_clinical_intake_state(session_id) | |
| return { | |
| "type": "insights", | |
| "insights": insights, | |
| "insights_block": insights_block, | |
| "refined_query": refined_query, | |
| "qa_pairs": state["answers"], | |
| "qa_transcript": transcript | |
| } | |
| state["pending_question_index"] = next_index | |
| state["current_round"] = len(state["answers"]) + 1 | |
| state["awaiting_answer"] = True | |
| _set_clinical_intake_state(session_id, state) | |
| next_question = questions[next_index] | |
| prompt = _format_intake_question( | |
| next_question, | |
| round_idx=state["current_round"], | |
| max_rounds=state["max_rounds"], | |
| target_lang=state["original_language"] | |
| ) | |
| state["last_prompt_hash"] = _hash_prompt_text(prompt) | |
| _set_clinical_intake_state(session_id, state) | |
| return {"type": "question", "prompt": prompt} | |
| def stream_chat( | |
| message: str, | |
| history: list, | |
| system_prompt: str, | |
| temperature: float, | |
| max_new_tokens: int, | |
| top_p: float, | |
| top_k: int, | |
| penalty: float, | |
| retriever_k: int, | |
| merge_threshold: float, | |
| use_rag: bool, | |
| medical_model: str, | |
| use_web_search: bool, | |
| enable_clinical_intake: bool, | |
| disable_agentic_reasoning: bool, | |
| show_thoughts: bool, | |
| request: gr.Request | |
| ): | |
| """Main chat pipeline implementing MAC architecture""" | |
| if not request: | |
| yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], "" | |
| return | |
| # Check if model is loaded before proceeding | |
| if not is_model_loaded(medical_model): | |
| loading_state = get_model_loading_state(medical_model) | |
| if loading_state == "loading": | |
| error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages." | |
| else: | |
| error_msg = f"⚠️ {medical_model} is not ready. Please wait for the model to finish loading." | |
| # Try to load it | |
| try: | |
| set_model_loading_state(medical_model, "loading") | |
| initialize_medical_model(medical_model) | |
| # If successful, continue | |
| except Exception as e: | |
| error_msg = f"⚠️ Error loading {medical_model}: {str(e)[:200]}. Please try again." | |
| yield history + [{"role": "assistant", "content": error_msg}], "" | |
| return | |
| if not is_model_loaded(medical_model): | |
| yield history + [{"role": "assistant", "content": error_msg}], "" | |
| return | |
| thought_handler = None | |
| if show_thoughts: | |
| thought_handler = ThoughtCaptureHandler() | |
| thought_handler.setLevel(logging.INFO) | |
| thought_handler.clear() | |
| logger.addHandler(thought_handler) | |
| session_start = time.time() | |
| soft_timeout = 100 | |
| hard_timeout = 118 | |
| def elapsed(): | |
| return time.time() - session_start | |
| user_id = request.session_hash or "anonymous" | |
| index_dir = f"./{user_id}_index" | |
| has_rag_index = os.path.exists(index_dir) | |
| original_lang = detect_language(message) | |
| original_message = message | |
| needs_translation = original_lang != "en" | |
| pipeline_diagnostics = { | |
| "reasoning": None, | |
| "plan": None, | |
| "strategy_decisions": [], | |
| "stage_metrics": {}, | |
| "search": {"strategies": [], "total_results": 0}, | |
| "clinical_intake": { | |
| "enabled": enable_clinical_intake, | |
| "activated": False, | |
| "rounds": 0, | |
| "reason": "", | |
| "insights": [], | |
| "plan": [], | |
| "qa_pairs": [], | |
| "transcript": "", | |
| "insights_block": "" | |
| } | |
| } | |
| def record_stage(stage_name: str, start_time: float): | |
| pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3) | |
| translation_stage_start = time.time() | |
| if needs_translation: | |
| logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...") | |
| message = translate_text(message, target_lang="en", source_lang=original_lang) | |
| logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...") | |
| record_stage("translation", translation_stage_start) | |
| final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning | |
| final_use_web_search = use_web_search and not disable_agentic_reasoning | |
| # Initialize updated_history early to avoid UnboundLocalError | |
| updated_history = history + [ | |
| {"role": "user", "content": original_message}, | |
| {"role": "assistant", "content": ""} | |
| ] | |
| clinical_intake_context_block = "" | |
| # Clinical intake currently uses Gemini-based supervisors. | |
| # When agentic reasoning is disabled, we also skip all Gemini-driven | |
| # intake planning and summarization so the flow is purely MedSwin. | |
| if disable_agentic_reasoning or not enable_clinical_intake: | |
| _clear_clinical_intake_state(user_id) | |
| else: | |
| intake_state = _rehydrate_intake_state(user_id, history) | |
| if intake_state and intake_state.get("awaiting_answer"): | |
| logger.info("[INTAKE] Awaiting patient response - processing answer") | |
| intake_result = _handle_clinical_answer(user_id, message) | |
| if intake_result.get("type") == "question": | |
| logger.info("[INTAKE] Requesting additional follow-up") | |
| updated_history[-1]["content"] = intake_result["prompt"] | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| if thought_handler: | |
| logger.removeHandler(thought_handler) | |
| return | |
| if intake_result.get("type") == "insights": | |
| pipeline_diagnostics["clinical_intake"]["activated"] = True | |
| pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", [])) | |
| pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", []) | |
| pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", []) | |
| pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "") | |
| pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "") | |
| base_refined = intake_result.get("refined_query", message) | |
| summary_section = "" | |
| transcript_section = "" | |
| if intake_result.get("insights_block"): | |
| summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}" | |
| if intake_result.get("qa_transcript"): | |
| transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}" | |
| sections = [base_refined, summary_section, transcript_section] | |
| message = "\n\n---\n\n".join([section for section in sections if section]) | |
| clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg]) | |
| else: | |
| history_context = _history_to_text(history) | |
| # Run in thread pool to avoid blocking GPU task | |
| triage_plan = run_gemini_in_thread(gemini_clinical_intake_triage, message, history_context, MAX_CLINICAL_QA_ROUNDS) | |
| pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "") | |
| pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", []) | |
| needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions") | |
| if needs_intake: | |
| first_prompt = _start_clinical_intake_session( | |
| user_id, | |
| triage_plan, | |
| message, | |
| original_lang | |
| ) | |
| if first_prompt: | |
| pipeline_diagnostics["clinical_intake"]["activated"] = True | |
| updated_history[-1]["content"] = first_prompt | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| if thought_handler: | |
| logger.removeHandler(thought_handler) | |
| return | |
| plan = None | |
| if not disable_agentic_reasoning: | |
| reasoning_stage_start = time.time() | |
| reasoning = autonomous_reasoning(message, history) | |
| record_stage("autonomous_reasoning", reasoning_stage_start) | |
| pipeline_diagnostics["reasoning"] = reasoning | |
| plan = create_execution_plan(reasoning, message, has_rag_index) | |
| pipeline_diagnostics["plan"] = plan | |
| execution_strategy = autonomous_execution_strategy( | |
| reasoning, plan, final_use_rag, final_use_web_search, has_rag_index | |
| ) | |
| if final_use_rag and not reasoning.get("requires_rag", True): | |
| final_use_rag = False | |
| pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning") | |
| elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index: | |
| pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available") | |
| if final_use_web_search and not reasoning.get("requires_web_search", False): | |
| final_use_web_search = False | |
| pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning") | |
| elif not final_use_web_search and reasoning.get("requires_web_search", False): | |
| if not use_web_search: | |
| pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request") | |
| else: | |
| pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode") | |
| else: | |
| pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user") | |
| # Update thoughts after reasoning stage | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| if disable_agentic_reasoning: | |
| logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone") | |
| breakdown = { | |
| "sub_topics": [ | |
| {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"} | |
| ], | |
| "strategy": "Direct answer", | |
| "exploration_note": "Direct mode - no breakdown" | |
| } | |
| else: | |
| logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...") | |
| # Provide previous assistant answer as context so Gemini can | |
| # interpret follow-up queries like "clarify your answer". | |
| previous_answer = _get_last_assistant_answer(history) | |
| # Run in thread pool to avoid blocking GPU task | |
| breakdown = run_gemini_in_thread( | |
| gemini_supervisor_breakdown, | |
| message, | |
| final_use_rag, | |
| final_use_web_search, | |
| elapsed(), | |
| 120, | |
| previous_answer, | |
| ) | |
| logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics") | |
| # Update thoughts after breakdown | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| search_contexts = [] | |
| web_urls = [] | |
| if final_use_web_search: | |
| search_stage_start = time.time() | |
| logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...") | |
| # Run in thread pool to avoid blocking GPU task | |
| search_strategies = run_gemini_in_thread(gemini_supervisor_search_strategies, message, elapsed()) | |
| all_search_results = [] | |
| strategy_jobs = [] | |
| for strategy in search_strategies.get("search_strategies", [])[:MAX_SEARCH_STRATEGIES]: | |
| search_query = strategy.get("strategy", message) | |
| target_sources = strategy.get("target_sources", 2) | |
| strategy_jobs.append({ | |
| "query": search_query, | |
| "target_sources": target_sources, | |
| "meta": strategy | |
| }) | |
| def execute_search(job): | |
| job_start = time.time() | |
| try: | |
| results = search_web(job["query"], max_results=job["target_sources"]) | |
| duration = time.time() - job_start | |
| return results, duration, None | |
| except Exception as exc: | |
| return [], time.time() - job_start, exc | |
| def record_search_diag(job, duration, results_count, error=None): | |
| entry = { | |
| "query": job["query"], | |
| "target_sources": job["target_sources"], | |
| "duration": round(duration, 3), | |
| "results": results_count | |
| } | |
| if error: | |
| entry["error"] = str(error) | |
| pipeline_diagnostics["search"]["strategies"].append(entry) | |
| if strategy_jobs: | |
| max_workers = min(len(strategy_jobs), 4) | |
| if len(strategy_jobs) > 1: | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_map = {executor.submit(execute_search, job): job for job in strategy_jobs} | |
| for future in concurrent.futures.as_completed(future_map): | |
| job = future_map[future] | |
| try: | |
| results, duration, error = future.result() | |
| except Exception as exc: | |
| results, duration, error = [], 0.0, exc | |
| record_search_diag(job, duration, len(results), error) | |
| if not error and results: | |
| all_search_results.extend(results) | |
| web_urls.extend([r.get('url', '') for r in results if r.get('url')]) | |
| else: | |
| job = strategy_jobs[0] | |
| results, duration, error = execute_search(job) | |
| record_search_diag(job, duration, len(results), error) | |
| if not error and results: | |
| all_search_results.extend(results) | |
| web_urls.extend([r.get('url', '') for r in results if r.get('url')]) | |
| else: | |
| pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned") | |
| pipeline_diagnostics["search"]["total_results"] = len(all_search_results) | |
| if all_search_results: | |
| logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...") | |
| search_summary = summarize_web_content(all_search_results, message) | |
| if search_summary: | |
| search_contexts.append(search_summary) | |
| logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars") | |
| record_stage("web_search", search_stage_start) | |
| rag_contexts = [] | |
| if final_use_rag and has_rag_index: | |
| rag_stage_start = time.time() | |
| if elapsed() >= soft_timeout - 10: | |
| logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure") | |
| final_use_rag = False | |
| else: | |
| logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...") | |
| embed_model = get_or_create_embed_model() | |
| Settings.embed_model = embed_model | |
| storage_context = StorageContext.from_defaults(persist_dir=index_dir) | |
| index = load_index_from_storage(storage_context, settings=Settings) | |
| base_retriever = index.as_retriever(similarity_top_k=retriever_k) | |
| auto_merging_retriever = AutoMergingRetriever( | |
| base_retriever, | |
| storage_context=storage_context, | |
| simple_ratio_thresh=merge_threshold, | |
| verbose=False | |
| ) | |
| merged_nodes = auto_merging_retriever.retrieve(message) | |
| retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes]) | |
| logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes") | |
| logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...") | |
| # Run in thread pool to avoid blocking GPU task | |
| rag_brainstorm = run_gemini_in_thread(gemini_supervisor_rag_brainstorm, message, retrieved_docs, elapsed()) | |
| rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])] | |
| logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts") | |
| record_stage("rag_retrieval", rag_stage_start) | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model) | |
| base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables." | |
| context_sections = [] | |
| if clinical_intake_context_block: | |
| context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block) | |
| if rag_contexts: | |
| context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4])) | |
| if search_contexts: | |
| context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts)) | |
| combined_context = "\n\n".join(context_sections) | |
| logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...") | |
| medswin_answers = [] | |
| # Update thoughts before starting MedSwin tasks | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| medswin_stage_start = time.time() | |
| for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1): | |
| if elapsed() >= hard_timeout - 5: | |
| logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}") | |
| break | |
| task_instruction = sub_topic.get("instruction", "") | |
| topic_name = sub_topic.get("topic", f"Topic {idx}") | |
| priority = sub_topic.get("priority", "medium") | |
| logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})") | |
| task_context = combined_context | |
| if len(rag_contexts) > 1 and idx <= len(rag_contexts): | |
| task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context | |
| # Add small delay between GPU requests to prevent ZeroGPU scheduler conflicts | |
| if idx > 1: | |
| delay = 0.5 # 500ms delay between sequential GPU requests | |
| logger.debug(f"[MEDSWIN] Waiting {delay}s before next GPU request to avoid scheduler conflicts...") | |
| time.sleep(delay) | |
| try: | |
| task_answer = execute_medswin_task( | |
| medical_model_obj=medical_model_obj, | |
| medical_tokenizer=medical_tokenizer, | |
| task_instruction=task_instruction, | |
| context=task_context if task_context else "", | |
| system_prompt_base=base_system_prompt, | |
| temperature=temperature, | |
| max_new_tokens=min(max_new_tokens, 800), | |
| top_p=top_p, | |
| top_k=top_k, | |
| penalty=penalty | |
| ) | |
| formatted_answer = f"## {topic_name}\n\n{task_answer}" | |
| medswin_answers.append(formatted_answer) | |
| logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars") | |
| partial_final = "\n\n".join(medswin_answers) | |
| updated_history[-1]["content"] = partial_final | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| except Exception as e: | |
| logger.error(f"[MEDSWIN] Task {idx} failed: {e}") | |
| continue | |
| record_stage("medswin_tasks", medswin_stage_start) | |
| # If agentic reasoning is disabled, we skip all Gemini-based synthesis, | |
| # challenge, and enhancement loops. The final answer is just the | |
| # concatenation of MedSwin task outputs. | |
| if disable_agentic_reasoning: | |
| logger.info("[MAC] Agentic reasoning disabled - skipping Gemini synthesis and challenge") | |
| if medswin_answers: | |
| final_answer = "\n\n".join(medswin_answers) | |
| else: | |
| final_answer = "I apologize, but I was unable to generate a response." | |
| else: | |
| logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...") | |
| raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] | |
| synthesis_stage_start = time.time() | |
| # Run in thread pool to avoid blocking GPU task | |
| final_answer = run_gemini_in_thread( | |
| gemini_supervisor_synthesize, message, raw_medswin_answers, rag_contexts, search_contexts, breakdown | |
| ) | |
| record_stage("synthesis", synthesis_stage_start) | |
| if not final_answer or len(final_answer.strip()) < 50: | |
| logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation") | |
| final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response." | |
| if "|" in final_answer and "---" in final_answer: | |
| logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets") | |
| lines = final_answer.split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| if '|' in line and '---' not in line: | |
| cells = [cell.strip() for cell in line.split('|') if cell.strip()] | |
| if cells: | |
| cleaned_lines.append(f"- {' / '.join(cells)}") | |
| elif '---' not in line: | |
| cleaned_lines.append(line) | |
| final_answer = '\n'.join(cleaned_lines) | |
| max_challenge_iterations = 2 | |
| challenge_iteration = 0 | |
| challenge_stage_start = time.time() | |
| while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15: | |
| challenge_iteration += 1 | |
| logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...") | |
| # Run in thread pool to avoid blocking GPU task | |
| evaluation = run_gemini_in_thread( | |
| gemini_supervisor_challenge, message, final_answer, raw_medswin_answers, rag_contexts, search_contexts | |
| ) | |
| if evaluation.get("is_optimal", False): | |
| logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)") | |
| break | |
| enhancement_instructions = evaluation.get("enhancement_instructions", "") | |
| if not enhancement_instructions: | |
| logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal") | |
| break | |
| logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...") | |
| # Run in thread pool to avoid blocking GPU task | |
| enhanced_answer = run_gemini_in_thread( | |
| gemini_supervisor_enhance_answer, message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts | |
| ) | |
| if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8: | |
| final_answer = enhanced_answer | |
| logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)") | |
| else: | |
| logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping") | |
| break | |
| record_stage("challenge_loop", challenge_stage_start) | |
| if final_use_web_search and elapsed() < soft_timeout - 10: | |
| logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...") | |
| clarity_stage_start = time.time() | |
| # Run in thread pool to avoid blocking GPU task | |
| clarity_check = run_gemini_in_thread(gemini_supervisor_check_clarity, message, final_answer, final_use_web_search) | |
| record_stage("clarity_check", clarity_stage_start) | |
| if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"): | |
| logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}") | |
| additional_search_results = [] | |
| followup_stage_start = time.time() | |
| for search_query in clarity_check.get("search_queries", [])[:3]: | |
| if elapsed() >= soft_timeout - 5: | |
| break | |
| extra_start = time.time() | |
| results = search_web(search_query, max_results=2) | |
| extra_duration = time.time() - extra_start | |
| pipeline_diagnostics["search"]["strategies"].append({ | |
| "query": search_query, | |
| "target_sources": 2, | |
| "duration": round(extra_duration, 3), | |
| "results": len(results), | |
| "type": "followup" | |
| }) | |
| additional_search_results.extend(results) | |
| web_urls.extend([r.get('url', '') for r in results if r.get('url')]) | |
| if additional_search_results: | |
| pipeline_diagnostics["search"]["total_results"] += len(additional_search_results) | |
| logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...") | |
| additional_summary = summarize_web_content(additional_search_results, message) | |
| if additional_summary: | |
| search_contexts.append(additional_summary) | |
| logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...") | |
| # Run in thread pool to avoid blocking GPU task | |
| enhanced_with_search = run_gemini_in_thread( | |
| gemini_supervisor_enhance_answer, message, final_answer, | |
| f"Incorporate the following additional information from web search: {additional_summary}", | |
| raw_medswin_answers, rag_contexts, search_contexts | |
| ) | |
| if enhanced_with_search and len(enhanced_with_search.strip()) > 50: | |
| final_answer = enhanced_with_search | |
| logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context") | |
| record_stage("followup_search", followup_stage_start) | |
| # Update thoughts after followup search | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| yield updated_history, thoughts_text | |
| citations_text = "" | |
| if needs_translation and final_answer: | |
| logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...") | |
| final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en") | |
| if web_urls: | |
| unique_urls = list(dict.fromkeys(web_urls)) | |
| citation_links = [] | |
| for url in unique_urls[:5]: | |
| domain = format_url_as_domain(url) | |
| if domain: | |
| citation_links.append(f"[{domain}]({url})") | |
| if citation_links: | |
| citations_text = "\n\n**Sources:** " + ", ".join(citation_links) | |
| speaker_icon = ' 🔊' | |
| final_answer_with_metadata = final_answer + citations_text + speaker_icon | |
| updated_history[-1]["content"] = final_answer_with_metadata | |
| thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" | |
| # Always yield thoughts_text, even if empty, to ensure UI updates | |
| yield updated_history, thoughts_text | |
| if thought_handler: | |
| logger.removeHandler(thought_handler) | |
| diag_summary = { | |
| "stage_metrics": pipeline_diagnostics["stage_metrics"], | |
| "decisions": pipeline_diagnostics["strategy_decisions"], | |
| "search": pipeline_diagnostics["search"], | |
| } | |
| try: | |
| logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}") | |
| except Exception: | |
| logger.info(f"[MAC] Diagnostics summary (non-serializable)") | |
| logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed") | |