diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,2969 +1,13 @@ -import gradio as gr +"""Main entry point for MedLLM Agent""" import os -import base64 -import logging -import torch -import threading -import time -import json -import concurrent.futures -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - TextIteratorStreamer, - StoppingCriteria, - StoppingCriteriaList, -) -from transformers import logging as hf_logging -import spaces -from llama_index.core import ( - StorageContext, - VectorStoreIndex, - load_index_from_storage, - Document as LlamaDocument, -) -from llama_index.core import Settings -from llama_index.core.node_parser import ( - HierarchicalNodeParser, - get_leaf_nodes, - get_root_nodes, -) -from llama_index.core.retrievers import AutoMergingRetriever -from llama_index.core.storage.docstore import SimpleDocumentStore -from llama_index.llms.huggingface import HuggingFaceLLM -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from tqdm import tqdm -from langdetect import detect, LangDetectException os.environ["TOKENIZERS_PARALLELISM"] = "false" -# Set logging to INFO level for cleaner output -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) -# Custom logger handler to capture agentic thoughts -class ThoughtCaptureHandler(logging.Handler): - """Custom handler to capture internal thoughts from MedSwin and supervisor""" - def __init__(self): - super().__init__() - self.thoughts = [] - self.lock = threading.Lock() - - def emit(self, record): - """Capture log messages that contain agentic thoughts""" - try: - msg = self.format(record) - # Only capture messages from GEMINI SUPERVISOR or MEDSWIN - if "[GEMINI SUPERVISOR]" in msg or "[MEDSWIN]" in msg or "[MAC]" in msg: - # Remove timestamp and logger name for cleaner display - # Format: "timestamp - logger - level - message" - parts = msg.split(" - ", 3) - if len(parts) >= 4: - clean_msg = parts[-1] # Get the message part - else: - clean_msg = msg - with self.lock: - self.thoughts.append(clean_msg) - except Exception: - pass # Ignore formatting errors - - def get_thoughts(self): - """Get all captured thoughts as a formatted string""" - with self.lock: - return "\n".join(self.thoughts) - - def clear(self): - """Clear captured thoughts""" - with self.lock: - self.thoughts = [] -# Set MCP client logging to WARNING to reduce noise -mcp_client_logger = logging.getLogger("mcp.client") -mcp_client_logger.setLevel(logging.WARNING) -hf_logging.set_verbosity_error() - -# MCP imports -MCP_CLIENT_INFO = None -try: - from mcp import ClientSession, StdioServerParameters - from mcp import types as mcp_types - from mcp.client.stdio import stdio_client - import asyncio - try: - import nest_asyncio - nest_asyncio.apply() # Allow nested event loops - except ImportError: - pass # nest_asyncio is optional - MCP_AVAILABLE = True - MCP_CLIENT_INFO = mcp_types.Implementation( - name="MedLLM-Agent", - version=os.environ.get("SPACE_VERSION", "local"), - ) -except ImportError as e: - logger.warning(f"MCP SDK not available: {e}") - MCP_AVAILABLE = False - # Fallback imports if MCP is not available -from ddgs import DDGS -import requests -from bs4 import BeautifulSoup -try: - from TTS.api import TTS - TTS_AVAILABLE = True -except ImportError: - TTS_AVAILABLE = False - TTS = None -import numpy as np -import soundfile as sf -import tempfile - -# Model configurations -MEDSWIN_MODELS = { - "MedSwin SFT": "MedSwin/MedSwin-7B-SFT", - "MedSwin KD": "MedSwin/MedSwin-7B-KD", - "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7" -} -DEFAULT_MEDICAL_MODEL = "MedSwin TA" -EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model -TTS_MODEL = "maya-research/maya1" -HF_TOKEN = os.environ.get("HF_TOKEN") -if not HF_TOKEN: - raise ValueError("HF_TOKEN not found in environment variables") - -# Gemini MCP configuration -GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") -GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash") # Default for harder tasks -GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite") # For parsing and simple tasks - -# Custom UI -TITLE = "

đŸŠē MedLLM Agent - Medical RAG & Web Search System

" -DESCRIPTION = """ -
-

Advanced Medical AI Assistant powered by MedSwin models

-

📄 Document RAG: Answer based on uploaded medical documents

-

🌐 Web Search: Fetch knowledge from reliable online medical resources

-

🌍 Multi-language: Automatic translation for non-English queries

-

Upload PDF or text files to get started!

