"""Main chat pipeline - stream_chat function""" import os import json import time import logging import threading import concurrent.futures import hashlib import gradio as gr import spaces from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.core import Settings from llama_index.core.retrievers import AutoMergingRetriever from logger import logger, ThoughtCaptureHandler from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state from utils import detect_language, translate_text, format_url_as_domain from search import search_web, summarize_web_content from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy from supervisor import ( gemini_supervisor_breakdown, gemini_supervisor_search_strategies, gemini_supervisor_rag_brainstorm, execute_medswin_task, gemini_supervisor_synthesize, gemini_supervisor_challenge, gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity, gemini_clinical_intake_triage, gemini_summarize_clinical_insights, MAX_SEARCH_STRATEGIES ) MAX_CLINICAL_QA_ROUNDS = 5 _clinical_intake_sessions = {} _clinical_intake_lock = threading.Lock() # Thread pool executor for running Gemini supervisor calls without blocking GPU task _gemini_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="gemini-supervisor") def run_gemini_in_thread(fn, *args, **kwargs): """ Run Gemini supervisor function in a separate thread to avoid blocking GPU task. This ensures Gemini API calls don't consume GPU task time and cause timeouts. """ try: future = _gemini_executor.submit(fn, *args, **kwargs) # Set a reasonable timeout (30s) to prevent hanging result = future.result(timeout=30.0) return result except concurrent.futures.TimeoutError: logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s") # Return fallback based on function if "breakdown" in fn.__name__: return { "sub_topics": [ {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"} ], "strategy": "Direct answer (timeout fallback)", "exploration_note": "Gemini supervisor timeout" } elif "search_strategies" in fn.__name__: return { "search_strategies": [ {"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"} ], "max_strategies": 1 } elif "rag_brainstorm" in fn.__name__: return { "contexts": [ {"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"} ], "max_contexts": 1 } elif "synthesize" in fn.__name__: return "\n\n".join(args[1] if len(args) > 1 else []) elif "challenge" in fn.__name__: return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""} elif "enhance_answer" in fn.__name__: return args[1] if len(args) > 1 else "" elif "check_clarity" in fn.__name__: return {"is_unclear": False, "needs_search": False, "search_queries": []} elif "clinical_intake_triage" in fn.__name__: return { "needs_additional_info": False, "decision_reason": "Timeout fallback", "max_rounds": args[2] if len(args) > 2 else 5, "questions": [], "initial_hypotheses": [] } elif "summarize_clinical_insights" in fn.__name__: return { "patient_profile": "", "refined_problem_statement": args[0] if args else "", "key_findings": [], "handoff_note": "Proceed with regular workflow." } else: logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn.__name__}, returning None") return None except Exception as e: logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}") # Return appropriate fallback if "breakdown" in fn.__name__: return { "sub_topics": [ {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"} ], "strategy": "Direct answer (error fallback)", "exploration_note": "Gemini supervisor error" } return None def _get_clinical_intake_state(session_id: str): with _clinical_intake_lock: return _clinical_intake_sessions.get(session_id) def _set_clinical_intake_state(session_id: str, state: dict): with _clinical_intake_lock: _clinical_intake_sessions[session_id] = state def _clear_clinical_intake_state(session_id: str): with _clinical_intake_lock: _clinical_intake_sessions.pop(session_id, None) def _history_to_text(history: list, limit: int = 6) -> str: if not history: return "No prior conversation." recent = history[-limit:] lines = [] for turn in recent: role = turn.get("role", "user") content = turn.get("content", "") lines.append(f"{role}: {content}") return "\n".join(lines) def _format_intake_question(question: dict, round_idx: int, max_rounds: int, target_lang: str) -> str: header = f"🩺 Question for clarity {round_idx}/{max_rounds}" body = question.get("question") or "Could you share a bit more detail so I can give an accurate answer?" prompt_parts = [ header, body, "Please answer in 1-2 sentences so I can continue." ] prompt_text = "\n\n".join(prompt_parts) if target_lang and target_lang != "en": try: prompt_text = translate_text(prompt_text, target_lang=target_lang, source_lang="en") except Exception as exc: logger.warning(f"[INTAKE] Question translation failed: {exc}") return prompt_text def _format_qa_transcript(qa_pairs: list) -> str: if not qa_pairs: return "" lines = [] for idx, qa in enumerate(qa_pairs, 1): question = qa.get("question", "").strip() answer = qa.get("answer", "").strip() if question: lines.append(f"Q{idx}: {question}") if answer: lines.append(f"A{idx}: {answer}") lines.append("") return "\n".join(lines).strip() def _format_insights_block(insights: dict) -> str: if not insights: return "" lines = [] profile = insights.get("patient_profile") if profile: lines.append(f"- Patient profile: {profile}") for finding in insights.get("key_findings", []): title = finding.get("title", "Insight") detail = finding.get("detail", "") implication = finding.get("clinical_implication", "") line = f"- {title}: {detail}" if implication: line += f" (Clinical note: {implication})" lines.append(line) return "\n".join(lines) def _build_refined_query(base_query: str, insights: dict, insights_block: str) -> str: sections = [base_query.strip()] if base_query else [] if insights_block: sections.append(f"Clinical intake summary:\n{insights_block}") refined = insights.get("refined_problem_statement") if refined: sections.append(f"Refined problem statement:\n{refined}") handoff = insights.get("handoff_note") if handoff: sections.append(f"Handoff note:\n{handoff}") return "\n\n".join([section for section in sections if section]) def _hash_prompt_text(text: str) -> str: if not text: return "" digest = hashlib.sha1() digest.update(text.strip().encode("utf-8")) return digest.hexdigest() def _extract_pending_intake_prompt(history: list) -> str: if not history: return "" for turn in reversed(history): if turn.get("role") != "assistant": continue content = turn.get("content", "") if content.startswith("🩺 Question for clarity"): return content return "" def _rehydrate_intake_state(session_id: str, history: list): state = _get_clinical_intake_state(session_id) if state or not history: return state pending_prompt = _extract_pending_intake_prompt(history) if not pending_prompt: return None prompt_hash = _hash_prompt_text(pending_prompt) if not prompt_hash: return None with _clinical_intake_lock: for existing_id, existing_state in list(_clinical_intake_sessions.items()): if existing_state.get("awaiting_answer") and existing_state.get("last_prompt_hash") == prompt_hash: if existing_id != session_id: _clinical_intake_sessions.pop(existing_id, None) _clinical_intake_sessions[session_id] = existing_state return existing_state return None def _get_last_assistant_answer(history: list) -> str: """ Extract the last non-empty assistant answer from history. Skips clinical intake clarification prompts so that follow-up questions like "clarify your answer" refer to the real medical answer, not an intake question. """ if not history: return "" for turn in reversed(history): if turn.get("role") != "assistant": continue content = (turn.get("content") or "").strip() if not content: continue # Skip intake prompts that start with the standard header if content.startswith("🩺 Question for clarity"): continue return content return "" def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str): questions = plan.get("questions", []) or [] if not questions: return None max_rounds = plan.get("max_rounds") or len(questions) max_rounds = max(1, min(MAX_CLINICAL_QA_ROUNDS, max_rounds, len(questions))) state = { "base_query": base_query, "original_language": original_language or "en", "questions": questions, "max_rounds": max_rounds, "current_round": 1, "pending_question_index": 0, "awaiting_answer": True, "answers": [], "decision_reason": plan.get("decision_reason", ""), "initial_hypotheses": plan.get("initial_hypotheses", []), "started_at": time.time(), "last_prompt_hash": "" } _set_clinical_intake_state(session_id, state) first_prompt = _format_intake_question( questions[0], round_idx=1, max_rounds=max_rounds, target_lang=state["original_language"] ) state["last_prompt_hash"] = _hash_prompt_text(first_prompt) _set_clinical_intake_state(session_id, state) return first_prompt def _handle_clinical_answer(session_id: str, answer_text: str): state = _get_clinical_intake_state(session_id) if not state: return {"type": "error"} questions = state.get("questions", []) idx = state.get("pending_question_index", 0) if idx >= len(questions): logger.warning("[INTAKE] Pending question index out of range, ending intake session") _clear_clinical_intake_state(session_id) return {"type": "error"} question_meta = questions[idx] or {} qa_entry = { "question": question_meta.get("question", ""), "focus": question_meta.get("clinical_focus"), "why_it_matters": question_meta.get("why_it_matters"), "round": state.get("current_round", len(state.get("answers", [])) + 1), "answer": answer_text.strip() } state["answers"].append(qa_entry) next_index = idx + 1 reached_round_limit = len(state["answers"]) >= state["max_rounds"] if reached_round_limit or next_index >= len(questions): # Run in thread pool to avoid blocking GPU task insights = run_gemini_in_thread(gemini_summarize_clinical_insights, state["base_query"], state["answers"]) insights_block = _format_insights_block(insights) refined_query = _build_refined_query(state["base_query"], insights, insights_block) transcript = _format_qa_transcript(state["answers"]) _clear_clinical_intake_state(session_id) return { "type": "insights", "insights": insights, "insights_block": insights_block, "refined_query": refined_query, "qa_pairs": state["answers"], "qa_transcript": transcript } state["pending_question_index"] = next_index state["current_round"] = len(state["answers"]) + 1 state["awaiting_answer"] = True _set_clinical_intake_state(session_id, state) next_question = questions[next_index] prompt = _format_intake_question( next_question, round_idx=state["current_round"], max_rounds=state["max_rounds"], target_lang=state["original_language"] ) state["last_prompt_hash"] = _hash_prompt_text(prompt) _set_clinical_intake_state(session_id, state) return {"type": "question", "prompt": prompt} @spaces.GPU(max_duration=120) def stream_chat( message: str, history: list, system_prompt: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float, retriever_k: int, merge_threshold: float, use_rag: bool, medical_model: str, use_web_search: bool, enable_clinical_intake: bool, disable_agentic_reasoning: bool, show_thoughts: bool, request: gr.Request ): """Main chat pipeline implementing MAC architecture""" if not request: yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], "" return # Check if model is loaded before proceeding if not is_model_loaded(medical_model): loading_state = get_model_loading_state(medical_model) if loading_state == "loading": error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages." else: error_msg = f"⚠️ {medical_model} is not ready. Please wait for the model to finish loading." # Try to load it try: set_model_loading_state(medical_model, "loading") initialize_medical_model(medical_model) # If successful, continue except Exception as e: error_msg = f"⚠️ Error loading {medical_model}: {str(e)[:200]}. Please try again." yield history + [{"role": "assistant", "content": error_msg}], "" return if not is_model_loaded(medical_model): yield history + [{"role": "assistant", "content": error_msg}], "" return thought_handler = None if show_thoughts: thought_handler = ThoughtCaptureHandler() thought_handler.setLevel(logging.INFO) thought_handler.clear() logger.addHandler(thought_handler) session_start = time.time() soft_timeout = 100 hard_timeout = 118 def elapsed(): return time.time() - session_start user_id = request.session_hash or "anonymous" index_dir = f"./{user_id}_index" has_rag_index = os.path.exists(index_dir) original_lang = detect_language(message) original_message = message needs_translation = original_lang != "en" pipeline_diagnostics = { "reasoning": None, "plan": None, "strategy_decisions": [], "stage_metrics": {}, "search": {"strategies": [], "total_results": 0}, "clinical_intake": { "enabled": enable_clinical_intake, "activated": False, "rounds": 0, "reason": "", "insights": [], "plan": [], "qa_pairs": [], "transcript": "", "insights_block": "" } } def record_stage(stage_name: str, start_time: float): pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3) translation_stage_start = time.time() if needs_translation: logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...") message = translate_text(message, target_lang="en", source_lang=original_lang) logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...") record_stage("translation", translation_stage_start) final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning final_use_web_search = use_web_search and not disable_agentic_reasoning # Initialize updated_history early to avoid UnboundLocalError updated_history = history + [ {"role": "user", "content": original_message}, {"role": "assistant", "content": ""} ] clinical_intake_context_block = "" # Clinical intake currently uses Gemini-based supervisors. # When agentic reasoning is disabled, we also skip all Gemini-driven # intake planning and summarization so the flow is purely MedSwin. if disable_agentic_reasoning or not enable_clinical_intake: _clear_clinical_intake_state(user_id) else: intake_state = _rehydrate_intake_state(user_id, history) if intake_state and intake_state.get("awaiting_answer"): logger.info("[INTAKE] Awaiting patient response - processing answer") intake_result = _handle_clinical_answer(user_id, message) if intake_result.get("type") == "question": logger.info("[INTAKE] Requesting additional follow-up") updated_history[-1]["content"] = intake_result["prompt"] thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text if thought_handler: logger.removeHandler(thought_handler) return if intake_result.get("type") == "insights": pipeline_diagnostics["clinical_intake"]["activated"] = True pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", [])) pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", []) pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", []) pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "") pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "") base_refined = intake_result.get("refined_query", message) summary_section = "" transcript_section = "" if intake_result.get("insights_block"): summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}" if intake_result.get("qa_transcript"): transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}" sections = [base_refined, summary_section, transcript_section] message = "\n\n---\n\n".join([section for section in sections if section]) clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg]) else: history_context = _history_to_text(history) # Run in thread pool to avoid blocking GPU task triage_plan = run_gemini_in_thread(gemini_clinical_intake_triage, message, history_context, MAX_CLINICAL_QA_ROUNDS) pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "") pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", []) needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions") if needs_intake: first_prompt = _start_clinical_intake_session( user_id, triage_plan, message, original_lang ) if first_prompt: pipeline_diagnostics["clinical_intake"]["activated"] = True updated_history[-1]["content"] = first_prompt thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text if thought_handler: logger.removeHandler(thought_handler) return plan = None if not disable_agentic_reasoning: reasoning_stage_start = time.time() reasoning = autonomous_reasoning(message, history) record_stage("autonomous_reasoning", reasoning_stage_start) pipeline_diagnostics["reasoning"] = reasoning plan = create_execution_plan(reasoning, message, has_rag_index) pipeline_diagnostics["plan"] = plan execution_strategy = autonomous_execution_strategy( reasoning, plan, final_use_rag, final_use_web_search, has_rag_index ) if final_use_rag and not reasoning.get("requires_rag", True): final_use_rag = False pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning") elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index: pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available") if final_use_web_search and not reasoning.get("requires_web_search", False): final_use_web_search = False pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning") elif not final_use_web_search and reasoning.get("requires_web_search", False): if not use_web_search: pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request") else: pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode") else: pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user") # Update thoughts after reasoning stage thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text if disable_agentic_reasoning: logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone") breakdown = { "sub_topics": [ {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"} ], "strategy": "Direct answer", "exploration_note": "Direct mode - no breakdown" } else: logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...") # Provide previous assistant answer as context so Gemini can # interpret follow-up queries like "clarify your answer". previous_answer = _get_last_assistant_answer(history) # Run in thread pool to avoid blocking GPU task breakdown = run_gemini_in_thread( gemini_supervisor_breakdown, message, final_use_rag, final_use_web_search, elapsed(), 120, previous_answer, ) logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics") # Update thoughts after breakdown thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text search_contexts = [] web_urls = [] if final_use_web_search: search_stage_start = time.time() logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...") # Run in thread pool to avoid blocking GPU task search_strategies = run_gemini_in_thread(gemini_supervisor_search_strategies, message, elapsed()) all_search_results = [] strategy_jobs = [] for strategy in search_strategies.get("search_strategies", [])[:MAX_SEARCH_STRATEGIES]: search_query = strategy.get("strategy", message) target_sources = strategy.get("target_sources", 2) strategy_jobs.append({ "query": search_query, "target_sources": target_sources, "meta": strategy }) def execute_search(job): job_start = time.time() try: results = search_web(job["query"], max_results=job["target_sources"]) duration = time.time() - job_start return results, duration, None except Exception as exc: return [], time.time() - job_start, exc def record_search_diag(job, duration, results_count, error=None): entry = { "query": job["query"], "target_sources": job["target_sources"], "duration": round(duration, 3), "results": results_count } if error: entry["error"] = str(error) pipeline_diagnostics["search"]["strategies"].append(entry) if strategy_jobs: max_workers = min(len(strategy_jobs), 4) if len(strategy_jobs) > 1: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_map = {executor.submit(execute_search, job): job for job in strategy_jobs} for future in concurrent.futures.as_completed(future_map): job = future_map[future] try: results, duration, error = future.result() except Exception as exc: results, duration, error = [], 0.0, exc record_search_diag(job, duration, len(results), error) if not error and results: all_search_results.extend(results) web_urls.extend([r.get('url', '') for r in results if r.get('url')]) else: job = strategy_jobs[0] results, duration, error = execute_search(job) record_search_diag(job, duration, len(results), error) if not error and results: all_search_results.extend(results) web_urls.extend([r.get('url', '') for r in results if r.get('url')]) else: pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned") pipeline_diagnostics["search"]["total_results"] = len(all_search_results) if all_search_results: logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...") search_summary = summarize_web_content(all_search_results, message) if search_summary: search_contexts.append(search_summary) logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars") record_stage("web_search", search_stage_start) rag_contexts = [] if final_use_rag and has_rag_index: rag_stage_start = time.time() if elapsed() >= soft_timeout - 10: logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure") final_use_rag = False else: logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...") embed_model = get_or_create_embed_model() Settings.embed_model = embed_model storage_context = StorageContext.from_defaults(persist_dir=index_dir) index = load_index_from_storage(storage_context, settings=Settings) base_retriever = index.as_retriever(similarity_top_k=retriever_k) auto_merging_retriever = AutoMergingRetriever( base_retriever, storage_context=storage_context, simple_ratio_thresh=merge_threshold, verbose=False ) merged_nodes = auto_merging_retriever.retrieve(message) retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes]) logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes") logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...") # Run in thread pool to avoid blocking GPU task rag_brainstorm = run_gemini_in_thread(gemini_supervisor_rag_brainstorm, message, retrieved_docs, elapsed()) rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])] logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts") record_stage("rag_retrieval", rag_stage_start) medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model) base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables." context_sections = [] if clinical_intake_context_block: context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block) if rag_contexts: context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4])) if search_contexts: context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts)) combined_context = "\n\n".join(context_sections) logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...") medswin_answers = [] # Update thoughts before starting MedSwin tasks thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text medswin_stage_start = time.time() for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1): if elapsed() >= hard_timeout - 5: logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}") break task_instruction = sub_topic.get("instruction", "") topic_name = sub_topic.get("topic", f"Topic {idx}") priority = sub_topic.get("priority", "medium") logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})") task_context = combined_context if len(rag_contexts) > 1 and idx <= len(rag_contexts): task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context # Add small delay between GPU requests to prevent ZeroGPU scheduler conflicts if idx > 1: delay = 0.5 # 500ms delay between sequential GPU requests logger.debug(f"[MEDSWIN] Waiting {delay}s before next GPU request to avoid scheduler conflicts...") time.sleep(delay) try: task_answer = execute_medswin_task( medical_model_obj=medical_model_obj, medical_tokenizer=medical_tokenizer, task_instruction=task_instruction, context=task_context if task_context else "", system_prompt_base=base_system_prompt, temperature=temperature, max_new_tokens=min(max_new_tokens, 800), top_p=top_p, top_k=top_k, penalty=penalty ) formatted_answer = f"## {topic_name}\n\n{task_answer}" medswin_answers.append(formatted_answer) logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars") partial_final = "\n\n".join(medswin_answers) updated_history[-1]["content"] = partial_final thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text except Exception as e: logger.error(f"[MEDSWIN] Task {idx} failed: {e}") continue record_stage("medswin_tasks", medswin_stage_start) # If agentic reasoning is disabled, we skip all Gemini-based synthesis, # challenge, and enhancement loops. The final answer is just the # concatenation of MedSwin task outputs. if disable_agentic_reasoning: logger.info("[MAC] Agentic reasoning disabled - skipping Gemini synthesis and challenge") if medswin_answers: final_answer = "\n\n".join(medswin_answers) else: final_answer = "I apologize, but I was unable to generate a response." else: logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...") raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] synthesis_stage_start = time.time() # Run in thread pool to avoid blocking GPU task final_answer = run_gemini_in_thread( gemini_supervisor_synthesize, message, raw_medswin_answers, rag_contexts, search_contexts, breakdown ) record_stage("synthesis", synthesis_stage_start) if not final_answer or len(final_answer.strip()) < 50: logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation") final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response." if "|" in final_answer and "---" in final_answer: logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets") lines = final_answer.split('\n') cleaned_lines = [] for line in lines: if '|' in line and '---' not in line: cells = [cell.strip() for cell in line.split('|') if cell.strip()] if cells: cleaned_lines.append(f"- {' / '.join(cells)}") elif '---' not in line: cleaned_lines.append(line) final_answer = '\n'.join(cleaned_lines) max_challenge_iterations = 2 challenge_iteration = 0 challenge_stage_start = time.time() while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15: challenge_iteration += 1 logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...") # Run in thread pool to avoid blocking GPU task evaluation = run_gemini_in_thread( gemini_supervisor_challenge, message, final_answer, raw_medswin_answers, rag_contexts, search_contexts ) if evaluation.get("is_optimal", False): logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)") break enhancement_instructions = evaluation.get("enhancement_instructions", "") if not enhancement_instructions: logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal") break logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...") # Run in thread pool to avoid blocking GPU task enhanced_answer = run_gemini_in_thread( gemini_supervisor_enhance_answer, message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts ) if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8: final_answer = enhanced_answer logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)") else: logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping") break record_stage("challenge_loop", challenge_stage_start) if final_use_web_search and elapsed() < soft_timeout - 10: logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...") clarity_stage_start = time.time() # Run in thread pool to avoid blocking GPU task clarity_check = run_gemini_in_thread(gemini_supervisor_check_clarity, message, final_answer, final_use_web_search) record_stage("clarity_check", clarity_stage_start) if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"): logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}") additional_search_results = [] followup_stage_start = time.time() for search_query in clarity_check.get("search_queries", [])[:3]: if elapsed() >= soft_timeout - 5: break extra_start = time.time() results = search_web(search_query, max_results=2) extra_duration = time.time() - extra_start pipeline_diagnostics["search"]["strategies"].append({ "query": search_query, "target_sources": 2, "duration": round(extra_duration, 3), "results": len(results), "type": "followup" }) additional_search_results.extend(results) web_urls.extend([r.get('url', '') for r in results if r.get('url')]) if additional_search_results: pipeline_diagnostics["search"]["total_results"] += len(additional_search_results) logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...") additional_summary = summarize_web_content(additional_search_results, message) if additional_summary: search_contexts.append(additional_summary) logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...") # Run in thread pool to avoid blocking GPU task enhanced_with_search = run_gemini_in_thread( gemini_supervisor_enhance_answer, message, final_answer, f"Incorporate the following additional information from web search: {additional_summary}", raw_medswin_answers, rag_contexts, search_contexts ) if enhanced_with_search and len(enhanced_with_search.strip()) > 50: final_answer = enhanced_with_search logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context") record_stage("followup_search", followup_stage_start) # Update thoughts after followup search thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" yield updated_history, thoughts_text citations_text = "" if needs_translation and final_answer: logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...") final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en") if web_urls: unique_urls = list(dict.fromkeys(web_urls)) citation_links = [] for url in unique_urls[:5]: domain = format_url_as_domain(url) if domain: citation_links.append(f"[{domain}]({url})") if citation_links: citations_text = "\n\n**Sources:** " + ", ".join(citation_links) speaker_icon = ' 🔊' final_answer_with_metadata = final_answer + citations_text + speaker_icon updated_history[-1]["content"] = final_answer_with_metadata thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else "" # Always yield thoughts_text, even if empty, to ensure UI updates yield updated_history, thoughts_text if thought_handler: logger.removeHandler(thought_handler) diag_summary = { "stage_metrics": pipeline_diagnostics["stage_metrics"], "decisions": pipeline_diagnostics["strategy_decisions"], "search": pipeline_diagnostics["search"], } try: logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}") except Exception: logger.info(f"[MAC] Diagnostics summary (non-serializable)") logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")