-
-""" -CSS = """ -.upload-section { - max-width: 400px; - margin: 0 auto; - padding: 10px; - border: 2px dashed #ccc; - border-radius: 10px; -} -.upload-button { - background: #34c759 !important; - color: white !important; - border-radius: 25px !important; -} -.chatbot-container { - margin-top: 20px; -} -.status-output { - margin-top: 10px; - font-size: 14px; -} -.processing-info { - margin-top: 5px; - font-size: 12px; - color: #666; -} -.info-container { - margin-top: 10px; - padding: 10px; - border-radius: 5px; -} -.file-list { - margin-top: 0; - max-height: 200px; - overflow-y: auto; - padding: 5px; - border: 1px solid #eee; - border-radius: 5px; -} -.stats-box { - margin-top: 10px; - padding: 10px; - border-radius: 5px; - font-size: 12px; -} -.submit-btn { - background: #1a73e8 !important; - color: white !important; - border-radius: 25px !important; - margin-left: 10px; - padding: 5px 10px; - font-size: 16px; -} -.input-row { - display: flex; - align-items: center; -} -.recording-timer { - font-size: 12px; - color: #666; - text-align: center; - margin-top: 5px; -} -.feature-badge { - display: inline-block; - padding: 3px 8px; - margin: 2px; - border-radius: 12px; - font-size: 11px; - font-weight: bold; -} -.badge-rag { - background: #e3f2fd; - color: #1976d2; -} -.badge-web { - background: #f3e5f5; - color: #7b1fa2; -} -@media (min-width: 768px) { - .main-container { - display: flex; - justify-content: space-between; - gap: 20px; - } - .upload-section { - flex: 1; - max-width: 300px; - } - .chatbot-container { - flex: 2; - margin-top: 0; - } -} -""" - -# Global model storage -global_medical_models = {} -global_medical_tokenizers = {} -global_file_info = {} -global_tts_model = None -global_embed_model = None - -# MCP client storage -global_mcp_session = None -global_mcp_stdio_ctx = None # Store stdio context to keep it alive -global_mcp_lock = threading.Lock() # Lock for thread-safe session access -# MCP server configuration via environment variables -# Gemini MCP server: Python-based server (agent.py) -# This works on Hugging Face Spaces without requiring npm/Node.js -# Make sure GEMINI_API_KEY is set in environment variables -# -# Default configuration uses the bundled agent.py script -# To override: -# export MCP_SERVER_COMMAND="python" -# export MCP_SERVER_ARGS="/path/to/agent.py" -script_dir = os.path.dirname(os.path.abspath(__file__)) -agent_path = os.path.join(script_dir, "agent.py") -MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python") -MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path] - -async def get_mcp_session(): - """Get or create MCP client session with proper context management""" - global global_mcp_session, global_mcp_stdio_ctx - - if not MCP_AVAILABLE: - logger.warning("MCP not available - SDK not installed") - return None - - # Check if session exists and is still valid - if global_mcp_session is not None: - # Trust that existing session is valid - verify only when actually using it - return global_mcp_session - - # Create new session using correct MCP SDK pattern - try: - # Prepare environment variables for MCP server - mcp_env = os.environ.copy() - if GEMINI_API_KEY: - mcp_env["GEMINI_API_KEY"] = GEMINI_API_KEY - else: - logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.") - - # Add other Gemini MCP configuration if set - if os.environ.get("GEMINI_MODEL"): - mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL") - if os.environ.get("GEMINI_TIMEOUT"): - mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT") - if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"): - mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS") - if os.environ.get("GEMINI_TEMPERATURE"): - mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE") - - logger.info("Creating MCP client session...") - - server_params = StdioServerParameters( - command=MCP_SERVER_COMMAND, - args=MCP_SERVER_ARGS, - env=mcp_env - ) - - # Correct MCP SDK usage: stdio_client is an async context manager - # that yields (read, write) streams - stdio_ctx = stdio_client(server_params) - read, write = await stdio_ctx.__aenter__() - - # Create ClientSession from the streams - session = ClientSession( - read, - write, - client_info=MCP_CLIENT_INFO, - ) - - # Initialize the session (this sends initialize request and waits for response + initialized notification) - # The __aenter__() method handles the complete initialization handshake: - # 1. Sends initialize request with client info - # 2. Waits for initialize response from server - # 3. Waits for initialized notification from server (this is critical!) - # According to MCP protocol spec, the client MUST wait for the initialized notification - # before sending any other requests (like list_tools) - try: - # The __aenter__() method properly handles the full initialization sequence - # including waiting for the server's initialized notification - # This is a blocking call that completes only after the server sends initialized - await session.__aenter__() - init_result = await session.initialize() - server_info = getattr(init_result, "serverInfo", None) - server_name = getattr(server_info, "name", "unknown") - server_version = getattr(server_info, "version", "unknown") - logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})") - except Exception as e: - error_msg = str(e) - error_type = type(e).__name__ - logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}") - - # Clean up and return None - try: - await session.__aexit__(None, None, None) - except Exception: - pass - try: - await stdio_ctx.__aexit__(None, None, None) - except Exception: - pass - return None - - # Store both the session and stdio context to keep them alive - global_mcp_session = session - global_mcp_stdio_ctx = stdio_ctx - logger.info("✅ MCP client session created successfully") - return session - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}") - global_mcp_session = None - global_mcp_stdio_ctx = None - return None - -MCP_TOOLS_CACHE_TTL = int(os.environ.get("MCP_TOOLS_CACHE_TTL", "60")) -global_mcp_tools_cache = {"timestamp": 0.0, "tools": None} - - -def invalidate_mcp_tools_cache(): - """Invalidate cached MCP tool metadata""" - global global_mcp_tools_cache - global_mcp_tools_cache = {"timestamp": 0.0, "tools": None} - - -async def get_cached_mcp_tools(force_refresh: bool = False): - """Return cached MCP tools list to avoid repeated list_tools calls""" - global global_mcp_tools_cache - if not MCP_AVAILABLE: - return [] - - now = time.time() - if ( - not force_refresh - and global_mcp_tools_cache["tools"] - and now - global_mcp_tools_cache["timestamp"] < MCP_TOOLS_CACHE_TTL - ): - return global_mcp_tools_cache["tools"] - - session = await get_mcp_session() - if session is None: - return [] - - try: - tools_resp = await session.list_tools() - tools_list = list(getattr(tools_resp, "tools", []) or []) - global_mcp_tools_cache = {"timestamp": now, "tools": tools_list} - return tools_list - except Exception as e: - logger.error(f"Failed to refresh MCP tools: {e}") - invalidate_mcp_tools_cache() - return [] - - -async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: - """Call Gemini MCP generate_content tool""" - if not MCP_AVAILABLE: - logger.warning("MCP not available for Gemini call") - return "" - - try: - session = await get_mcp_session() - if session is None: - logger.warning("Failed to get MCP session for Gemini call") - return "" - - tools = await get_cached_mcp_tools() - if not tools: - tools = await get_cached_mcp_tools(force_refresh=True) - if not tools: - logger.error("Unable to obtain MCP tool catalog for Gemini calls") - return "" - - generate_tool = None - for tool in tools: - if tool.name == "generate_content" or "generate_content" in tool.name.lower(): - generate_tool = tool - logger.info(f"Found Gemini MCP tool: {tool.name}") - break - - if not generate_tool: - logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}") - invalidate_mcp_tools_cache() - return "" - - # Prepare arguments - arguments = { - "user_prompt": user_prompt - } - if system_prompt: - arguments["system_prompt"] = system_prompt - if files: - arguments["files"] = files - if model: - arguments["model"] = model - if temperature is not None: - arguments["temperature"] = temperature - - result = await session.call_tool(generate_tool.name, arguments=arguments) - - # Parse result - if hasattr(result, 'content') and result.content: - for item in result.content: - if hasattr(item, 'text'): - response_text = item.text.strip() - return response_text - logger.warning("âš ī¸ Gemini MCP returned empty or invalid result") - return "" - except Exception as e: - logger.error(f"Gemini MCP call error: {e}") - return "" - -def initialize_medical_model(model_name: str): - """Initialize medical model (MedSwin) - download on demand""" - global global_medical_models, global_medical_tokenizers - if model_name not in global_medical_models or global_medical_models[model_name] is None: - logger.info(f"Initializing medical model: {model_name}...") - model_path = MEDSWIN_MODELS[model_name] - tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN) - model = AutoModelForCausalLM.from_pretrained( - model_path, - device_map="auto", - trust_remote_code=True, - token=HF_TOKEN, - torch_dtype=torch.float16 - ) - global_medical_models[model_name] = model - global_medical_tokenizers[model_name] = tokenizer - logger.info(f"Medical model {model_name} initialized successfully") - return global_medical_models[model_name], global_medical_tokenizers[model_name] - - -def initialize_tts_model(): - """Initialize TTS model for text-to-speech""" - global global_tts_model - if not TTS_AVAILABLE: - logger.warning("TTS library not installed. TTS features will be disabled.") - return None - if global_tts_model is None: - try: - logger.info("Initializing TTS model for voice generation...") - global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False) - logger.info("TTS model initialized successfully") - except Exception as e: - logger.warning(f"TTS model initialization failed: {e}") - logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") - global_tts_model = None - return global_tts_model - - -def get_or_create_embed_model(): - """Reuse embedding model to avoid reloading weights each request""" - global global_embed_model - if global_embed_model is None: - logger.info("Initializing shared embedding model for RAG retrieval...") - global_embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) - return global_embed_model - -async def transcribe_audio_gemini(audio_path: str) -> str: - """Transcribe audio using Gemini MCP""" - if not MCP_AVAILABLE: - return "" - - try: - # Ensure we have an absolute path - audio_path_abs = os.path.abspath(audio_path) - - # Prepare file object for Gemini MCP using path (as per Gemini MCP documentation) - files = [{ - "path": audio_path_abs - }] - - # Use exact prompts from Gemini MCP documentation - system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." - user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." - - result = await call_agent( - user_prompt=user_prompt, - system_prompt=system_prompt, - files=files, - model=GEMINI_MODEL_LITE, # Use lite model for transcription - temperature=0.2 - ) - - return result.strip() - except Exception as e: - logger.error(f"Gemini transcription error: {e}") - return "" - -def transcribe_audio(audio): - """Transcribe audio to text using Gemini MCP""" - if audio is None: - return "" - - try: - # Handle file path (Gradio Audio component returns file path) - if isinstance(audio, str): - audio_path = audio - elif isinstance(audio, tuple): - # Handle tuple format (sample_rate, audio_data) - sample_rate, audio_data = audio - # Save to temp file - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: - sf.write(tmp_file.name, audio_data, samplerate=sample_rate) - audio_path = tmp_file.name - else: - audio_path = audio - - # Use Gemini MCP for transcription - if MCP_AVAILABLE: - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) - if transcribed: - logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...") - return transcribed - except Exception as e: - logger.error(f"Error in nested async transcription: {e}") - else: - transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) - if transcribed: - logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...") - return transcribed - except Exception as e: - logger.error(f"Gemini MCP transcription error: {e}") - - logger.warning("Gemini MCP transcription not available") - return "" - except Exception as e: - logger.error(f"Transcription error: {e}") - return "" - -async def generate_speech_mcp(text: str) -> str: - """Generate speech using MCP TTS tool""" - if not MCP_AVAILABLE: - return None - - try: - # Get MCP session - session = await get_mcp_session() - if session is None: - return None - - # Find TTS tool - tools = await session.list_tools() - tts_tool = None - for tool in tools.tools: - if "tts" in tool.name.lower() or "speech" in tool.name.lower() or "synthesize" in tool.name.lower(): - tts_tool = tool - logger.info(f"Found MCP TTS tool: {tool.name}") - break - - if tts_tool: - result = await session.call_tool( - tts_tool.name, - arguments={"text": text, "language": "en"} - ) - - # Parse result - MCP might return audio data or file path - if hasattr(result, 'content') and result.content: - for item in result.content: - if hasattr(item, 'text'): - # If it's a file path - if os.path.exists(item.text): - return item.text - elif hasattr(item, 'data') and item.data: - # If it's binary audio data, save it - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: - tmp_file.write(item.data) - return tmp_file.name - return None - except Exception as e: - logger.warning(f"MCP TTS error: {e}") - return None - -def generate_speech(text: str): - """Generate speech from text using TTS model (with MCP fallback)""" - if not text or len(text.strip()) == 0: - return None - - # Try MCP first if available - if MCP_AVAILABLE: - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - audio_path = nest_asyncio.run(generate_speech_mcp(text)) - if audio_path: - logger.info("Generated speech via MCP") - return audio_path - except: - pass - else: - audio_path = loop.run_until_complete(generate_speech_mcp(text)) - if audio_path: - return audio_path - except Exception as e: - pass # MCP TTS not available, fallback to local - - # Fallback to local TTS model - if not TTS_AVAILABLE: - logger.error("TTS library not installed. Please install TTS to use voice generation.") - return None - - global global_tts_model - if global_tts_model is None: - initialize_tts_model() - - if global_tts_model is None: - logger.error("TTS model not available. Please check dependencies.") - return None - - try: - # Generate audio - wav = global_tts_model.tts(text) - - # Save to temporary file - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: - sf.write(tmp_file.name, wav, samplerate=22050) - return tmp_file.name - except Exception as e: - logger.error(f"TTS error: {e}") - return None - -def format_prompt_manually(messages: list, tokenizer) -> str: - """Manually format prompt for models without chat template""" - prompt_parts = [] - - # Combine system and user messages into a single instruction - system_content = "" - user_content = "" - - for msg in messages: - role = msg.get("role", "user") - content = msg.get("content", "") - - if role == "system": - system_content = content - elif role == "user": - user_content = content - elif role == "assistant": - # Skip assistant messages in history for now (can be added if needed) - pass - - # Format for MedAlpaca/LLaMA-based medical models - # Common format: Instruction + Input -> Response - if system_content: - prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:" - else: - prompt = f"Question: {user_content}\n\nAnswer:" - - return prompt - -def detect_language(text: str) -> str: - """Detect language of input text""" - try: - lang = detect(text) - return lang - except LangDetectException: - return "en" # Default to English if detection fails - -def format_url_as_domain(url: str) -> str: - """Format URL as simple domain name (e.g., www.mayoclinic.org)""" - if not url: - return "" - try: - from urllib.parse import urlparse - parsed = urlparse(url) - domain = parsed.netloc or parsed.path - # Remove www. prefix if present, but keep it for display - if domain.startswith('www.'): - return domain - elif domain: - return domain - return url - except Exception: - # Fallback: try to extract domain manually - if '://' in url: - domain = url.split('://')[1].split('/')[0] - return domain - return url - -async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str: - """Translate text using Gemini MCP""" - if source_lang: - user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}" - else: - user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}" - - # Use concise system prompt - system_prompt = "You are a professional translator. Translate accurately and concisely." - - result = await call_agent( - user_prompt=user_prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL_LITE, # Use lite model for translation - temperature=0.2 - ) - - return result.strip() - -def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str: - """Translate text using Gemini MCP""" - if not MCP_AVAILABLE: - logger.warning("Gemini MCP not available for translation") - return text # Return original text if translation fails - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang)) - if translated: - logger.info(f"Translated via Gemini MCP: {translated[:50]}...") - return translated - except Exception as e: - logger.error(f"Error in nested async translation: {e}") - else: - translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang)) - if translated: - logger.info(f"Translated via Gemini MCP: {translated[:50]}...") - return translated - except Exception as e: - logger.error(f"Gemini MCP translation error: {e}") - - # Return original text if translation fails - return text - -async def search_web_mcp_tool(query: str, max_results: int = 5) -> list: - """Search web using MCP web search tool (e.g., DuckDuckGo MCP server)""" - if not MCP_AVAILABLE: - return [] - - try: - tools = await get_cached_mcp_tools() - if not tools: - return [] - - search_tool = None - for tool in tools: - tool_name_lower = tool.name.lower() - if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]): - search_tool = tool - logger.info(f"Found web search MCP tool: {tool.name}") - break - - if not search_tool: - tools = await get_cached_mcp_tools(force_refresh=True) - for tool in tools: - tool_name_lower = tool.name.lower() - if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]): - search_tool = tool - logger.info(f"Found web search MCP tool after refresh: {tool.name}") - break - - if search_tool: - try: - session = await get_mcp_session() - if session is None: - return [] - # Call the search tool - result = await session.call_tool( - search_tool.name, - arguments={"query": query, "max_results": max_results} - ) - - # Parse result - web_content = [] - if hasattr(result, 'content') and result.content: - for item in result.content: - if hasattr(item, 'text'): - try: - data = json.loads(item.text) - if isinstance(data, list): - for entry in data[:max_results]: - web_content.append({ - 'title': entry.get('title', ''), - 'url': entry.get('url', entry.get('href', '')), - 'content': entry.get('body', entry.get('snippet', entry.get('content', ''))) - }) - elif isinstance(data, dict): - if 'results' in data: - for entry in data['results'][:max_results]: - web_content.append({ - 'title': entry.get('title', ''), - 'url': entry.get('url', entry.get('href', '')), - 'content': entry.get('body', entry.get('snippet', entry.get('content', ''))) - }) - else: - web_content.append({ - 'title': data.get('title', ''), - 'url': data.get('url', data.get('href', '')), - 'content': data.get('body', data.get('snippet', data.get('content', ''))) - }) - except json.JSONDecodeError: - # If not JSON, treat as plain text - web_content.append({ - 'title': '', - 'url': '', - 'content': item.text[:1000] - }) - - if web_content: - return web_content - except Exception as e: - logger.error(f"Error calling web search MCP tool: {e}") - - else: - logger.debug("No MCP web search tool discovered in current catalog") - return [] - except Exception as e: - logger.error(f"Web search MCP tool error: {e}") - return [] - -async def search_web_mcp(query: str, max_results: int = 5) -> list: - """Search web using MCP tools - tries web search MCP tool first, then falls back to direct search""" - # First try to use a dedicated web search MCP tool (like DuckDuckGo MCP server) - results = await search_web_mcp_tool(query, max_results) - if results: - logger.info(f"✅ Web search via MCP tool: found {len(results)} results") - return results - - # If no web search MCP tool available, use direct search (ddgs) - # Note: Gemini MCP doesn't have web search capability, so we use direct API - # The results will then be summarized using Gemini MCP - logger.info("â„šī¸ [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)") - return search_web_fallback(query, max_results) - -def search_web_fallback(query: str, max_results: int = 5) -> list: - """Fallback web search using DuckDuckGo directly (when MCP is not available)""" - logger.info(f"🔍 [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...") - # Always import here to ensure availability - try: - from ddgs import DDGS - import requests - from bs4 import BeautifulSoup - except ImportError: - logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available") - return [] - - try: - with DDGS() as ddgs: - results = list(ddgs.text(query, max_results=max_results)) - web_content = [] - for result in results: - try: - url = result.get('href', '') - title = result.get('title', '') - snippet = result.get('body', '') - - # Try to fetch full content - try: - response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'}) - if response.status_code == 200: - soup = BeautifulSoup(response.content, 'html.parser') - # Extract main content - for script in soup(["script", "style"]): - script.decompose() - text = soup.get_text() - # Clean and limit text - lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) - text = ' '.join(chunk for chunk in chunks if chunk) - if len(text) > 1000: - text = text[:1000] + "..." - web_content.append({ - 'title': title, - 'url': url, - 'content': snippet + "\n" + text[:500] if text else snippet - }) - else: - web_content.append({ - 'title': title, - 'url': url, - 'content': snippet - }) - except: - web_content.append({ - 'title': title, - 'url': url, - 'content': snippet - }) - except Exception as e: - logger.error(f"Error processing search result: {e}") - continue - logger.info(f"✅ [Direct API] Web search completed: {len(web_content)} results") - return web_content - except Exception as e: - logger.error(f"❌ [Direct API] Web search error: {e}") - return [] - -def search_web(query: str, max_results: int = 5) -> list: - """Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs""" - # Always try MCP first if available - if MCP_AVAILABLE: - try: - # Run async MCP search - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if loop.is_running(): - # If loop is already running, use nest_asyncio or create new thread - try: - import nest_asyncio - results = nest_asyncio.run(search_web_mcp(query, max_results)) - if results: # Only return if we got results from MCP - return results - except (ImportError, AttributeError): - # Fallback: run in thread - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, search_web_mcp(query, max_results)) - results = future.result(timeout=30) - if results: # Only return if we got results from MCP - return results - else: - results = loop.run_until_complete(search_web_mcp(query, max_results)) - if results: # Only return if we got results from MCP - return results - except Exception as e: - logger.error(f"Error running async MCP search: {e}") - - # Only use ddgs fallback if MCP is not available or returned no results - logger.info("â„šī¸ [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)") - return search_web_fallback(query, max_results) - -async def summarize_web_content_gemini(content_list: list, query: str) -> str: - """Summarize web search results using Gemini MCP""" - combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]]) - - user_prompt = f"""Summarize the following web search results related to the query: "{query}" -Extract key medical information, facts, and insights. Be concise and focus on reliable information. -Search Results: -{combined_content} -Summary:""" - - # Use concise system prompt - system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately." - - result = await call_agent( - user_prompt=user_prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, # Use full model for summarization - temperature=0.5 - ) - - return result.strip() - -def summarize_web_content(content_list: list, query: str) -> str: - """Summarize web search results using Gemini MCP""" - if not MCP_AVAILABLE: - logger.warning("Gemini MCP not available for summarization") - # Fallback: return first result's content - if content_list: - return content_list[0].get('content', '')[:500] - return "" - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query)) - if summary: - return summary - except Exception as e: - logger.error(f"Error in nested async summarization: {e}") - else: - summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query)) - if summary: - return summary - except Exception as e: - logger.error(f"Gemini MCP summarization error: {e}") - - # Fallback: return first result's content - if content_list: - return content_list[0].get('content', '')[:500] - return "" - -def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): - """Get LLM for RAG indexing (uses medical model)""" - # Use medical model for RAG indexing instead of translation model - medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL) - - return HuggingFaceLLM( - context_window=4096, - max_new_tokens=max_new_tokens, - tokenizer=medical_tokenizer, - model=medical_model_obj, - generate_kwargs={ - "do_sample": True, - "temperature": temperature, - "top_k": top_k, - "top_p": top_p - } - ) - -async def autonomous_reasoning_gemini(query: str) -> dict: - """Autonomous reasoning using Gemini MCP""" - reasoning_prompt = f"""Analyze this medical query and provide structured reasoning: -Query: "{query}" -Analyze: -1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info) -2. Complexity: (simple, moderate, complex, multi_faceted) -3. Information Needs: What specific information is required? -4. Requires RAG: (yes/no) - Does this need document context? -5. Requires Web Search: (yes/no) - Does this need current/updated information? -6. Sub-questions: Break down into key sub-questions if complex -Respond in JSON format: -{{ - "query_type": "...", - "complexity": "...", - "information_needs": ["..."], - "requires_rag": true/false, - "requires_web_search": true/false, - "sub_questions": ["..."] -}}""" - - # Use concise system prompt - system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses." - - response = await call_agent( - user_prompt=reasoning_prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, # Use full model for reasoning - temperature=0.3 - ) - - # Parse JSON response (with fallback) - try: - # Extract JSON from response - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - reasoning = json.loads(response[json_start:json_end]) - else: - raise ValueError("No JSON found") - except: - # Fallback reasoning - reasoning = { - "query_type": "general_info", - "complexity": "moderate", - "information_needs": ["medical information"], - "requires_rag": True, - "requires_web_search": False, - "sub_questions": [query] - } - - logger.info(f"Reasoning analysis: {reasoning}") - return reasoning - -def autonomous_reasoning(query: str, history: list) -> dict: - """ - Autonomous reasoning: Analyze query complexity, intent, and information needs. - Returns reasoning analysis with query type, complexity, and required information sources. - Uses Gemini MCP for reasoning. - """ - if not MCP_AVAILABLE: - logger.warning("âš ī¸ Gemini MCP not available for reasoning, using fallback") - # Fallback reasoning - return { - "query_type": "general_info", - "complexity": "moderate", - "information_needs": ["medical information"], - "requires_rag": True, - "requires_web_search": False, - "sub_questions": [query] - } - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query)) - return reasoning - except Exception as e: - logger.error(f"Error in nested async reasoning: {e}") - else: - reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query)) - return reasoning - except Exception as e: - logger.error(f"Gemini MCP reasoning error: {e}") - - # Fallback reasoning only if all attempts failed - logger.warning("âš ī¸ Falling back to default reasoning") - return { - "query_type": "general_info", - "complexity": "moderate", - "information_needs": ["medical information"], - "requires_rag": True, - "requires_web_search": False, - "sub_questions": [query] - } - -def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict: - """ - Planning: Create multi-step execution plan based on reasoning analysis. - Returns execution plan with steps and strategy. - """ - plan = { - "steps": [], - "strategy": "sequential", - "iterations": 1 - } - - # Determine execution strategy - if reasoning["complexity"] in ["complex", "multi_faceted"]: - plan["strategy"] = "iterative" - plan["iterations"] = 2 - - # Step 1: Language detection and translation - plan["steps"].append({ - "step": 1, - "action": "detect_language", - "description": "Detect query language and translate if needed" - }) - - # Step 2: RAG retrieval (if needed and available) - if reasoning.get("requires_rag", True) and has_rag_index: - plan["steps"].append({ - "step": 2, - "action": "rag_retrieval", - "description": "Retrieve relevant document context", - "parameters": {"top_k": 15, "merge_threshold": 0.5} - }) - - # Step 3: Web search (if needed) - if reasoning.get("requires_web_search", False): - plan["steps"].append({ - "step": 3, - "action": "web_search", - "description": "Search web for current/updated information", - "parameters": {"max_results": 5} - }) - - # Step 4: Sub-question processing (if complex) - if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1: - plan["steps"].append({ - "step": 4, - "action": "multi_step_reasoning", - "description": "Process sub-questions iteratively", - "sub_questions": reasoning["sub_questions"] - }) - - # Step 5: Synthesis and answer generation - plan["steps"].append({ - "step": len(plan["steps"]) + 1, - "action": "synthesize_answer", - "description": "Generate comprehensive answer from all sources" - }) - - # Step 6: Self-reflection (for complex queries) - if reasoning["complexity"] in ["complex", "multi_faceted"]: - plan["steps"].append({ - "step": len(plan["steps"]) + 1, - "action": "self_reflection", - "description": "Evaluate answer quality and completeness" - }) - - logger.info(f"Execution plan created: {len(plan['steps'])} steps") - return plan - -def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict: - """ - Autonomous execution: Make decisions on information gathering strategy. - Only suggests web search override, but respects user's RAG disable setting. - """ - strategy = { - "use_rag": use_rag, # Respect user's RAG setting - "use_web_search": use_web_search, - "reasoning_override": False, - "rationale": "" - } - - # Respect user toggle; just log recommendation if web search is disabled - if reasoning.get("requires_web_search", False) and not use_web_search: - strategy["rationale"] = "Reasoning suggests web search for current information, but the user kept it disabled." - - # Note: We don't override RAG setting because: - # 1. User may have explicitly disabled it - # 2. RAG requires documents to be uploaded - # 3. We should respect user's explicit choice - - if strategy["rationale"]: - logger.info(f"Autonomous reasoning note: {strategy['rationale']}") - - return strategy - -async def gemini_supervisor_breakdown_async(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict: - """ - Gemini Supervisor: Break user query into sub-topics (flexible number, explore different approaches) - This is the main supervisor function that orchestrates the MAC architecture. - All internal thoughts are logged, not displayed. - """ - remaining_time = max(15, max_duration - time_elapsed) - - mode_description = [] - if use_rag: - mode_description.append("RAG mode enabled - will use retrieved documents") - if use_web_search: - mode_description.append("Web search mode enabled - will search online sources") - if not mode_description: - mode_description.append("Direct answer mode - no additional context") - - # Calculate reasonable max topics based on time remaining - # Allow more subtasks if we have time, but be flexible - estimated_time_per_task = 8 # seconds per task - max_topics_by_time = max(2, int((remaining_time - 20) / estimated_time_per_task)) - max_topics = min(max_topics_by_time, 10) # Cap at 10, but allow more than 4 - - prompt = f"""You are a supervisor agent coordinating with a MedSwin medical specialist model. -Break the following medical query into focused sub-topics that MedSwin can answer sequentially. -Explore different potential approaches to comprehensively address the topic. - -Query: "{query}" -Mode: {', '.join(mode_description)} -Time Remaining: ~{remaining_time:.1f}s -Maximum Topics: {max_topics} (adjust based on complexity - use as many as needed for thorough coverage) - -Return ONLY valid JSON (no markdown, no tables, no explanations): -{{ - "sub_topics": [ - {{ - "id": 1, - "topic": "concise topic name", - "instruction": "specific directive for MedSwin to answer this topic", - "expected_tokens": 200, - "priority": "high|medium|low", - "approach": "brief description of approach/angle for this topic" - }}, - ... - ], - "strategy": "brief strategy description explaining the breakdown approach", - "exploration_note": "brief note on different approaches explored" -}} - -Guidelines: -- Break down the query into as many subtasks as needed for comprehensive coverage -- Explore different angles/approaches (e.g., clinical, diagnostic, treatment, prevention, research perspectives) -- Each topic should be focused and answerable in ~200 tokens by MedSwin -- Prioritize topics by importance (high priority first) -- Don't limit yourself to 4 topics - use more if the query is complex or multi-faceted""" - - system_prompt = "You are a medical query supervisor. Break queries into structured JSON sub-topics, exploring different approaches. Return ONLY valid JSON." - - response = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, - temperature=0.3 - ) - - try: - # Extract JSON from response - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - breakdown = json.loads(response[json_start:json_end]) - logger.info(f"[GEMINI SUPERVISOR] Query broken into {len(breakdown.get('sub_topics', []))} sub-topics") - logger.debug(f"[GEMINI SUPERVISOR] Breakdown: {json.dumps(breakdown, indent=2)}") - return breakdown - else: - raise ValueError("Supervisor JSON not found") - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Breakdown parsing failed: {exc}") - # Fallback: simple breakdown - breakdown = { - "sub_topics": [ - {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"}, - {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"}, - ], - "strategy": "Sequential answer with key points", - "exploration_note": "Fallback breakdown - basic coverage" - } - logger.warning(f"[GEMINI SUPERVISOR] Using fallback breakdown") - return breakdown - -async def gemini_supervisor_search_strategies_async(query: str, time_elapsed: float) -> dict: - """ - Gemini Supervisor: In search mode, break query into 1-4 searching strategies - Returns JSON with search strategies that will be executed with ddgs - """ - prompt = f"""You are supervising web search for a medical query. -Break this query into 1-4 focused search strategies (each targeting 1-2 sources). - -Query: "{query}" - -Return ONLY valid JSON: -{{ - "search_strategies": [ - {{ - "id": 1, - "strategy": "search query string", - "target_sources": 1, - "focus": "what to search for" - }}, - ... - ], - "max_strategies": 4 -}} - -Keep strategies focused and avoid overlap.""" - - system_prompt = "You are a search strategy supervisor. Create focused search queries. Return ONLY valid JSON." - - response = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL_LITE, # Use lite model for search planning - temperature=0.2 - ) - - try: - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - strategies = json.loads(response[json_start:json_end]) - logger.info(f"[GEMINI SUPERVISOR] Created {len(strategies.get('search_strategies', []))} search strategies") - logger.debug(f"[GEMINI SUPERVISOR] Strategies: {json.dumps(strategies, indent=2)}") - return strategies - else: - raise ValueError("Search strategies JSON not found") - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Search strategies parsing failed: {exc}") - return { - "search_strategies": [ - {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"} - ], - "max_strategies": 1 - } - -async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict: - """ - Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts - These contexts will be passed to MedSwin to support decision-making - """ - # Limit retrieved docs to avoid token overflow - max_doc_length = 3000 - if len(retrieved_docs) > max_doc_length: - retrieved_docs = retrieved_docs[:max_doc_length] + "..." - - prompt = f"""You are supervising RAG context preparation for a medical query. -Brainstorm the retrieved documents into 1-4 concise, focused contexts that MedSwin can use. - -Query: "{query}" -Retrieved Documents: -{retrieved_docs} - -Return ONLY valid JSON: -{{ - "contexts": [ - {{ - "id": 1, - "context": "concise summary of relevant information (keep under 500 chars)", - "focus": "what this context covers", - "relevance": "high|medium|low" - }}, - ... - ], - "max_contexts": 4 -}} - -Keep contexts brief and factual. Avoid redundancy.""" - - system_prompt = "You are a RAG context supervisor. Summarize documents into concise contexts. Return ONLY valid JSON." - - response = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL_LITE, # Use lite model for RAG brainstorming - temperature=0.2 - ) - - try: - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - contexts = json.loads(response[json_start:json_end]) - logger.info(f"[GEMINI SUPERVISOR] Brainstormed {len(contexts.get('contexts', []))} RAG contexts") - logger.debug(f"[GEMINI SUPERVISOR] Contexts: {json.dumps(contexts, indent=2)}") - return contexts - else: - raise ValueError("RAG contexts JSON not found") - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] RAG brainstorming parsing failed: {exc}") - # Fallback: use retrieved docs as single context - return { - "contexts": [ - {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"} - ], - "max_contexts": 1 - } - -def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict: - """Wrapper to obtain supervisor breakdown synchronously""" - if not MCP_AVAILABLE: - logger.warning("[GEMINI SUPERVISOR] MCP unavailable, using fallback breakdown") - return { - "sub_topics": [ - {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"}, - {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"}, - ], - "strategy": "Sequential answer with key points", - "exploration_note": "Fallback breakdown - basic coverage" - } - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run( - gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration) - ) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested breakdown execution failed: {exc}") - raise - return loop.run_until_complete( - gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration) - ) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Breakdown request failed: {exc}") - return { - "sub_topics": [ - {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"}, - ], - "strategy": "Direct answer", - "exploration_note": "Fallback breakdown - single topic" - } - -def gemini_supervisor_search_strategies(query: str, time_elapsed: float) -> dict: - """Wrapper to obtain search strategies synchronously""" - if not MCP_AVAILABLE: - logger.warning("[GEMINI SUPERVISOR] MCP unavailable for search strategies") - return { - "search_strategies": [ - {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"} - ], - "max_strategies": 1 - } - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_search_strategies_async(query, time_elapsed)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested search strategies execution failed: {exc}") - return { - "search_strategies": [ - {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"} - ], - "max_strategies": 1 - } - return loop.run_until_complete(gemini_supervisor_search_strategies_async(query, time_elapsed)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Search strategies request failed: {exc}") - return { - "search_strategies": [ - {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"} - ], - "max_strategies": 1 - } - -def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elapsed: float) -> dict: - """Wrapper to obtain RAG brainstorm synchronously""" - if not MCP_AVAILABLE: - logger.warning("[GEMINI SUPERVISOR] MCP unavailable for RAG brainstorm") - return { - "contexts": [ - {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"} - ], - "max_contexts": 1 - } - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested RAG brainstorm execution failed: {exc}") - return { - "contexts": [ - {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"} - ], - "max_contexts": 1 - } - return loop.run_until_complete(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] RAG brainstorm request failed: {exc}") - return { - "contexts": [ - {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"} - ], - "max_contexts": 1 - } - -@spaces.GPU(max_duration=120) -def execute_medswin_task( - medical_model_obj, - medical_tokenizer, - task_instruction: str, - context: str, - system_prompt_base: str, - temperature: float, - max_new_tokens: int, - top_p: float, - top_k: int, - penalty: float -) -> str: - """ - MedSwin Specialist: Execute a single task assigned by Gemini Supervisor - This function is tagged with @spaces.GPU to run on GPU (ZeroGPU equivalent) - All internal thoughts are logged, only final answer is returned - """ - # Build task-specific prompt - if context: - full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):" - else: - full_prompt = f"{system_prompt_base}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):" - - messages = [{"role": "system", "content": full_prompt}] - - # Format prompt - if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None: - try: - prompt = medical_tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - except Exception as e: - logger.warning(f"[MEDSWIN] Chat template failed, using manual formatting: {e}") - prompt = format_prompt_manually(messages, medical_tokenizer) - else: - prompt = format_prompt_manually(messages, medical_tokenizer) - - # Tokenize and generate - inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device) - - eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id - - with torch.no_grad(): - outputs = medical_model_obj.generate( - **inputs, - max_new_tokens=min(max_new_tokens, 800), # Limit per task - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=penalty, - do_sample=True, - eos_token_id=eos_token_id, - pad_token_id=medical_tokenizer.pad_token_id or eos_token_id - ) - - # Decode response - response = medical_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) - - # Clean response - remove any table-like formatting, ensure Markdown bullets - response = response.strip() - # Remove table markers if present - if "|" in response and "---" in response: - logger.warning("[MEDSWIN] Detected table format, converting to Markdown bullets") - # Simple conversion: split by lines and convert to bullets - lines = [line.strip() for line in response.split('\n') if line.strip() and not line.strip().startswith('|') and '---' not in line] - response = '\n'.join([f"- {line}" if not line.startswith('-') else line for line in lines]) - - logger.info(f"[MEDSWIN] Task completed: {len(response)} chars generated") - return response - -async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str: - """ - Gemini Supervisor: Synthesize final answer from all MedSwin responses with clear context - Provides better context to create a comprehensive, well-structured final answer - """ - # Prepare context summary - context_summary = "" - if rag_contexts: - context_summary += f"Document Context Available: {len(rag_contexts)} context(s) from uploaded documents.\n" - if search_contexts: - context_summary += f"Web Search Context Available: {len(search_contexts)} search result(s).\n" - - # Combine all MedSwin answers - all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)]) - - prompt = f"""You are a supervisor agent synthesizing a comprehensive medical answer from multiple specialist responses. - -Original Query: "{query}" - -Context Available: -{context_summary} - -MedSwin Specialist Responses (from {len(medswin_answers)} sub-topics): -{all_answers_text} - -Your task: -1. Synthesize all responses into a coherent, comprehensive final answer -2. Integrate information from all sub-topics seamlessly -3. Ensure the answer directly addresses the original query -4. Maintain clinical accuracy and clarity -5. Use clear structure with appropriate headings and bullet points -6. Remove redundancy and contradictions -7. Ensure all important points from MedSwin responses are included - -Return the final synthesized answer in Markdown format. Do not add meta-commentary or explanations - just provide the final answer.""" - - system_prompt = "You are a medical answer synthesis supervisor. Create comprehensive, well-structured final answers from multiple specialist responses." - - result = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, - temperature=0.3 - ) - - return result.strip() - -async def gemini_supervisor_challenge_async(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict: - """ - Gemini Supervisor: Challenge and evaluate the current answer, suggesting improvements - Returns evaluation with suggestions for enhancement - """ - context_info = "" - if rag_contexts: - context_info += f"Document contexts: {len(rag_contexts)} available.\n" - if search_contexts: - context_info += f"Search contexts: {len(search_contexts)} available.\n" - - all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)]) - - prompt = f"""You are a supervisor agent evaluating and challenging a medical answer for quality and completeness. - -Original Query: "{query}" - -Available Context: -{context_info} - -MedSwin Specialist Responses: -{all_answers_text} - -Current Synthesized Answer: -{current_answer[:2000]} - -Evaluate this answer and provide: -1. Completeness: Does it fully address the query? What's missing? -2. Accuracy: Are there any inaccuracies or contradictions? -3. Clarity: Is it well-structured and clear? -4. Context Usage: Are document/search contexts properly utilized? -5. Improvement Suggestions: Specific ways to enhance the answer - -Return ONLY valid JSON: -{{ - "is_optimal": true/false, - "completeness_score": 0-10, - "accuracy_score": 0-10, - "clarity_score": 0-10, - "missing_aspects": ["..."], - "inaccuracies": ["..."], - "improvement_suggestions": ["..."], - "needs_more_context": true/false, - "enhancement_instructions": "specific instructions for improving the answer" -}}""" - - system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback in JSON format. Return ONLY valid JSON." - - response = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, - temperature=0.3 - ) - - try: - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - evaluation = json.loads(response[json_start:json_end]) - logger.info(f"[GEMINI SUPERVISOR] Challenge evaluation: optimal={evaluation.get('is_optimal', False)}, scores={evaluation.get('completeness_score', 'N/A')}/{evaluation.get('accuracy_score', 'N/A')}/{evaluation.get('clarity_score', 'N/A')}") - return evaluation - else: - raise ValueError("Evaluation JSON not found") - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Challenge evaluation parsing failed: {exc}") - return { - "is_optimal": True, - "completeness_score": 7, - "accuracy_score": 7, - "clarity_score": 7, - "missing_aspects": [], - "inaccuracies": [], - "improvement_suggestions": [], - "needs_more_context": False, - "enhancement_instructions": "" - } - -async def gemini_supervisor_enhance_answer_async(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str: - """ - Gemini Supervisor: Enhance the answer based on challenge feedback - """ - context_info = "" - if rag_contexts: - context_info += f"Document contexts: {len(rag_contexts)} available.\n" - if search_contexts: - context_info += f"Search contexts: {len(search_contexts)} available.\n" - - all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)]) - - prompt = f"""You are a supervisor agent enhancing a medical answer based on evaluation feedback. - -Original Query: "{query}" - -Available Context: -{context_info} - -MedSwin Specialist Responses: -{all_answers_text} - -Current Answer (to enhance): -{current_answer} - -Enhancement Instructions: -{enhancement_instructions} - -Create an enhanced version of the answer that: -1. Addresses all improvement suggestions -2. Fills in missing aspects -3. Corrects any inaccuracies -4. Improves clarity and structure -5. Better utilizes available context -6. Maintains all valuable information from the current answer - -Return the enhanced answer in Markdown format. Do not add meta-commentary.""" - - system_prompt = "You are a medical answer enhancement supervisor. Improve answers based on evaluation feedback while maintaining accuracy." - - result = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, - temperature=0.3 - ) - - return result.strip() - -async def gemini_supervisor_check_clarity_async(query: str, answer: str, use_web_search: bool) -> dict: - """ - Gemini Supervisor: Check if answer is unclear or supervisor is unsure (only when search mode enabled) - Returns decision on whether to trigger additional search - """ - if not use_web_search: - # Only check clarity when search mode is enabled - return {"is_unclear": False, "needs_search": False, "search_queries": []} - - prompt = f"""You are a supervisor agent evaluating answer clarity and completeness. - -Query: "{query}" - -Current Answer: -{answer[:1500]} - -Evaluate: -1. Is the answer unclear or incomplete? -2. Are there gaps that web search could fill? -3. Is the supervisor (you) unsure about certain aspects? - -Return ONLY valid JSON: -{{ - "is_unclear": true/false, - "needs_search": true/false, - "uncertainty_areas": ["..."], - "search_queries": ["specific search queries to fill gaps"], - "rationale": "brief explanation" -}} - -Only suggest search if the answer is genuinely unclear or has significant gaps that search could address.""" - - system_prompt = "You are a clarity evaluator. Assess if additional web search is needed. Return ONLY valid JSON." - - response = await call_agent( - user_prompt=prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL_LITE, - temperature=0.2 - ) - - try: - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - evaluation = json.loads(response[json_start:json_end]) - logger.info(f"[GEMINI SUPERVISOR] Clarity check: unclear={evaluation.get('is_unclear', False)}, needs_search={evaluation.get('needs_search', False)}") - return evaluation - else: - raise ValueError("Clarity check JSON not found") - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Clarity check parsing failed: {exc}") - return {"is_unclear": False, "needs_search": False, "search_queries": []} - -def gemini_supervisor_synthesize(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str: - """Wrapper to synthesize answer synchronously""" - if not MCP_AVAILABLE: - logger.warning("[GEMINI SUPERVISOR] MCP unavailable for synthesis, using simple concatenation") - return "\n\n".join(medswin_answers) - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested synthesis failed: {exc}") - return "\n\n".join(medswin_answers) - return loop.run_until_complete(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Synthesis failed: {exc}") - return "\n\n".join(medswin_answers) - -def gemini_supervisor_challenge(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict: - """Wrapper to challenge answer synchronously""" - if not MCP_AVAILABLE: - return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""} - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested challenge failed: {exc}") - return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""} - return loop.run_until_complete(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Challenge failed: {exc}") - return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""} - -def gemini_supervisor_enhance_answer(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str: - """Wrapper to enhance answer synchronously""" - if not MCP_AVAILABLE: - return current_answer - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested enhancement failed: {exc}") - return current_answer - return loop.run_until_complete(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Enhancement failed: {exc}") - return current_answer - -def gemini_supervisor_check_clarity(query: str, answer: str, use_web_search: bool) -> dict: - """Wrapper to check clarity synchronously""" - if not MCP_AVAILABLE or not use_web_search: - return {"is_unclear": False, "needs_search": False, "search_queries": []} - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(gemini_supervisor_check_clarity_async(query, answer, use_web_search)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Nested clarity check failed: {exc}") - return {"is_unclear": False, "needs_search": False, "search_queries": []} - return loop.run_until_complete(gemini_supervisor_check_clarity_async(query, answer, use_web_search)) - except Exception as exc: - logger.error(f"[GEMINI SUPERVISOR] Clarity check failed: {exc}") - return {"is_unclear": False, "needs_search": False, "search_queries": []} - -async def self_reflection_gemini(answer: str, query: str) -> dict: - """Self-reflection using Gemini MCP""" - reflection_prompt = f"""Evaluate this medical answer for quality and completeness: -Query: "{query}" -Answer: "{answer[:1000]}" -Evaluate: -1. Completeness: Does it address all aspects of the query? -2. Accuracy: Is the medical information accurate? -3. Clarity: Is it clear and well-structured? -4. Sources: Are sources cited appropriately? -5. Missing Information: What important information might be missing? -Respond in JSON: -{{ - "completeness_score": 0-10, - "accuracy_score": 0-10, - "clarity_score": 0-10, - "overall_score": 0-10, - "missing_aspects": ["..."], - "improvement_suggestions": ["..."] -}}""" - - # Use concise system prompt - system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback." - - response = await call_agent( - user_prompt=reflection_prompt, - system_prompt=system_prompt, - model=GEMINI_MODEL, # Use full model for reflection - temperature=0.3 - ) - - try: - json_start = response.find('{') - json_end = response.rfind('}') + 1 - if json_start >= 0 and json_end > json_start: - reflection = json.loads(response[json_start:json_end]) - else: - reflection = {"overall_score": 7, "improvement_suggestions": []} - except: - reflection = {"overall_score": 7, "improvement_suggestions": []} - - logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}") - return reflection - -def self_reflection(answer: str, query: str, reasoning: dict) -> dict: - """ - Self-reflection: Evaluate answer quality and completeness. - Returns reflection with quality score and improvement suggestions. - """ - if not MCP_AVAILABLE: - logger.warning("Gemini MCP not available for reflection, using fallback") - return {"overall_score": 7, "improvement_suggestions": []} - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - return nest_asyncio.run(self_reflection_gemini(answer, query)) - except Exception as e: - logger.error(f"Error in nested async reflection: {e}") - else: - return loop.run_until_complete(self_reflection_gemini(answer, query)) - except Exception as e: - logger.error(f"Gemini MCP reflection error: {e}") - - return {"overall_score": 7, "improvement_suggestions": []} - -async def parse_document_gemini(file_path: str, file_extension: str) -> str: - """Parse document using Gemini MCP""" - if not MCP_AVAILABLE: - return "" - - try: - # Read file and encode to base64 - with open(file_path, 'rb') as f: - file_content = base64.b64encode(f.read()).decode('utf-8') - - # Determine MIME type from file extension - mime_type_map = { - '.pdf': 'application/pdf', - '.doc': 'application/msword', - '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.txt': 'text/plain', - '.md': 'text/markdown', - '.json': 'application/json', - '.xml': 'application/xml', - '.csv': 'text/csv' - } - mime_type = mime_type_map.get(file_extension, 'application/octet-stream') - - # Prepare file object for Gemini MCP (use content for base64) - files = [{ - "content": file_content, - "type": mime_type - }] - - # Use concise system prompt - system_prompt = "Extract all text content from the document accurately." - user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible." - - result = await call_agent( - user_prompt=user_prompt, - system_prompt=system_prompt, - files=files, - model=GEMINI_MODEL_LITE, # Use lite model for parsing - temperature=0.2 - ) - - return result.strip() - except Exception as e: - logger.error(f"Gemini document parsing error: {e}") - return "" - -def extract_text_from_document(file): - """Extract text from document using Gemini MCP""" - file_name = file.name - file_extension = os.path.splitext(file_name)[1].lower() - - # Handle text files directly - if file_extension == '.txt': - text = file.read().decode('utf-8') - return text, len(text.split()), None - - # For PDF, Word, and other documents, use Gemini MCP - # Save file to temporary location for processing - try: - with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: - # Write file content to temp file - file.seek(0) # Reset file pointer - tmp_file.write(file.read()) - tmp_file_path = tmp_file.name - - # Use Gemini MCP to parse document - if MCP_AVAILABLE: - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - import nest_asyncio - text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension)) - except Exception as e: - logger.error(f"Error in nested async document parsing: {e}") - text = "" - else: - text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension)) - - # Clean up temp file - try: - os.unlink(tmp_file_path) - except: - pass - - if text: - return text, len(text.split()), None - else: - return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP") - except Exception as e: - logger.error(f"Gemini MCP document parsing error: {e}") - # Clean up temp file - try: - os.unlink(tmp_file_path) - except: - pass - return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}") - else: - # Clean up temp file - try: - os.unlink(tmp_file_path) - except: - pass - return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.") - except Exception as e: - logger.error(f"Error processing document: {e}") - return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}") - -@spaces.GPU(max_duration=120) -def create_or_update_index(files, request: gr.Request): - global global_file_info - - if not files: - return "Please provide files.", "" - - start_time = time.time() - user_id = request.session_hash - save_dir = f"./{user_id}_index" - # Initialize LlamaIndex modules - llm = get_llm_for_rag() - embed_model = get_or_create_embed_model() - Settings.llm = llm - Settings.embed_model = embed_model - file_stats = [] - new_documents = [] - - for file in tqdm(files, desc="Processing files"): - file_basename = os.path.basename(file.name) - text, word_count, error = extract_text_from_document(file) - if error: - logger.error(f"Error processing file {file_basename}: {str(error)}") - file_stats.append({ - "name": file_basename, - "words": 0, - "status": f"error: {str(error)}" - }) - continue - - doc = LlamaDocument( - text=text, - metadata={ - "file_name": file_basename, - "word_count": word_count, - "source": "user_upload" - } - ) - new_documents.append(doc) - - file_stats.append({ - "name": file_basename, - "words": word_count, - "status": "processed" - }) - - global_file_info[file_basename] = { - "word_count": word_count, - "processed_at": time.time() - } - - node_parser = HierarchicalNodeParser.from_defaults( - chunk_sizes=[2048, 512, 128], - chunk_overlap=20 - ) - logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes") - new_nodes = node_parser.get_nodes_from_documents(new_documents) - new_leaf_nodes = get_leaf_nodes(new_nodes) - new_root_nodes = get_root_nodes(new_nodes) - logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)") - - if os.path.exists(save_dir): - logger.info(f"Loading existing index from {save_dir}") - storage_context = StorageContext.from_defaults(persist_dir=save_dir) - index = load_index_from_storage(storage_context, settings=Settings) - docstore = storage_context.docstore - - docstore.add_documents(new_nodes) - for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"): - index.insert_nodes([node]) - - total_docs = len(docstore.docs) - logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files") - else: - logger.info("Creating new index") - docstore = SimpleDocumentStore() - storage_context = StorageContext.from_defaults(docstore=docstore) - docstore.add_documents(new_nodes) - - index = VectorStoreIndex( - new_leaf_nodes, - storage_context=storage_context, - settings=Settings - ) - total_docs = len(new_documents) - logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files") - - index.storage_context.persist(persist_dir=save_dir) - # custom outputs after processing files - file_list_html = "
" - for stat in file_stats: - status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336" - file_list_html += f"
● {stat['name']} - {stat['words']} words
" - file_list_html += "
" - processing_time = time.time() - start_time - stats_output = f"
" - stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds
" - stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)
" - stats_output += f"✓ Total documents in index: {total_docs}
" - stats_output += f"✓ Index saved to: {save_dir}
" - stats_output += "
" - output_container = f"
" - output_container += file_list_html - output_container += stats_output - output_container += "
" - return f"Successfully indexed {len(files)} files.", output_container - -@spaces.GPU(max_duration=120) -def stream_chat( - message: str, - history: list, - system_prompt: str, - temperature: float, - max_new_tokens: int, - top_p: float, - top_k: int, - penalty: float, - retriever_k: int, - merge_threshold: float, - use_rag: bool, - medical_model: str, - use_web_search: bool, - disable_agentic_reasoning: bool, - show_thoughts: bool, - request: gr.Request -): - if not request: - yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], "" - return - - # Set up thought capture handler if show_thoughts is enabled - thought_handler = None - if show_thoughts: - thought_handler = ThoughtCaptureHandler() - thought_handler.setLevel(logging.INFO) - thought_handler.clear() # Start fresh - logger.addHandler(thought_handler) - - session_start = time.time() - soft_timeout = 100 - hard_timeout = 118 # stop slightly before HF max duration (120s) - - def elapsed(): - return time.time() - session_start - - user_id = request.session_hash - index_dir = f"./{user_id}_index" - has_rag_index = os.path.exists(index_dir) - - # ===== MAC ARCHITECTURE: GEMINI SUPERVISOR + MEDSWIN SPECIALIST ===== - # All internal thoughts are logged, only final answer is displayed - - original_lang = detect_language(message) - original_message = message - needs_translation = original_lang != "en" - - pipeline_diagnostics = { - "reasoning": None, - "plan": None, - "strategy_decisions": [], - "stage_metrics": {}, - "search": {"strategies": [], "total_results": 0} - } - - def record_stage(stage_name: str, start_time: float): - pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3) - - translation_stage_start = time.time() - if needs_translation: - logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...") - message = translate_text(message, target_lang="en", source_lang=original_lang) - logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...") - record_stage("translation", translation_stage_start) - - # Determine final modes (respect user settings and availability) - final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning - final_use_web_search = use_web_search and not disable_agentic_reasoning - - plan = None - if not disable_agentic_reasoning: - reasoning_stage_start = time.time() - reasoning = autonomous_reasoning(message, history) - record_stage("autonomous_reasoning", reasoning_stage_start) - pipeline_diagnostics["reasoning"] = reasoning - plan = create_execution_plan(reasoning, message, has_rag_index) - pipeline_diagnostics["plan"] = plan - execution_strategy = autonomous_execution_strategy( - reasoning, plan, final_use_rag, final_use_web_search, has_rag_index - ) - - if final_use_rag and not reasoning.get("requires_rag", True): - final_use_rag = False - pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning") - elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index: - pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available") - - if final_use_web_search and not reasoning.get("requires_web_search", False): - final_use_web_search = False - pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning") - elif not final_use_web_search and reasoning.get("requires_web_search", False): - if not use_web_search: - pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request") - else: - pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode") - else: - pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user") - - # ===== STEP 1: GEMINI SUPERVISOR - Break query into sub-topics ===== - if disable_agentic_reasoning: - logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone") - # Simple breakdown for direct mode - breakdown = { - "sub_topics": [ - {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"} - ], - "strategy": "Direct answer", - "exploration_note": "Direct mode - no breakdown" - } - else: - logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...") - breakdown = gemini_supervisor_breakdown(message, final_use_rag, final_use_web_search, elapsed(), max_duration=120) - logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics") - - # ===== STEP 2: GEMINI SUPERVISOR - Handle Search Mode ===== - search_contexts = [] - web_urls = [] - if final_use_web_search: - search_stage_start = time.time() - logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...") - search_strategies = gemini_supervisor_search_strategies(message, elapsed()) - - # Execute searches for each strategy - all_search_results = [] - strategy_jobs = [] - for strategy in search_strategies.get("search_strategies", [])[:4]: # Max 4 strategies - search_query = strategy.get("strategy", message) - target_sources = strategy.get("target_sources", 2) - strategy_jobs.append({ - "query": search_query, - "target_sources": target_sources, - "meta": strategy - }) - - def execute_search(job): - job_start = time.time() - try: - results = search_web(job["query"], max_results=job["target_sources"]) - duration = time.time() - job_start - return results, duration, None - except Exception as exc: - return [], time.time() - job_start, exc - - def record_search_diag(job, duration, results_count, error=None): - entry = { - "query": job["query"], - "target_sources": job["target_sources"], - "duration": round(duration, 3), - "results": results_count - } - if error: - entry["error"] = str(error) - pipeline_diagnostics["search"]["strategies"].append(entry) - - if strategy_jobs: - max_workers = min(len(strategy_jobs), 4) - if len(strategy_jobs) > 1: - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_map = {executor.submit(execute_search, job): job for job in strategy_jobs} - for future in concurrent.futures.as_completed(future_map): - job = future_map[future] - try: - results, duration, error = future.result() - except Exception as exc: - results, duration, error = [], 0.0, exc - record_search_diag(job, duration, len(results), error) - if not error and results: - all_search_results.extend(results) - web_urls.extend([r.get('url', '') for r in results if r.get('url')]) - else: - job = strategy_jobs[0] - results, duration, error = execute_search(job) - record_search_diag(job, duration, len(results), error) - if not error and results: - all_search_results.extend(results) - web_urls.extend([r.get('url', '') for r in results if r.get('url')]) - else: - pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned") - - pipeline_diagnostics["search"]["total_results"] = len(all_search_results) - - # Summarize search results with Gemini - if all_search_results: - logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...") - search_summary = summarize_web_content(all_search_results, message) - if search_summary: - search_contexts.append(search_summary) - logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars") - record_stage("web_search", search_stage_start) - - # ===== STEP 3: GEMINI SUPERVISOR - Handle RAG Mode ===== - rag_contexts = [] - if final_use_rag and has_rag_index: - rag_stage_start = time.time() - if elapsed() >= soft_timeout - 10: - logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure") - final_use_rag = False - else: - logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...") - embed_model = get_or_create_embed_model() - Settings.embed_model = embed_model - storage_context = StorageContext.from_defaults(persist_dir=index_dir) - index = load_index_from_storage(storage_context, settings=Settings) - base_retriever = index.as_retriever(similarity_top_k=retriever_k) - auto_merging_retriever = AutoMergingRetriever( - base_retriever, - storage_context=storage_context, - simple_ratio_thresh=merge_threshold, - verbose=False # Reduce logging noise - ) - merged_nodes = auto_merging_retriever.retrieve(message) - retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes]) - logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes") - - # Brainstorm retrieved docs into contexts - logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...") - rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed()) - rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])] - logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts") - record_stage("rag_retrieval", rag_stage_start) - - # ===== STEP 4: MEDSWIN SPECIALIST - Execute tasks sequentially ===== - # Initialize medical model - medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model) - - # Base system prompt for MedSwin (clean, no internal thoughts) - base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables." - - # Prepare context for MedSwin (combine RAG and search contexts) - combined_context = "" - if rag_contexts: - combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4]) # Max 4 contexts - if search_contexts: - if combined_context: - combined_context += "\n\n" - combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts) - - # Execute MedSwin tasks for each sub-topic - logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...") - medswin_answers = [] - - updated_history = history + [ - {"role": "user", "content": original_message}, - {"role": "assistant", "content": ""} - ] - thoughts_text = thought_handler.get_thoughts() if thought_handler else "" - yield updated_history, thoughts_text - - medswin_stage_start = time.time() - for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1): - if elapsed() >= hard_timeout - 5: - logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}") - break - - task_instruction = sub_topic.get("instruction", "") - topic_name = sub_topic.get("topic", f"Topic {idx}") - priority = sub_topic.get("priority", "medium") - - logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})") - - # Select relevant context for this task (if multiple contexts available) - task_context = combined_context - if len(rag_contexts) > 1 and idx <= len(rag_contexts): - # Use corresponding RAG context if available - task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context - - # Execute MedSwin task (with GPU tag) - try: - task_answer = execute_medswin_task( - medical_model_obj=medical_model_obj, - medical_tokenizer=medical_tokenizer, - task_instruction=task_instruction, - context=task_context if task_context else "", - system_prompt_base=base_system_prompt, - temperature=temperature, - max_new_tokens=min(max_new_tokens, 800), # Limit per task - top_p=top_p, - top_k=top_k, - penalty=penalty - ) - - # Format task answer with topic header - formatted_answer = f"## {topic_name}\n\n{task_answer}" - medswin_answers.append(formatted_answer) - logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars") - - # Stream partial answer as we complete each task - partial_final = "\n\n".join(medswin_answers) - updated_history[-1]["content"] = partial_final - thoughts_text = thought_handler.get_thoughts() if thought_handler else "" - yield updated_history, thoughts_text - - except Exception as e: - logger.error(f"[MEDSWIN] Task {idx} failed: {e}") - # Continue with next task - continue - record_stage("medswin_tasks", medswin_stage_start) - - # ===== STEP 5: GEMINI SUPERVISOR - Synthesize final answer with clear context ===== - logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...") - raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] # Remove headers for synthesis - synthesis_stage_start = time.time() - final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown) - record_stage("synthesis", synthesis_stage_start) - - if not final_answer or len(final_answer.strip()) < 50: - # Fallback to simple concatenation if synthesis fails - logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation") - final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response." - - # Clean final answer - ensure no tables, only Markdown bullets - if "|" in final_answer and "---" in final_answer: - logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets") - lines = final_answer.split('\n') - cleaned_lines = [] - for line in lines: - if '|' in line and '---' not in line: - # Convert table row to bullet points - cells = [cell.strip() for cell in line.split('|') if cell.strip()] - if cells: - cleaned_lines.append(f"- {' / '.join(cells)}") - elif '---' not in line: - cleaned_lines.append(line) - final_answer = '\n'.join(cleaned_lines) - - # ===== STEP 6: GEMINI SUPERVISOR - Challenge and enhance answer iteratively ===== - max_challenge_iterations = 2 # Limit iterations to avoid timeout - challenge_iteration = 0 - challenge_stage_start = time.time() - - while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15: - challenge_iteration += 1 - logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...") - - evaluation = gemini_supervisor_challenge(message, final_answer, raw_medswin_answers, rag_contexts, search_contexts) - - if evaluation.get("is_optimal", False): - logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)") - break - - enhancement_instructions = evaluation.get("enhancement_instructions", "") - if not enhancement_instructions: - logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal") - break - - logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...") - enhanced_answer = gemini_supervisor_enhance_answer( - message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts - ) - - if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8: # Ensure enhancement is substantial - final_answer = enhanced_answer - logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)") - else: - logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping") - break - record_stage("challenge_loop", challenge_stage_start) - - # ===== STEP 7: Conditional search trigger (only when search mode enabled) ===== - if final_use_web_search and elapsed() < soft_timeout - 10: - logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...") - clarity_stage_start = time.time() - clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search) - record_stage("clarity_check", clarity_stage_start) - - if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"): - logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}") - additional_search_results = [] - followup_stage_start = time.time() - for search_query in clarity_check.get("search_queries", [])[:3]: # Limit to 3 additional searches - if elapsed() >= soft_timeout - 5: - break - extra_start = time.time() - results = search_web(search_query, max_results=2) - extra_duration = time.time() - extra_start - pipeline_diagnostics["search"]["strategies"].append({ - "query": search_query, - "target_sources": 2, - "duration": round(extra_duration, 3), - "results": len(results), - "type": "followup" - }) - additional_search_results.extend(results) - web_urls.extend([r.get('url', '') for r in results if r.get('url')]) - - if additional_search_results: - pipeline_diagnostics["search"]["total_results"] += len(additional_search_results) - logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...") - additional_summary = summarize_web_content(additional_search_results, message) - if additional_summary: - # Enhance answer with additional search context - search_contexts.append(additional_summary) - logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...") - enhanced_with_search = gemini_supervisor_enhance_answer( - message, final_answer, - f"Incorporate the following additional information from web search: {additional_summary}", - raw_medswin_answers, rag_contexts, search_contexts - ) - if enhanced_with_search and len(enhanced_with_search.strip()) > 50: - final_answer = enhanced_with_search - logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context") - record_stage("followup_search", followup_stage_start) - - citations_text = "" - - # ===== STEP 8: Finalize answer (translate, add citations, format) ===== - # Translate back if needed - if needs_translation and final_answer: - logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...") - final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en") - - # Add citations if web sources were used - if web_urls: - unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates - citation_links = [] - for url in unique_urls[:5]: # Limit to 5 citations - domain = format_url_as_domain(url) - if domain: - citation_links.append(f"[{domain}]({url})") - - if citation_links: - citations_text = "\n\n**Sources:** " + ", ".join(citation_links) - - # Add speaker icon - speaker_icon = ' 🔊' - final_answer_with_metadata = final_answer + citations_text + speaker_icon - - # Update history with final answer (ONLY final answer, no internal thoughts) - updated_history[-1]["content"] = final_answer_with_metadata - thoughts_text = thought_handler.get_thoughts() if thought_handler else "" - yield updated_history, thoughts_text - - # Clean up thought handler - if thought_handler: - logger.removeHandler(thought_handler) - - # Log completion - diag_summary = { - "stage_metrics": pipeline_diagnostics["stage_metrics"], - "decisions": pipeline_diagnostics["strategy_decisions"], - "search": pipeline_diagnostics["search"], - } - try: - logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}") - except Exception: - logger.info(f"[MAC] Diagnostics summary (non-serializable)") - logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed") - -def generate_speech_for_message(text: str): - """Generate speech for a message and return audio file""" - audio_path = generate_speech(text) - if audio_path: - return audio_path - return None - -def create_demo(): - with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: - gr.HTML(TITLE) - gr.HTML(DESCRIPTION) - - with gr.Row(elem_classes="main-container"): - with gr.Column(elem_classes="upload-section"): - file_upload = gr.File( - file_count="multiple", - label="Drag and Drop Files Here", - file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], - elem_id="file-upload" - ) - upload_button = gr.Button("Upload & Index", elem_classes="upload-button") - status_output = gr.Textbox( - label="Status", - placeholder="Upload files to start...", - interactive=False - ) - file_info_output = gr.HTML( - label="File Information", - elem_classes="processing-info" - ) - upload_button.click( - fn=create_or_update_index, - inputs=[file_upload], - outputs=[status_output, file_info_output] - ) - - with gr.Column(elem_classes="chatbot-container"): - chatbot = gr.Chatbot( - height=500, - placeholder="Chat with MedSwin... Type your question below.", - show_label=False, - type="messages" - ) - with gr.Row(elem_classes="input-row"): - message_input = gr.Textbox( - placeholder="Type your medical question here...", - show_label=False, - container=False, - lines=1, - scale=10 - ) - mic_button = gr.Audio( - sources=["microphone"], - type="filepath", - label="", - show_label=False, - container=False, - scale=1 - ) - submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) - - # Timer display for recording (shown below input row) - recording_timer = gr.Textbox( - value="", - label="", - show_label=False, - interactive=False, - visible=False, - container=False, - elem_classes="recording-timer" - ) - - # Handle microphone transcription - import time - recording_start_time = [None] - - def handle_recording_start(): - """Called when recording starts""" - recording_start_time[0] = time.time() - return gr.update(visible=True, value="Recording... 0s") - - def handle_recording_stop(audio): - """Called when recording stops""" - recording_start_time[0] = None - if audio is None: - return gr.update(visible=False, value=""), "" - transcribed = transcribe_audio(audio) - return gr.update(visible=False, value=""), transcribed - - # Use JavaScript for timer updates (simpler than Gradio Timer) - mic_button.start_recording( - fn=handle_recording_start, - outputs=[recording_timer] - ) - - mic_button.stop_recording( - fn=handle_recording_stop, - inputs=[mic_button], - outputs=[recording_timer, message_input] - ) - - # TTS component for generating speech from messages - with gr.Row(visible=False) as tts_row: - tts_text = gr.Textbox(visible=False) - tts_audio = gr.Audio(label="Generated Speech", visible=False) - - # Function to generate speech when speaker icon is clicked - def generate_speech_from_chat(history): - """Extract last assistant message and generate speech""" - if not history or len(history) == 0: - return None - last_msg = history[-1] - if last_msg.get("role") == "assistant": - text = last_msg.get("content", "").replace(" 🔊", "").strip() - if text: - audio_path = generate_speech(text) - return audio_path - return None - - # Add TTS button that appears when assistant responds - tts_button = gr.Button("🔊 Play Response", visible=False, size="sm") - - # Update TTS button visibility and generate speech - def update_tts_button(history): - if history and len(history) > 0 and history[-1].get("role") == "assistant": - return gr.update(visible=True) - return gr.update(visible=False) - - chatbot.change( - fn=update_tts_button, - inputs=[chatbot], - outputs=[tts_button] - ) - - tts_button.click( - fn=generate_speech_from_chat, - inputs=[chatbot], - outputs=[tts_audio] - ) - - with gr.Accordion("âš™ī¸ Advanced Settings", open=False): - with gr.Row(): - disable_agentic_reasoning = gr.Checkbox( - value=False, - label="Disable agentic reasoning", - info="Use MedSwin model alone without agentic reasoning, RAG, or web search" - ) - show_agentic_thought = gr.Button( - "Show agentic thought", - size="sm" - ) - # Scrollable textbox for agentic thoughts (initially hidden) - agentic_thoughts_box = gr.Textbox( - label="Agentic Thoughts", - placeholder="Internal thoughts from MedSwin and supervisor will appear here...", - lines=8, - max_lines=15, - interactive=False, - visible=False, - elem_classes="agentic-thoughts" - ) - with gr.Row(): - use_rag = gr.Checkbox( - value=False, - label="Enable Document RAG", - info="Answer based on uploaded documents (upload required)" - ) - use_web_search = gr.Checkbox( - value=False, - label="Enable Web Search (MCP)", - info="Fetch knowledge from online medical resources" - ) - - medical_model = gr.Radio( - choices=list(MEDSWIN_MODELS.keys()), - value=DEFAULT_MEDICAL_MODEL, - label="Medical Model", - info="MedSwin TA (default), others download on first use" - ) - - system_prompt = gr.Textbox( - value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.", - label="System Prompt", - lines=3 - ) - - with gr.Tab("Generation Parameters"): - temperature = gr.Slider( - minimum=0, - maximum=1, - step=0.1, - value=0.2, - label="Temperature" - ) - max_new_tokens = gr.Slider( - minimum=512, - maximum=4096, - step=128, - value=2048, - label="Max New Tokens", - info="Increased for medical models to prevent early stopping" - ) - top_p = gr.Slider( - minimum=0.0, - maximum=1.0, - step=0.1, - value=0.7, - label="Top P" - ) - top_k = gr.Slider( - minimum=1, - maximum=100, - step=1, - value=50, - label="Top K" - ) - penalty = gr.Slider( - minimum=0.0, - maximum=2.0, - step=0.1, - value=1.2, - label="Repetition Penalty" - ) - - with gr.Tab("Retrieval Parameters"): - retriever_k = gr.Slider( - minimum=5, - maximum=30, - step=1, - value=15, - label="Initial Retrieval Size (Top K)" - ) - merge_threshold = gr.Slider( - minimum=0.1, - maximum=0.9, - step=0.1, - value=0.5, - label="Merge Threshold (lower = more merging)" - ) - - # Toggle function for showing/hiding agentic thoughts - show_thoughts_state = gr.State(value=False) - - def toggle_thoughts_box(current_state): - """Toggle visibility of agentic thoughts box""" - new_state = not current_state - return gr.update(visible=new_state), new_state - - show_agentic_thought.click( - fn=toggle_thoughts_box, - inputs=[show_thoughts_state], - outputs=[agentic_thoughts_box, show_thoughts_state] - ) - - submit_button.click( - fn=stream_chat, - inputs=[ - message_input, - chatbot, - system_prompt, - temperature, - max_new_tokens, - top_p, - top_k, - penalty, - retriever_k, - merge_threshold, - use_rag, - medical_model, - use_web_search, - disable_agentic_reasoning, - show_thoughts_state - ], - outputs=[chatbot, agentic_thoughts_box] - ) - - message_input.submit( - fn=stream_chat, - inputs=[ - message_input, - chatbot, - system_prompt, - temperature, - max_new_tokens, - top_p, - top_k, - penalty, - retriever_k, - merge_threshold, - use_rag, - medical_model, - use_web_search, - disable_agentic_reasoning, - show_thoughts_state - ], - outputs=[chatbot, agentic_thoughts_box] - ) - - return demo +from logger import logger +from config import DEFAULT_MEDICAL_MODEL +import config +from models import initialize_medical_model, initialize_tts_model +from mcp import MCP_AVAILABLE +from ui import create_demo if __name__ == "__main__": # Preload models on startup @@ -2973,7 +17,7 @@ if __name__ == "__main__": logger.info("Preloading TTS model...") try: initialize_tts_model() - if global_tts_model is not None: + if config.global_tts_model is not None: logger.info("TTS model preloaded successfully!") else: logger.warning("TTS model not available - will use MCP or disable voice generation") @@ -2989,4 +33,4 @@ if __name__ == "__main__": logger.info("Model preloading complete!") demo = create_demo() - demo.launch() \ No newline at end of file + demo.launch()