diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,2969 +1,13 @@
-import gradio as gr
+"""Main entry point for MedLLM Agent"""
import os
-import base64
-import logging
-import torch
-import threading
-import time
-import json
-import concurrent.futures
-from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- TextIteratorStreamer,
- StoppingCriteria,
- StoppingCriteriaList,
-)
-from transformers import logging as hf_logging
-import spaces
-from llama_index.core import (
- StorageContext,
- VectorStoreIndex,
- load_index_from_storage,
- Document as LlamaDocument,
-)
-from llama_index.core import Settings
-from llama_index.core.node_parser import (
- HierarchicalNodeParser,
- get_leaf_nodes,
- get_root_nodes,
-)
-from llama_index.core.retrievers import AutoMergingRetriever
-from llama_index.core.storage.docstore import SimpleDocumentStore
-from llama_index.llms.huggingface import HuggingFaceLLM
-from llama_index.embeddings.huggingface import HuggingFaceEmbedding
-from tqdm import tqdm
-from langdetect import detect, LangDetectException
os.environ["TOKENIZERS_PARALLELISM"] = "false"
-# Set logging to INFO level for cleaner output
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger(__name__)
-# Custom logger handler to capture agentic thoughts
-class ThoughtCaptureHandler(logging.Handler):
- """Custom handler to capture internal thoughts from MedSwin and supervisor"""
- def __init__(self):
- super().__init__()
- self.thoughts = []
- self.lock = threading.Lock()
-
- def emit(self, record):
- """Capture log messages that contain agentic thoughts"""
- try:
- msg = self.format(record)
- # Only capture messages from GEMINI SUPERVISOR or MEDSWIN
- if "[GEMINI SUPERVISOR]" in msg or "[MEDSWIN]" in msg or "[MAC]" in msg:
- # Remove timestamp and logger name for cleaner display
- # Format: "timestamp - logger - level - message"
- parts = msg.split(" - ", 3)
- if len(parts) >= 4:
- clean_msg = parts[-1] # Get the message part
- else:
- clean_msg = msg
- with self.lock:
- self.thoughts.append(clean_msg)
- except Exception:
- pass # Ignore formatting errors
-
- def get_thoughts(self):
- """Get all captured thoughts as a formatted string"""
- with self.lock:
- return "\n".join(self.thoughts)
-
- def clear(self):
- """Clear captured thoughts"""
- with self.lock:
- self.thoughts = []
-# Set MCP client logging to WARNING to reduce noise
-mcp_client_logger = logging.getLogger("mcp.client")
-mcp_client_logger.setLevel(logging.WARNING)
-hf_logging.set_verbosity_error()
-
-# MCP imports
-MCP_CLIENT_INFO = None
-try:
- from mcp import ClientSession, StdioServerParameters
- from mcp import types as mcp_types
- from mcp.client.stdio import stdio_client
- import asyncio
- try:
- import nest_asyncio
- nest_asyncio.apply() # Allow nested event loops
- except ImportError:
- pass # nest_asyncio is optional
- MCP_AVAILABLE = True
- MCP_CLIENT_INFO = mcp_types.Implementation(
- name="MedLLM-Agent",
- version=os.environ.get("SPACE_VERSION", "local"),
- )
-except ImportError as e:
- logger.warning(f"MCP SDK not available: {e}")
- MCP_AVAILABLE = False
- # Fallback imports if MCP is not available
-from ddgs import DDGS
-import requests
-from bs4 import BeautifulSoup
-try:
- from TTS.api import TTS
- TTS_AVAILABLE = True
-except ImportError:
- TTS_AVAILABLE = False
- TTS = None
-import numpy as np
-import soundfile as sf
-import tempfile
-
-# Model configurations
-MEDSWIN_MODELS = {
- "MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
- "MedSwin KD": "MedSwin/MedSwin-7B-KD",
- "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
-}
-DEFAULT_MEDICAL_MODEL = "MedSwin TA"
-EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
-TTS_MODEL = "maya-research/maya1"
-HF_TOKEN = os.environ.get("HF_TOKEN")
-if not HF_TOKEN:
- raise ValueError("HF_TOKEN not found in environment variables")
-
-# Gemini MCP configuration
-GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
-GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash") # Default for harder tasks
-GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite") # For parsing and simple tasks
-
-# Custom UI
-TITLE = "
đŠē MedLLM Agent - Medical RAG & Web Search System
"
-DESCRIPTION = """
-
-Advanced Medical AI Assistant powered by MedSwin models
-đ Document RAG: Answer based on uploaded medical documents
-đ Web Search: Fetch knowledge from reliable online medical resources
-đ Multi-language: Automatic translation for non-English queries
-Upload PDF or text files to get started!
-
-"""
-CSS = """
-.upload-section {
- max-width: 400px;
- margin: 0 auto;
- padding: 10px;
- border: 2px dashed #ccc;
- border-radius: 10px;
-}
-.upload-button {
- background: #34c759 !important;
- color: white !important;
- border-radius: 25px !important;
-}
-.chatbot-container {
- margin-top: 20px;
-}
-.status-output {
- margin-top: 10px;
- font-size: 14px;
-}
-.processing-info {
- margin-top: 5px;
- font-size: 12px;
- color: #666;
-}
-.info-container {
- margin-top: 10px;
- padding: 10px;
- border-radius: 5px;
-}
-.file-list {
- margin-top: 0;
- max-height: 200px;
- overflow-y: auto;
- padding: 5px;
- border: 1px solid #eee;
- border-radius: 5px;
-}
-.stats-box {
- margin-top: 10px;
- padding: 10px;
- border-radius: 5px;
- font-size: 12px;
-}
-.submit-btn {
- background: #1a73e8 !important;
- color: white !important;
- border-radius: 25px !important;
- margin-left: 10px;
- padding: 5px 10px;
- font-size: 16px;
-}
-.input-row {
- display: flex;
- align-items: center;
-}
-.recording-timer {
- font-size: 12px;
- color: #666;
- text-align: center;
- margin-top: 5px;
-}
-.feature-badge {
- display: inline-block;
- padding: 3px 8px;
- margin: 2px;
- border-radius: 12px;
- font-size: 11px;
- font-weight: bold;
-}
-.badge-rag {
- background: #e3f2fd;
- color: #1976d2;
-}
-.badge-web {
- background: #f3e5f5;
- color: #7b1fa2;
-}
-@media (min-width: 768px) {
- .main-container {
- display: flex;
- justify-content: space-between;
- gap: 20px;
- }
- .upload-section {
- flex: 1;
- max-width: 300px;
- }
- .chatbot-container {
- flex: 2;
- margin-top: 0;
- }
-}
-"""
-
-# Global model storage
-global_medical_models = {}
-global_medical_tokenizers = {}
-global_file_info = {}
-global_tts_model = None
-global_embed_model = None
-
-# MCP client storage
-global_mcp_session = None
-global_mcp_stdio_ctx = None # Store stdio context to keep it alive
-global_mcp_lock = threading.Lock() # Lock for thread-safe session access
-# MCP server configuration via environment variables
-# Gemini MCP server: Python-based server (agent.py)
-# This works on Hugging Face Spaces without requiring npm/Node.js
-# Make sure GEMINI_API_KEY is set in environment variables
-#
-# Default configuration uses the bundled agent.py script
-# To override:
-# export MCP_SERVER_COMMAND="python"
-# export MCP_SERVER_ARGS="/path/to/agent.py"
-script_dir = os.path.dirname(os.path.abspath(__file__))
-agent_path = os.path.join(script_dir, "agent.py")
-MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python")
-MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path]
-
-async def get_mcp_session():
- """Get or create MCP client session with proper context management"""
- global global_mcp_session, global_mcp_stdio_ctx
-
- if not MCP_AVAILABLE:
- logger.warning("MCP not available - SDK not installed")
- return None
-
- # Check if session exists and is still valid
- if global_mcp_session is not None:
- # Trust that existing session is valid - verify only when actually using it
- return global_mcp_session
-
- # Create new session using correct MCP SDK pattern
- try:
- # Prepare environment variables for MCP server
- mcp_env = os.environ.copy()
- if GEMINI_API_KEY:
- mcp_env["GEMINI_API_KEY"] = GEMINI_API_KEY
- else:
- logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
-
- # Add other Gemini MCP configuration if set
- if os.environ.get("GEMINI_MODEL"):
- mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
- if os.environ.get("GEMINI_TIMEOUT"):
- mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
- if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
- mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
- if os.environ.get("GEMINI_TEMPERATURE"):
- mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
-
- logger.info("Creating MCP client session...")
-
- server_params = StdioServerParameters(
- command=MCP_SERVER_COMMAND,
- args=MCP_SERVER_ARGS,
- env=mcp_env
- )
-
- # Correct MCP SDK usage: stdio_client is an async context manager
- # that yields (read, write) streams
- stdio_ctx = stdio_client(server_params)
- read, write = await stdio_ctx.__aenter__()
-
- # Create ClientSession from the streams
- session = ClientSession(
- read,
- write,
- client_info=MCP_CLIENT_INFO,
- )
-
- # Initialize the session (this sends initialize request and waits for response + initialized notification)
- # The __aenter__() method handles the complete initialization handshake:
- # 1. Sends initialize request with client info
- # 2. Waits for initialize response from server
- # 3. Waits for initialized notification from server (this is critical!)
- # According to MCP protocol spec, the client MUST wait for the initialized notification
- # before sending any other requests (like list_tools)
- try:
- # The __aenter__() method properly handles the full initialization sequence
- # including waiting for the server's initialized notification
- # This is a blocking call that completes only after the server sends initialized
- await session.__aenter__()
- init_result = await session.initialize()
- server_info = getattr(init_result, "serverInfo", None)
- server_name = getattr(server_info, "name", "unknown")
- server_version = getattr(server_info, "version", "unknown")
- logger.info(f"â
MCP session initialized (server={server_name} v{server_version})")
- except Exception as e:
- error_msg = str(e)
- error_type = type(e).__name__
- logger.error(f"â MCP session initialization failed: {error_type}: {error_msg}")
-
- # Clean up and return None
- try:
- await session.__aexit__(None, None, None)
- except Exception:
- pass
- try:
- await stdio_ctx.__aexit__(None, None, None)
- except Exception:
- pass
- return None
-
- # Store both the session and stdio context to keep them alive
- global_mcp_session = session
- global_mcp_stdio_ctx = stdio_ctx
- logger.info("â
MCP client session created successfully")
- return session
- except Exception as e:
- error_type = type(e).__name__
- error_msg = str(e)
- logger.error(f"â Failed to create MCP client session: {error_type}: {error_msg}")
- global_mcp_session = None
- global_mcp_stdio_ctx = None
- return None
-
-MCP_TOOLS_CACHE_TTL = int(os.environ.get("MCP_TOOLS_CACHE_TTL", "60"))
-global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
-
-
-def invalidate_mcp_tools_cache():
- """Invalidate cached MCP tool metadata"""
- global global_mcp_tools_cache
- global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
-
-
-async def get_cached_mcp_tools(force_refresh: bool = False):
- """Return cached MCP tools list to avoid repeated list_tools calls"""
- global global_mcp_tools_cache
- if not MCP_AVAILABLE:
- return []
-
- now = time.time()
- if (
- not force_refresh
- and global_mcp_tools_cache["tools"]
- and now - global_mcp_tools_cache["timestamp"] < MCP_TOOLS_CACHE_TTL
- ):
- return global_mcp_tools_cache["tools"]
-
- session = await get_mcp_session()
- if session is None:
- return []
-
- try:
- tools_resp = await session.list_tools()
- tools_list = list(getattr(tools_resp, "tools", []) or [])
- global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
- return tools_list
- except Exception as e:
- logger.error(f"Failed to refresh MCP tools: {e}")
- invalidate_mcp_tools_cache()
- return []
-
-
-async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
- """Call Gemini MCP generate_content tool"""
- if not MCP_AVAILABLE:
- logger.warning("MCP not available for Gemini call")
- return ""
-
- try:
- session = await get_mcp_session()
- if session is None:
- logger.warning("Failed to get MCP session for Gemini call")
- return ""
-
- tools = await get_cached_mcp_tools()
- if not tools:
- tools = await get_cached_mcp_tools(force_refresh=True)
- if not tools:
- logger.error("Unable to obtain MCP tool catalog for Gemini calls")
- return ""
-
- generate_tool = None
- for tool in tools:
- if tool.name == "generate_content" or "generate_content" in tool.name.lower():
- generate_tool = tool
- logger.info(f"Found Gemini MCP tool: {tool.name}")
- break
-
- if not generate_tool:
- logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
- invalidate_mcp_tools_cache()
- return ""
-
- # Prepare arguments
- arguments = {
- "user_prompt": user_prompt
- }
- if system_prompt:
- arguments["system_prompt"] = system_prompt
- if files:
- arguments["files"] = files
- if model:
- arguments["model"] = model
- if temperature is not None:
- arguments["temperature"] = temperature
-
- result = await session.call_tool(generate_tool.name, arguments=arguments)
-
- # Parse result
- if hasattr(result, 'content') and result.content:
- for item in result.content:
- if hasattr(item, 'text'):
- response_text = item.text.strip()
- return response_text
- logger.warning("â ī¸ Gemini MCP returned empty or invalid result")
- return ""
- except Exception as e:
- logger.error(f"Gemini MCP call error: {e}")
- return ""
-
-def initialize_medical_model(model_name: str):
- """Initialize medical model (MedSwin) - download on demand"""
- global global_medical_models, global_medical_tokenizers
- if model_name not in global_medical_models or global_medical_models[model_name] is None:
- logger.info(f"Initializing medical model: {model_name}...")
- model_path = MEDSWIN_MODELS[model_name]
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- device_map="auto",
- trust_remote_code=True,
- token=HF_TOKEN,
- torch_dtype=torch.float16
- )
- global_medical_models[model_name] = model
- global_medical_tokenizers[model_name] = tokenizer
- logger.info(f"Medical model {model_name} initialized successfully")
- return global_medical_models[model_name], global_medical_tokenizers[model_name]
-
-
-def initialize_tts_model():
- """Initialize TTS model for text-to-speech"""
- global global_tts_model
- if not TTS_AVAILABLE:
- logger.warning("TTS library not installed. TTS features will be disabled.")
- return None
- if global_tts_model is None:
- try:
- logger.info("Initializing TTS model for voice generation...")
- global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False)
- logger.info("TTS model initialized successfully")
- except Exception as e:
- logger.warning(f"TTS model initialization failed: {e}")
- logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
- global_tts_model = None
- return global_tts_model
-
-
-def get_or_create_embed_model():
- """Reuse embedding model to avoid reloading weights each request"""
- global global_embed_model
- if global_embed_model is None:
- logger.info("Initializing shared embedding model for RAG retrieval...")
- global_embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
- return global_embed_model
-
-async def transcribe_audio_gemini(audio_path: str) -> str:
- """Transcribe audio using Gemini MCP"""
- if not MCP_AVAILABLE:
- return ""
-
- try:
- # Ensure we have an absolute path
- audio_path_abs = os.path.abspath(audio_path)
-
- # Prepare file object for Gemini MCP using path (as per Gemini MCP documentation)
- files = [{
- "path": audio_path_abs
- }]
-
- # Use exact prompts from Gemini MCP documentation
- system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
- user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
-
- result = await call_agent(
- user_prompt=user_prompt,
- system_prompt=system_prompt,
- files=files,
- model=GEMINI_MODEL_LITE, # Use lite model for transcription
- temperature=0.2
- )
-
- return result.strip()
- except Exception as e:
- logger.error(f"Gemini transcription error: {e}")
- return ""
-
-def transcribe_audio(audio):
- """Transcribe audio to text using Gemini MCP"""
- if audio is None:
- return ""
-
- try:
- # Handle file path (Gradio Audio component returns file path)
- if isinstance(audio, str):
- audio_path = audio
- elif isinstance(audio, tuple):
- # Handle tuple format (sample_rate, audio_data)
- sample_rate, audio_data = audio
- # Save to temp file
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
- sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
- audio_path = tmp_file.name
- else:
- audio_path = audio
-
- # Use Gemini MCP for transcription
- if MCP_AVAILABLE:
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
- if transcribed:
- logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
- return transcribed
- except Exception as e:
- logger.error(f"Error in nested async transcription: {e}")
- else:
- transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
- if transcribed:
- logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
- return transcribed
- except Exception as e:
- logger.error(f"Gemini MCP transcription error: {e}")
-
- logger.warning("Gemini MCP transcription not available")
- return ""
- except Exception as e:
- logger.error(f"Transcription error: {e}")
- return ""
-
-async def generate_speech_mcp(text: str) -> str:
- """Generate speech using MCP TTS tool"""
- if not MCP_AVAILABLE:
- return None
-
- try:
- # Get MCP session
- session = await get_mcp_session()
- if session is None:
- return None
-
- # Find TTS tool
- tools = await session.list_tools()
- tts_tool = None
- for tool in tools.tools:
- if "tts" in tool.name.lower() or "speech" in tool.name.lower() or "synthesize" in tool.name.lower():
- tts_tool = tool
- logger.info(f"Found MCP TTS tool: {tool.name}")
- break
-
- if tts_tool:
- result = await session.call_tool(
- tts_tool.name,
- arguments={"text": text, "language": "en"}
- )
-
- # Parse result - MCP might return audio data or file path
- if hasattr(result, 'content') and result.content:
- for item in result.content:
- if hasattr(item, 'text'):
- # If it's a file path
- if os.path.exists(item.text):
- return item.text
- elif hasattr(item, 'data') and item.data:
- # If it's binary audio data, save it
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
- tmp_file.write(item.data)
- return tmp_file.name
- return None
- except Exception as e:
- logger.warning(f"MCP TTS error: {e}")
- return None
-
-def generate_speech(text: str):
- """Generate speech from text using TTS model (with MCP fallback)"""
- if not text or len(text.strip()) == 0:
- return None
-
- # Try MCP first if available
- if MCP_AVAILABLE:
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- audio_path = nest_asyncio.run(generate_speech_mcp(text))
- if audio_path:
- logger.info("Generated speech via MCP")
- return audio_path
- except:
- pass
- else:
- audio_path = loop.run_until_complete(generate_speech_mcp(text))
- if audio_path:
- return audio_path
- except Exception as e:
- pass # MCP TTS not available, fallback to local
-
- # Fallback to local TTS model
- if not TTS_AVAILABLE:
- logger.error("TTS library not installed. Please install TTS to use voice generation.")
- return None
-
- global global_tts_model
- if global_tts_model is None:
- initialize_tts_model()
-
- if global_tts_model is None:
- logger.error("TTS model not available. Please check dependencies.")
- return None
-
- try:
- # Generate audio
- wav = global_tts_model.tts(text)
-
- # Save to temporary file
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
- sf.write(tmp_file.name, wav, samplerate=22050)
- return tmp_file.name
- except Exception as e:
- logger.error(f"TTS error: {e}")
- return None
-
-def format_prompt_manually(messages: list, tokenizer) -> str:
- """Manually format prompt for models without chat template"""
- prompt_parts = []
-
- # Combine system and user messages into a single instruction
- system_content = ""
- user_content = ""
-
- for msg in messages:
- role = msg.get("role", "user")
- content = msg.get("content", "")
-
- if role == "system":
- system_content = content
- elif role == "user":
- user_content = content
- elif role == "assistant":
- # Skip assistant messages in history for now (can be added if needed)
- pass
-
- # Format for MedAlpaca/LLaMA-based medical models
- # Common format: Instruction + Input -> Response
- if system_content:
- prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
- else:
- prompt = f"Question: {user_content}\n\nAnswer:"
-
- return prompt
-
-def detect_language(text: str) -> str:
- """Detect language of input text"""
- try:
- lang = detect(text)
- return lang
- except LangDetectException:
- return "en" # Default to English if detection fails
-
-def format_url_as_domain(url: str) -> str:
- """Format URL as simple domain name (e.g., www.mayoclinic.org)"""
- if not url:
- return ""
- try:
- from urllib.parse import urlparse
- parsed = urlparse(url)
- domain = parsed.netloc or parsed.path
- # Remove www. prefix if present, but keep it for display
- if domain.startswith('www.'):
- return domain
- elif domain:
- return domain
- return url
- except Exception:
- # Fallback: try to extract domain manually
- if '://' in url:
- domain = url.split('://')[1].split('/')[0]
- return domain
- return url
-
-async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str:
- """Translate text using Gemini MCP"""
- if source_lang:
- user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
- else:
- user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
-
- # Use concise system prompt
- system_prompt = "You are a professional translator. Translate accurately and concisely."
-
- result = await call_agent(
- user_prompt=user_prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL_LITE, # Use lite model for translation
- temperature=0.2
- )
-
- return result.strip()
-
-def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str:
- """Translate text using Gemini MCP"""
- if not MCP_AVAILABLE:
- logger.warning("Gemini MCP not available for translation")
- return text # Return original text if translation fails
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang))
- if translated:
- logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
- return translated
- except Exception as e:
- logger.error(f"Error in nested async translation: {e}")
- else:
- translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang))
- if translated:
- logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
- return translated
- except Exception as e:
- logger.error(f"Gemini MCP translation error: {e}")
-
- # Return original text if translation fails
- return text
-
-async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
- """Search web using MCP web search tool (e.g., DuckDuckGo MCP server)"""
- if not MCP_AVAILABLE:
- return []
-
- try:
- tools = await get_cached_mcp_tools()
- if not tools:
- return []
-
- search_tool = None
- for tool in tools:
- tool_name_lower = tool.name.lower()
- if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
- search_tool = tool
- logger.info(f"Found web search MCP tool: {tool.name}")
- break
-
- if not search_tool:
- tools = await get_cached_mcp_tools(force_refresh=True)
- for tool in tools:
- tool_name_lower = tool.name.lower()
- if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
- search_tool = tool
- logger.info(f"Found web search MCP tool after refresh: {tool.name}")
- break
-
- if search_tool:
- try:
- session = await get_mcp_session()
- if session is None:
- return []
- # Call the search tool
- result = await session.call_tool(
- search_tool.name,
- arguments={"query": query, "max_results": max_results}
- )
-
- # Parse result
- web_content = []
- if hasattr(result, 'content') and result.content:
- for item in result.content:
- if hasattr(item, 'text'):
- try:
- data = json.loads(item.text)
- if isinstance(data, list):
- for entry in data[:max_results]:
- web_content.append({
- 'title': entry.get('title', ''),
- 'url': entry.get('url', entry.get('href', '')),
- 'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
- })
- elif isinstance(data, dict):
- if 'results' in data:
- for entry in data['results'][:max_results]:
- web_content.append({
- 'title': entry.get('title', ''),
- 'url': entry.get('url', entry.get('href', '')),
- 'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
- })
- else:
- web_content.append({
- 'title': data.get('title', ''),
- 'url': data.get('url', data.get('href', '')),
- 'content': data.get('body', data.get('snippet', data.get('content', '')))
- })
- except json.JSONDecodeError:
- # If not JSON, treat as plain text
- web_content.append({
- 'title': '',
- 'url': '',
- 'content': item.text[:1000]
- })
-
- if web_content:
- return web_content
- except Exception as e:
- logger.error(f"Error calling web search MCP tool: {e}")
-
- else:
- logger.debug("No MCP web search tool discovered in current catalog")
- return []
- except Exception as e:
- logger.error(f"Web search MCP tool error: {e}")
- return []
-
-async def search_web_mcp(query: str, max_results: int = 5) -> list:
- """Search web using MCP tools - tries web search MCP tool first, then falls back to direct search"""
- # First try to use a dedicated web search MCP tool (like DuckDuckGo MCP server)
- results = await search_web_mcp_tool(query, max_results)
- if results:
- logger.info(f"â
Web search via MCP tool: found {len(results)} results")
- return results
-
- # If no web search MCP tool available, use direct search (ddgs)
- # Note: Gemini MCP doesn't have web search capability, so we use direct API
- # The results will then be summarized using Gemini MCP
- logger.info("âšī¸ [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)")
- return search_web_fallback(query, max_results)
-
-def search_web_fallback(query: str, max_results: int = 5) -> list:
- """Fallback web search using DuckDuckGo directly (when MCP is not available)"""
- logger.info(f"đ [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...")
- # Always import here to ensure availability
- try:
- from ddgs import DDGS
- import requests
- from bs4 import BeautifulSoup
- except ImportError:
- logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available")
- return []
-
- try:
- with DDGS() as ddgs:
- results = list(ddgs.text(query, max_results=max_results))
- web_content = []
- for result in results:
- try:
- url = result.get('href', '')
- title = result.get('title', '')
- snippet = result.get('body', '')
-
- # Try to fetch full content
- try:
- response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'})
- if response.status_code == 200:
- soup = BeautifulSoup(response.content, 'html.parser')
- # Extract main content
- for script in soup(["script", "style"]):
- script.decompose()
- text = soup.get_text()
- # Clean and limit text
- lines = (line.strip() for line in text.splitlines())
- chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
- text = ' '.join(chunk for chunk in chunks if chunk)
- if len(text) > 1000:
- text = text[:1000] + "..."
- web_content.append({
- 'title': title,
- 'url': url,
- 'content': snippet + "\n" + text[:500] if text else snippet
- })
- else:
- web_content.append({
- 'title': title,
- 'url': url,
- 'content': snippet
- })
- except:
- web_content.append({
- 'title': title,
- 'url': url,
- 'content': snippet
- })
- except Exception as e:
- logger.error(f"Error processing search result: {e}")
- continue
- logger.info(f"â
[Direct API] Web search completed: {len(web_content)} results")
- return web_content
- except Exception as e:
- logger.error(f"â [Direct API] Web search error: {e}")
- return []
-
-def search_web(query: str, max_results: int = 5) -> list:
- """Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs"""
- # Always try MCP first if available
- if MCP_AVAILABLE:
- try:
- # Run async MCP search
- try:
- loop = asyncio.get_event_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- if loop.is_running():
- # If loop is already running, use nest_asyncio or create new thread
- try:
- import nest_asyncio
- results = nest_asyncio.run(search_web_mcp(query, max_results))
- if results: # Only return if we got results from MCP
- return results
- except (ImportError, AttributeError):
- # Fallback: run in thread
- import concurrent.futures
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future = executor.submit(asyncio.run, search_web_mcp(query, max_results))
- results = future.result(timeout=30)
- if results: # Only return if we got results from MCP
- return results
- else:
- results = loop.run_until_complete(search_web_mcp(query, max_results))
- if results: # Only return if we got results from MCP
- return results
- except Exception as e:
- logger.error(f"Error running async MCP search: {e}")
-
- # Only use ddgs fallback if MCP is not available or returned no results
- logger.info("âšī¸ [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)")
- return search_web_fallback(query, max_results)
-
-async def summarize_web_content_gemini(content_list: list, query: str) -> str:
- """Summarize web search results using Gemini MCP"""
- combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
-
- user_prompt = f"""Summarize the following web search results related to the query: "{query}"
-Extract key medical information, facts, and insights. Be concise and focus on reliable information.
-Search Results:
-{combined_content}
-Summary:"""
-
- # Use concise system prompt
- system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately."
-
- result = await call_agent(
- user_prompt=user_prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL, # Use full model for summarization
- temperature=0.5
- )
-
- return result.strip()
-
-def summarize_web_content(content_list: list, query: str) -> str:
- """Summarize web search results using Gemini MCP"""
- if not MCP_AVAILABLE:
- logger.warning("Gemini MCP not available for summarization")
- # Fallback: return first result's content
- if content_list:
- return content_list[0].get('content', '')[:500]
- return ""
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query))
- if summary:
- return summary
- except Exception as e:
- logger.error(f"Error in nested async summarization: {e}")
- else:
- summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query))
- if summary:
- return summary
- except Exception as e:
- logger.error(f"Gemini MCP summarization error: {e}")
-
- # Fallback: return first result's content
- if content_list:
- return content_list[0].get('content', '')[:500]
- return ""
-
-def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
- """Get LLM for RAG indexing (uses medical model)"""
- # Use medical model for RAG indexing instead of translation model
- medical_model_obj, medical_tokenizer = initialize_medical_model(DEFAULT_MEDICAL_MODEL)
-
- return HuggingFaceLLM(
- context_window=4096,
- max_new_tokens=max_new_tokens,
- tokenizer=medical_tokenizer,
- model=medical_model_obj,
- generate_kwargs={
- "do_sample": True,
- "temperature": temperature,
- "top_k": top_k,
- "top_p": top_p
- }
- )
-
-async def autonomous_reasoning_gemini(query: str) -> dict:
- """Autonomous reasoning using Gemini MCP"""
- reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
-Query: "{query}"
-Analyze:
-1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
-2. Complexity: (simple, moderate, complex, multi_faceted)
-3. Information Needs: What specific information is required?
-4. Requires RAG: (yes/no) - Does this need document context?
-5. Requires Web Search: (yes/no) - Does this need current/updated information?
-6. Sub-questions: Break down into key sub-questions if complex
-Respond in JSON format:
-{{
- "query_type": "...",
- "complexity": "...",
- "information_needs": ["..."],
- "requires_rag": true/false,
- "requires_web_search": true/false,
- "sub_questions": ["..."]
-}}"""
-
- # Use concise system prompt
- system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses."
-
- response = await call_agent(
- user_prompt=reasoning_prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL, # Use full model for reasoning
- temperature=0.3
- )
-
- # Parse JSON response (with fallback)
- try:
- # Extract JSON from response
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- reasoning = json.loads(response[json_start:json_end])
- else:
- raise ValueError("No JSON found")
- except:
- # Fallback reasoning
- reasoning = {
- "query_type": "general_info",
- "complexity": "moderate",
- "information_needs": ["medical information"],
- "requires_rag": True,
- "requires_web_search": False,
- "sub_questions": [query]
- }
-
- logger.info(f"Reasoning analysis: {reasoning}")
- return reasoning
-
-def autonomous_reasoning(query: str, history: list) -> dict:
- """
- Autonomous reasoning: Analyze query complexity, intent, and information needs.
- Returns reasoning analysis with query type, complexity, and required information sources.
- Uses Gemini MCP for reasoning.
- """
- if not MCP_AVAILABLE:
- logger.warning("â ī¸ Gemini MCP not available for reasoning, using fallback")
- # Fallback reasoning
- return {
- "query_type": "general_info",
- "complexity": "moderate",
- "information_needs": ["medical information"],
- "requires_rag": True,
- "requires_web_search": False,
- "sub_questions": [query]
- }
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query))
- return reasoning
- except Exception as e:
- logger.error(f"Error in nested async reasoning: {e}")
- else:
- reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query))
- return reasoning
- except Exception as e:
- logger.error(f"Gemini MCP reasoning error: {e}")
-
- # Fallback reasoning only if all attempts failed
- logger.warning("â ī¸ Falling back to default reasoning")
- return {
- "query_type": "general_info",
- "complexity": "moderate",
- "information_needs": ["medical information"],
- "requires_rag": True,
- "requires_web_search": False,
- "sub_questions": [query]
- }
-
-def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict:
- """
- Planning: Create multi-step execution plan based on reasoning analysis.
- Returns execution plan with steps and strategy.
- """
- plan = {
- "steps": [],
- "strategy": "sequential",
- "iterations": 1
- }
-
- # Determine execution strategy
- if reasoning["complexity"] in ["complex", "multi_faceted"]:
- plan["strategy"] = "iterative"
- plan["iterations"] = 2
-
- # Step 1: Language detection and translation
- plan["steps"].append({
- "step": 1,
- "action": "detect_language",
- "description": "Detect query language and translate if needed"
- })
-
- # Step 2: RAG retrieval (if needed and available)
- if reasoning.get("requires_rag", True) and has_rag_index:
- plan["steps"].append({
- "step": 2,
- "action": "rag_retrieval",
- "description": "Retrieve relevant document context",
- "parameters": {"top_k": 15, "merge_threshold": 0.5}
- })
-
- # Step 3: Web search (if needed)
- if reasoning.get("requires_web_search", False):
- plan["steps"].append({
- "step": 3,
- "action": "web_search",
- "description": "Search web for current/updated information",
- "parameters": {"max_results": 5}
- })
-
- # Step 4: Sub-question processing (if complex)
- if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1:
- plan["steps"].append({
- "step": 4,
- "action": "multi_step_reasoning",
- "description": "Process sub-questions iteratively",
- "sub_questions": reasoning["sub_questions"]
- })
-
- # Step 5: Synthesis and answer generation
- plan["steps"].append({
- "step": len(plan["steps"]) + 1,
- "action": "synthesize_answer",
- "description": "Generate comprehensive answer from all sources"
- })
-
- # Step 6: Self-reflection (for complex queries)
- if reasoning["complexity"] in ["complex", "multi_faceted"]:
- plan["steps"].append({
- "step": len(plan["steps"]) + 1,
- "action": "self_reflection",
- "description": "Evaluate answer quality and completeness"
- })
-
- logger.info(f"Execution plan created: {len(plan['steps'])} steps")
- return plan
-
-def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict:
- """
- Autonomous execution: Make decisions on information gathering strategy.
- Only suggests web search override, but respects user's RAG disable setting.
- """
- strategy = {
- "use_rag": use_rag, # Respect user's RAG setting
- "use_web_search": use_web_search,
- "reasoning_override": False,
- "rationale": ""
- }
-
- # Respect user toggle; just log recommendation if web search is disabled
- if reasoning.get("requires_web_search", False) and not use_web_search:
- strategy["rationale"] = "Reasoning suggests web search for current information, but the user kept it disabled."
-
- # Note: We don't override RAG setting because:
- # 1. User may have explicitly disabled it
- # 2. RAG requires documents to be uploaded
- # 3. We should respect user's explicit choice
-
- if strategy["rationale"]:
- logger.info(f"Autonomous reasoning note: {strategy['rationale']}")
-
- return strategy
-
-async def gemini_supervisor_breakdown_async(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
- """
- Gemini Supervisor: Break user query into sub-topics (flexible number, explore different approaches)
- This is the main supervisor function that orchestrates the MAC architecture.
- All internal thoughts are logged, not displayed.
- """
- remaining_time = max(15, max_duration - time_elapsed)
-
- mode_description = []
- if use_rag:
- mode_description.append("RAG mode enabled - will use retrieved documents")
- if use_web_search:
- mode_description.append("Web search mode enabled - will search online sources")
- if not mode_description:
- mode_description.append("Direct answer mode - no additional context")
-
- # Calculate reasonable max topics based on time remaining
- # Allow more subtasks if we have time, but be flexible
- estimated_time_per_task = 8 # seconds per task
- max_topics_by_time = max(2, int((remaining_time - 20) / estimated_time_per_task))
- max_topics = min(max_topics_by_time, 10) # Cap at 10, but allow more than 4
-
- prompt = f"""You are a supervisor agent coordinating with a MedSwin medical specialist model.
-Break the following medical query into focused sub-topics that MedSwin can answer sequentially.
-Explore different potential approaches to comprehensively address the topic.
-
-Query: "{query}"
-Mode: {', '.join(mode_description)}
-Time Remaining: ~{remaining_time:.1f}s
-Maximum Topics: {max_topics} (adjust based on complexity - use as many as needed for thorough coverage)
-
-Return ONLY valid JSON (no markdown, no tables, no explanations):
-{{
- "sub_topics": [
- {{
- "id": 1,
- "topic": "concise topic name",
- "instruction": "specific directive for MedSwin to answer this topic",
- "expected_tokens": 200,
- "priority": "high|medium|low",
- "approach": "brief description of approach/angle for this topic"
- }},
- ...
- ],
- "strategy": "brief strategy description explaining the breakdown approach",
- "exploration_note": "brief note on different approaches explored"
-}}
-
-Guidelines:
-- Break down the query into as many subtasks as needed for comprehensive coverage
-- Explore different angles/approaches (e.g., clinical, diagnostic, treatment, prevention, research perspectives)
-- Each topic should be focused and answerable in ~200 tokens by MedSwin
-- Prioritize topics by importance (high priority first)
-- Don't limit yourself to 4 topics - use more if the query is complex or multi-faceted"""
-
- system_prompt = "You are a medical query supervisor. Break queries into structured JSON sub-topics, exploring different approaches. Return ONLY valid JSON."
-
- response = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL,
- temperature=0.3
- )
-
- try:
- # Extract JSON from response
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- breakdown = json.loads(response[json_start:json_end])
- logger.info(f"[GEMINI SUPERVISOR] Query broken into {len(breakdown.get('sub_topics', []))} sub-topics")
- logger.debug(f"[GEMINI SUPERVISOR] Breakdown: {json.dumps(breakdown, indent=2)}")
- return breakdown
- else:
- raise ValueError("Supervisor JSON not found")
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Breakdown parsing failed: {exc}")
- # Fallback: simple breakdown
- breakdown = {
- "sub_topics": [
- {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
- {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
- ],
- "strategy": "Sequential answer with key points",
- "exploration_note": "Fallback breakdown - basic coverage"
- }
- logger.warning(f"[GEMINI SUPERVISOR] Using fallback breakdown")
- return breakdown
-
-async def gemini_supervisor_search_strategies_async(query: str, time_elapsed: float) -> dict:
- """
- Gemini Supervisor: In search mode, break query into 1-4 searching strategies
- Returns JSON with search strategies that will be executed with ddgs
- """
- prompt = f"""You are supervising web search for a medical query.
-Break this query into 1-4 focused search strategies (each targeting 1-2 sources).
-
-Query: "{query}"
-
-Return ONLY valid JSON:
-{{
- "search_strategies": [
- {{
- "id": 1,
- "strategy": "search query string",
- "target_sources": 1,
- "focus": "what to search for"
- }},
- ...
- ],
- "max_strategies": 4
-}}
-
-Keep strategies focused and avoid overlap."""
-
- system_prompt = "You are a search strategy supervisor. Create focused search queries. Return ONLY valid JSON."
-
- response = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL_LITE, # Use lite model for search planning
- temperature=0.2
- )
-
- try:
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- strategies = json.loads(response[json_start:json_end])
- logger.info(f"[GEMINI SUPERVISOR] Created {len(strategies.get('search_strategies', []))} search strategies")
- logger.debug(f"[GEMINI SUPERVISOR] Strategies: {json.dumps(strategies, indent=2)}")
- return strategies
- else:
- raise ValueError("Search strategies JSON not found")
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Search strategies parsing failed: {exc}")
- return {
- "search_strategies": [
- {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
- ],
- "max_strategies": 1
- }
-
-async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
- """
- Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts
- These contexts will be passed to MedSwin to support decision-making
- """
- # Limit retrieved docs to avoid token overflow
- max_doc_length = 3000
- if len(retrieved_docs) > max_doc_length:
- retrieved_docs = retrieved_docs[:max_doc_length] + "..."
-
- prompt = f"""You are supervising RAG context preparation for a medical query.
-Brainstorm the retrieved documents into 1-4 concise, focused contexts that MedSwin can use.
-
-Query: "{query}"
-Retrieved Documents:
-{retrieved_docs}
-
-Return ONLY valid JSON:
-{{
- "contexts": [
- {{
- "id": 1,
- "context": "concise summary of relevant information (keep under 500 chars)",
- "focus": "what this context covers",
- "relevance": "high|medium|low"
- }},
- ...
- ],
- "max_contexts": 4
-}}
-
-Keep contexts brief and factual. Avoid redundancy."""
-
- system_prompt = "You are a RAG context supervisor. Summarize documents into concise contexts. Return ONLY valid JSON."
-
- response = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL_LITE, # Use lite model for RAG brainstorming
- temperature=0.2
- )
-
- try:
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- contexts = json.loads(response[json_start:json_end])
- logger.info(f"[GEMINI SUPERVISOR] Brainstormed {len(contexts.get('contexts', []))} RAG contexts")
- logger.debug(f"[GEMINI SUPERVISOR] Contexts: {json.dumps(contexts, indent=2)}")
- return contexts
- else:
- raise ValueError("RAG contexts JSON not found")
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] RAG brainstorming parsing failed: {exc}")
- # Fallback: use retrieved docs as single context
- return {
- "contexts": [
- {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
- ],
- "max_contexts": 1
- }
-
-def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
- """Wrapper to obtain supervisor breakdown synchronously"""
- if not MCP_AVAILABLE:
- logger.warning("[GEMINI SUPERVISOR] MCP unavailable, using fallback breakdown")
- return {
- "sub_topics": [
- {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
- {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
- ],
- "strategy": "Sequential answer with key points",
- "exploration_note": "Fallback breakdown - basic coverage"
- }
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(
- gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
- )
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested breakdown execution failed: {exc}")
- raise
- return loop.run_until_complete(
- gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
- )
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Breakdown request failed: {exc}")
- return {
- "sub_topics": [
- {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
- ],
- "strategy": "Direct answer",
- "exploration_note": "Fallback breakdown - single topic"
- }
-
-def gemini_supervisor_search_strategies(query: str, time_elapsed: float) -> dict:
- """Wrapper to obtain search strategies synchronously"""
- if not MCP_AVAILABLE:
- logger.warning("[GEMINI SUPERVISOR] MCP unavailable for search strategies")
- return {
- "search_strategies": [
- {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
- ],
- "max_strategies": 1
- }
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_search_strategies_async(query, time_elapsed))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested search strategies execution failed: {exc}")
- return {
- "search_strategies": [
- {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
- ],
- "max_strategies": 1
- }
- return loop.run_until_complete(gemini_supervisor_search_strategies_async(query, time_elapsed))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Search strategies request failed: {exc}")
- return {
- "search_strategies": [
- {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
- ],
- "max_strategies": 1
- }
-
-def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
- """Wrapper to obtain RAG brainstorm synchronously"""
- if not MCP_AVAILABLE:
- logger.warning("[GEMINI SUPERVISOR] MCP unavailable for RAG brainstorm")
- return {
- "contexts": [
- {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
- ],
- "max_contexts": 1
- }
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested RAG brainstorm execution failed: {exc}")
- return {
- "contexts": [
- {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
- ],
- "max_contexts": 1
- }
- return loop.run_until_complete(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] RAG brainstorm request failed: {exc}")
- return {
- "contexts": [
- {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
- ],
- "max_contexts": 1
- }
-
-@spaces.GPU(max_duration=120)
-def execute_medswin_task(
- medical_model_obj,
- medical_tokenizer,
- task_instruction: str,
- context: str,
- system_prompt_base: str,
- temperature: float,
- max_new_tokens: int,
- top_p: float,
- top_k: int,
- penalty: float
-) -> str:
- """
- MedSwin Specialist: Execute a single task assigned by Gemini Supervisor
- This function is tagged with @spaces.GPU to run on GPU (ZeroGPU equivalent)
- All internal thoughts are logged, only final answer is returned
- """
- # Build task-specific prompt
- if context:
- full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
- else:
- full_prompt = f"{system_prompt_base}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
-
- messages = [{"role": "system", "content": full_prompt}]
-
- # Format prompt
- if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
- try:
- prompt = medical_tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True
- )
- except Exception as e:
- logger.warning(f"[MEDSWIN] Chat template failed, using manual formatting: {e}")
- prompt = format_prompt_manually(messages, medical_tokenizer)
- else:
- prompt = format_prompt_manually(messages, medical_tokenizer)
-
- # Tokenize and generate
- inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
-
- eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
-
- with torch.no_grad():
- outputs = medical_model_obj.generate(
- **inputs,
- max_new_tokens=min(max_new_tokens, 800), # Limit per task
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- repetition_penalty=penalty,
- do_sample=True,
- eos_token_id=eos_token_id,
- pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
- )
-
- # Decode response
- response = medical_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
-
- # Clean response - remove any table-like formatting, ensure Markdown bullets
- response = response.strip()
- # Remove table markers if present
- if "|" in response and "---" in response:
- logger.warning("[MEDSWIN] Detected table format, converting to Markdown bullets")
- # Simple conversion: split by lines and convert to bullets
- lines = [line.strip() for line in response.split('\n') if line.strip() and not line.strip().startswith('|') and '---' not in line]
- response = '\n'.join([f"- {line}" if not line.startswith('-') else line for line in lines])
-
- logger.info(f"[MEDSWIN] Task completed: {len(response)} chars generated")
- return response
-
-async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
- """
- Gemini Supervisor: Synthesize final answer from all MedSwin responses with clear context
- Provides better context to create a comprehensive, well-structured final answer
- """
- # Prepare context summary
- context_summary = ""
- if rag_contexts:
- context_summary += f"Document Context Available: {len(rag_contexts)} context(s) from uploaded documents.\n"
- if search_contexts:
- context_summary += f"Web Search Context Available: {len(search_contexts)} search result(s).\n"
-
- # Combine all MedSwin answers
- all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
-
- prompt = f"""You are a supervisor agent synthesizing a comprehensive medical answer from multiple specialist responses.
-
-Original Query: "{query}"
-
-Context Available:
-{context_summary}
-
-MedSwin Specialist Responses (from {len(medswin_answers)} sub-topics):
-{all_answers_text}
-
-Your task:
-1. Synthesize all responses into a coherent, comprehensive final answer
-2. Integrate information from all sub-topics seamlessly
-3. Ensure the answer directly addresses the original query
-4. Maintain clinical accuracy and clarity
-5. Use clear structure with appropriate headings and bullet points
-6. Remove redundancy and contradictions
-7. Ensure all important points from MedSwin responses are included
-
-Return the final synthesized answer in Markdown format. Do not add meta-commentary or explanations - just provide the final answer."""
-
- system_prompt = "You are a medical answer synthesis supervisor. Create comprehensive, well-structured final answers from multiple specialist responses."
-
- result = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL,
- temperature=0.3
- )
-
- return result.strip()
-
-async def gemini_supervisor_challenge_async(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
- """
- Gemini Supervisor: Challenge and evaluate the current answer, suggesting improvements
- Returns evaluation with suggestions for enhancement
- """
- context_info = ""
- if rag_contexts:
- context_info += f"Document contexts: {len(rag_contexts)} available.\n"
- if search_contexts:
- context_info += f"Search contexts: {len(search_contexts)} available.\n"
-
- all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
-
- prompt = f"""You are a supervisor agent evaluating and challenging a medical answer for quality and completeness.
-
-Original Query: "{query}"
-
-Available Context:
-{context_info}
-
-MedSwin Specialist Responses:
-{all_answers_text}
-
-Current Synthesized Answer:
-{current_answer[:2000]}
-
-Evaluate this answer and provide:
-1. Completeness: Does it fully address the query? What's missing?
-2. Accuracy: Are there any inaccuracies or contradictions?
-3. Clarity: Is it well-structured and clear?
-4. Context Usage: Are document/search contexts properly utilized?
-5. Improvement Suggestions: Specific ways to enhance the answer
-
-Return ONLY valid JSON:
-{{
- "is_optimal": true/false,
- "completeness_score": 0-10,
- "accuracy_score": 0-10,
- "clarity_score": 0-10,
- "missing_aspects": ["..."],
- "inaccuracies": ["..."],
- "improvement_suggestions": ["..."],
- "needs_more_context": true/false,
- "enhancement_instructions": "specific instructions for improving the answer"
-}}"""
-
- system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback in JSON format. Return ONLY valid JSON."
-
- response = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL,
- temperature=0.3
- )
-
- try:
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- evaluation = json.loads(response[json_start:json_end])
- logger.info(f"[GEMINI SUPERVISOR] Challenge evaluation: optimal={evaluation.get('is_optimal', False)}, scores={evaluation.get('completeness_score', 'N/A')}/{evaluation.get('accuracy_score', 'N/A')}/{evaluation.get('clarity_score', 'N/A')}")
- return evaluation
- else:
- raise ValueError("Evaluation JSON not found")
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Challenge evaluation parsing failed: {exc}")
- return {
- "is_optimal": True,
- "completeness_score": 7,
- "accuracy_score": 7,
- "clarity_score": 7,
- "missing_aspects": [],
- "inaccuracies": [],
- "improvement_suggestions": [],
- "needs_more_context": False,
- "enhancement_instructions": ""
- }
-
-async def gemini_supervisor_enhance_answer_async(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
- """
- Gemini Supervisor: Enhance the answer based on challenge feedback
- """
- context_info = ""
- if rag_contexts:
- context_info += f"Document contexts: {len(rag_contexts)} available.\n"
- if search_contexts:
- context_info += f"Search contexts: {len(search_contexts)} available.\n"
-
- all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
-
- prompt = f"""You are a supervisor agent enhancing a medical answer based on evaluation feedback.
-
-Original Query: "{query}"
-
-Available Context:
-{context_info}
-
-MedSwin Specialist Responses:
-{all_answers_text}
-
-Current Answer (to enhance):
-{current_answer}
-
-Enhancement Instructions:
-{enhancement_instructions}
-
-Create an enhanced version of the answer that:
-1. Addresses all improvement suggestions
-2. Fills in missing aspects
-3. Corrects any inaccuracies
-4. Improves clarity and structure
-5. Better utilizes available context
-6. Maintains all valuable information from the current answer
-
-Return the enhanced answer in Markdown format. Do not add meta-commentary."""
-
- system_prompt = "You are a medical answer enhancement supervisor. Improve answers based on evaluation feedback while maintaining accuracy."
-
- result = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL,
- temperature=0.3
- )
-
- return result.strip()
-
-async def gemini_supervisor_check_clarity_async(query: str, answer: str, use_web_search: bool) -> dict:
- """
- Gemini Supervisor: Check if answer is unclear or supervisor is unsure (only when search mode enabled)
- Returns decision on whether to trigger additional search
- """
- if not use_web_search:
- # Only check clarity when search mode is enabled
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
-
- prompt = f"""You are a supervisor agent evaluating answer clarity and completeness.
-
-Query: "{query}"
-
-Current Answer:
-{answer[:1500]}
-
-Evaluate:
-1. Is the answer unclear or incomplete?
-2. Are there gaps that web search could fill?
-3. Is the supervisor (you) unsure about certain aspects?
-
-Return ONLY valid JSON:
-{{
- "is_unclear": true/false,
- "needs_search": true/false,
- "uncertainty_areas": ["..."],
- "search_queries": ["specific search queries to fill gaps"],
- "rationale": "brief explanation"
-}}
-
-Only suggest search if the answer is genuinely unclear or has significant gaps that search could address."""
-
- system_prompt = "You are a clarity evaluator. Assess if additional web search is needed. Return ONLY valid JSON."
-
- response = await call_agent(
- user_prompt=prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL_LITE,
- temperature=0.2
- )
-
- try:
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- evaluation = json.loads(response[json_start:json_end])
- logger.info(f"[GEMINI SUPERVISOR] Clarity check: unclear={evaluation.get('is_unclear', False)}, needs_search={evaluation.get('needs_search', False)}")
- return evaluation
- else:
- raise ValueError("Clarity check JSON not found")
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Clarity check parsing failed: {exc}")
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
-
-def gemini_supervisor_synthesize(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
- """Wrapper to synthesize answer synchronously"""
- if not MCP_AVAILABLE:
- logger.warning("[GEMINI SUPERVISOR] MCP unavailable for synthesis, using simple concatenation")
- return "\n\n".join(medswin_answers)
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested synthesis failed: {exc}")
- return "\n\n".join(medswin_answers)
- return loop.run_until_complete(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Synthesis failed: {exc}")
- return "\n\n".join(medswin_answers)
-
-def gemini_supervisor_challenge(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
- """Wrapper to challenge answer synchronously"""
- if not MCP_AVAILABLE:
- return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested challenge failed: {exc}")
- return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
- return loop.run_until_complete(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Challenge failed: {exc}")
- return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
-
-def gemini_supervisor_enhance_answer(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
- """Wrapper to enhance answer synchronously"""
- if not MCP_AVAILABLE:
- return current_answer
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested enhancement failed: {exc}")
- return current_answer
- return loop.run_until_complete(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Enhancement failed: {exc}")
- return current_answer
-
-def gemini_supervisor_check_clarity(query: str, answer: str, use_web_search: bool) -> dict:
- """Wrapper to check clarity synchronously"""
- if not MCP_AVAILABLE or not use_web_search:
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Nested clarity check failed: {exc}")
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
- return loop.run_until_complete(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
- except Exception as exc:
- logger.error(f"[GEMINI SUPERVISOR] Clarity check failed: {exc}")
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
-
-async def self_reflection_gemini(answer: str, query: str) -> dict:
- """Self-reflection using Gemini MCP"""
- reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
-Query: "{query}"
-Answer: "{answer[:1000]}"
-Evaluate:
-1. Completeness: Does it address all aspects of the query?
-2. Accuracy: Is the medical information accurate?
-3. Clarity: Is it clear and well-structured?
-4. Sources: Are sources cited appropriately?
-5. Missing Information: What important information might be missing?
-Respond in JSON:
-{{
- "completeness_score": 0-10,
- "accuracy_score": 0-10,
- "clarity_score": 0-10,
- "overall_score": 0-10,
- "missing_aspects": ["..."],
- "improvement_suggestions": ["..."]
-}}"""
-
- # Use concise system prompt
- system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback."
-
- response = await call_agent(
- user_prompt=reflection_prompt,
- system_prompt=system_prompt,
- model=GEMINI_MODEL, # Use full model for reflection
- temperature=0.3
- )
-
- try:
- json_start = response.find('{')
- json_end = response.rfind('}') + 1
- if json_start >= 0 and json_end > json_start:
- reflection = json.loads(response[json_start:json_end])
- else:
- reflection = {"overall_score": 7, "improvement_suggestions": []}
- except:
- reflection = {"overall_score": 7, "improvement_suggestions": []}
-
- logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}")
- return reflection
-
-def self_reflection(answer: str, query: str, reasoning: dict) -> dict:
- """
- Self-reflection: Evaluate answer quality and completeness.
- Returns reflection with quality score and improvement suggestions.
- """
- if not MCP_AVAILABLE:
- logger.warning("Gemini MCP not available for reflection, using fallback")
- return {"overall_score": 7, "improvement_suggestions": []}
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- return nest_asyncio.run(self_reflection_gemini(answer, query))
- except Exception as e:
- logger.error(f"Error in nested async reflection: {e}")
- else:
- return loop.run_until_complete(self_reflection_gemini(answer, query))
- except Exception as e:
- logger.error(f"Gemini MCP reflection error: {e}")
-
- return {"overall_score": 7, "improvement_suggestions": []}
-
-async def parse_document_gemini(file_path: str, file_extension: str) -> str:
- """Parse document using Gemini MCP"""
- if not MCP_AVAILABLE:
- return ""
-
- try:
- # Read file and encode to base64
- with open(file_path, 'rb') as f:
- file_content = base64.b64encode(f.read()).decode('utf-8')
-
- # Determine MIME type from file extension
- mime_type_map = {
- '.pdf': 'application/pdf',
- '.doc': 'application/msword',
- '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
- '.txt': 'text/plain',
- '.md': 'text/markdown',
- '.json': 'application/json',
- '.xml': 'application/xml',
- '.csv': 'text/csv'
- }
- mime_type = mime_type_map.get(file_extension, 'application/octet-stream')
-
- # Prepare file object for Gemini MCP (use content for base64)
- files = [{
- "content": file_content,
- "type": mime_type
- }]
-
- # Use concise system prompt
- system_prompt = "Extract all text content from the document accurately."
- user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible."
-
- result = await call_agent(
- user_prompt=user_prompt,
- system_prompt=system_prompt,
- files=files,
- model=GEMINI_MODEL_LITE, # Use lite model for parsing
- temperature=0.2
- )
-
- return result.strip()
- except Exception as e:
- logger.error(f"Gemini document parsing error: {e}")
- return ""
-
-def extract_text_from_document(file):
- """Extract text from document using Gemini MCP"""
- file_name = file.name
- file_extension = os.path.splitext(file_name)[1].lower()
-
- # Handle text files directly
- if file_extension == '.txt':
- text = file.read().decode('utf-8')
- return text, len(text.split()), None
-
- # For PDF, Word, and other documents, use Gemini MCP
- # Save file to temporary location for processing
- try:
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
- # Write file content to temp file
- file.seek(0) # Reset file pointer
- tmp_file.write(file.read())
- tmp_file_path = tmp_file.name
-
- # Use Gemini MCP to parse document
- if MCP_AVAILABLE:
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- try:
- import nest_asyncio
- text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension))
- except Exception as e:
- logger.error(f"Error in nested async document parsing: {e}")
- text = ""
- else:
- text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension))
-
- # Clean up temp file
- try:
- os.unlink(tmp_file_path)
- except:
- pass
-
- if text:
- return text, len(text.split()), None
- else:
- return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP")
- except Exception as e:
- logger.error(f"Gemini MCP document parsing error: {e}")
- # Clean up temp file
- try:
- os.unlink(tmp_file_path)
- except:
- pass
- return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}")
- else:
- # Clean up temp file
- try:
- os.unlink(tmp_file_path)
- except:
- pass
- return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.")
- except Exception as e:
- logger.error(f"Error processing document: {e}")
- return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
-
-@spaces.GPU(max_duration=120)
-def create_or_update_index(files, request: gr.Request):
- global global_file_info
-
- if not files:
- return "Please provide files.", ""
-
- start_time = time.time()
- user_id = request.session_hash
- save_dir = f"./{user_id}_index"
- # Initialize LlamaIndex modules
- llm = get_llm_for_rag()
- embed_model = get_or_create_embed_model()
- Settings.llm = llm
- Settings.embed_model = embed_model
- file_stats = []
- new_documents = []
-
- for file in tqdm(files, desc="Processing files"):
- file_basename = os.path.basename(file.name)
- text, word_count, error = extract_text_from_document(file)
- if error:
- logger.error(f"Error processing file {file_basename}: {str(error)}")
- file_stats.append({
- "name": file_basename,
- "words": 0,
- "status": f"error: {str(error)}"
- })
- continue
-
- doc = LlamaDocument(
- text=text,
- metadata={
- "file_name": file_basename,
- "word_count": word_count,
- "source": "user_upload"
- }
- )
- new_documents.append(doc)
-
- file_stats.append({
- "name": file_basename,
- "words": word_count,
- "status": "processed"
- })
-
- global_file_info[file_basename] = {
- "word_count": word_count,
- "processed_at": time.time()
- }
-
- node_parser = HierarchicalNodeParser.from_defaults(
- chunk_sizes=[2048, 512, 128],
- chunk_overlap=20
- )
- logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
- new_nodes = node_parser.get_nodes_from_documents(new_documents)
- new_leaf_nodes = get_leaf_nodes(new_nodes)
- new_root_nodes = get_root_nodes(new_nodes)
- logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
-
- if os.path.exists(save_dir):
- logger.info(f"Loading existing index from {save_dir}")
- storage_context = StorageContext.from_defaults(persist_dir=save_dir)
- index = load_index_from_storage(storage_context, settings=Settings)
- docstore = storage_context.docstore
-
- docstore.add_documents(new_nodes)
- for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
- index.insert_nodes([node])
-
- total_docs = len(docstore.docs)
- logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
- else:
- logger.info("Creating new index")
- docstore = SimpleDocumentStore()
- storage_context = StorageContext.from_defaults(docstore=docstore)
- docstore.add_documents(new_nodes)
-
- index = VectorStoreIndex(
- new_leaf_nodes,
- storage_context=storage_context,
- settings=Settings
- )
- total_docs = len(new_documents)
- logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
-
- index.storage_context.persist(persist_dir=save_dir)
- # custom outputs after processing files
- file_list_html = ""
- for stat in file_stats:
- status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
- file_list_html += f"
â {stat['name']} - {stat['words']} words
"
- file_list_html += "
"
- processing_time = time.time() - start_time
- stats_output = f""
- stats_output += f"â Processed {len(files)} files in {processing_time:.2f} seconds
"
- stats_output += f"â Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)
"
- stats_output += f"â Total documents in index: {total_docs}
"
- stats_output += f"â Index saved to: {save_dir}
"
- stats_output += "
"
- output_container = f""
- output_container += file_list_html
- output_container += stats_output
- output_container += "
"
- return f"Successfully indexed {len(files)} files.", output_container
-
-@spaces.GPU(max_duration=120)
-def stream_chat(
- message: str,
- history: list,
- system_prompt: str,
- temperature: float,
- max_new_tokens: int,
- top_p: float,
- top_k: int,
- penalty: float,
- retriever_k: int,
- merge_threshold: float,
- use_rag: bool,
- medical_model: str,
- use_web_search: bool,
- disable_agentic_reasoning: bool,
- show_thoughts: bool,
- request: gr.Request
-):
- if not request:
- yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], ""
- return
-
- # Set up thought capture handler if show_thoughts is enabled
- thought_handler = None
- if show_thoughts:
- thought_handler = ThoughtCaptureHandler()
- thought_handler.setLevel(logging.INFO)
- thought_handler.clear() # Start fresh
- logger.addHandler(thought_handler)
-
- session_start = time.time()
- soft_timeout = 100
- hard_timeout = 118 # stop slightly before HF max duration (120s)
-
- def elapsed():
- return time.time() - session_start
-
- user_id = request.session_hash
- index_dir = f"./{user_id}_index"
- has_rag_index = os.path.exists(index_dir)
-
- # ===== MAC ARCHITECTURE: GEMINI SUPERVISOR + MEDSWIN SPECIALIST =====
- # All internal thoughts are logged, only final answer is displayed
-
- original_lang = detect_language(message)
- original_message = message
- needs_translation = original_lang != "en"
-
- pipeline_diagnostics = {
- "reasoning": None,
- "plan": None,
- "strategy_decisions": [],
- "stage_metrics": {},
- "search": {"strategies": [], "total_results": 0}
- }
-
- def record_stage(stage_name: str, start_time: float):
- pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
-
- translation_stage_start = time.time()
- if needs_translation:
- logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
- message = translate_text(message, target_lang="en", source_lang=original_lang)
- logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
- record_stage("translation", translation_stage_start)
-
- # Determine final modes (respect user settings and availability)
- final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
- final_use_web_search = use_web_search and not disable_agentic_reasoning
-
- plan = None
- if not disable_agentic_reasoning:
- reasoning_stage_start = time.time()
- reasoning = autonomous_reasoning(message, history)
- record_stage("autonomous_reasoning", reasoning_stage_start)
- pipeline_diagnostics["reasoning"] = reasoning
- plan = create_execution_plan(reasoning, message, has_rag_index)
- pipeline_diagnostics["plan"] = plan
- execution_strategy = autonomous_execution_strategy(
- reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
- )
-
- if final_use_rag and not reasoning.get("requires_rag", True):
- final_use_rag = False
- pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
- elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
- pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
-
- if final_use_web_search and not reasoning.get("requires_web_search", False):
- final_use_web_search = False
- pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
- elif not final_use_web_search and reasoning.get("requires_web_search", False):
- if not use_web_search:
- pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
- else:
- pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
- else:
- pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
-
- # ===== STEP 1: GEMINI SUPERVISOR - Break query into sub-topics =====
- if disable_agentic_reasoning:
- logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
- # Simple breakdown for direct mode
- breakdown = {
- "sub_topics": [
- {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
- ],
- "strategy": "Direct answer",
- "exploration_note": "Direct mode - no breakdown"
- }
- else:
- logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
- breakdown = gemini_supervisor_breakdown(message, final_use_rag, final_use_web_search, elapsed(), max_duration=120)
- logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
-
- # ===== STEP 2: GEMINI SUPERVISOR - Handle Search Mode =====
- search_contexts = []
- web_urls = []
- if final_use_web_search:
- search_stage_start = time.time()
- logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
- search_strategies = gemini_supervisor_search_strategies(message, elapsed())
-
- # Execute searches for each strategy
- all_search_results = []
- strategy_jobs = []
- for strategy in search_strategies.get("search_strategies", [])[:4]: # Max 4 strategies
- search_query = strategy.get("strategy", message)
- target_sources = strategy.get("target_sources", 2)
- strategy_jobs.append({
- "query": search_query,
- "target_sources": target_sources,
- "meta": strategy
- })
-
- def execute_search(job):
- job_start = time.time()
- try:
- results = search_web(job["query"], max_results=job["target_sources"])
- duration = time.time() - job_start
- return results, duration, None
- except Exception as exc:
- return [], time.time() - job_start, exc
-
- def record_search_diag(job, duration, results_count, error=None):
- entry = {
- "query": job["query"],
- "target_sources": job["target_sources"],
- "duration": round(duration, 3),
- "results": results_count
- }
- if error:
- entry["error"] = str(error)
- pipeline_diagnostics["search"]["strategies"].append(entry)
-
- if strategy_jobs:
- max_workers = min(len(strategy_jobs), 4)
- if len(strategy_jobs) > 1:
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
- for future in concurrent.futures.as_completed(future_map):
- job = future_map[future]
- try:
- results, duration, error = future.result()
- except Exception as exc:
- results, duration, error = [], 0.0, exc
- record_search_diag(job, duration, len(results), error)
- if not error and results:
- all_search_results.extend(results)
- web_urls.extend([r.get('url', '') for r in results if r.get('url')])
- else:
- job = strategy_jobs[0]
- results, duration, error = execute_search(job)
- record_search_diag(job, duration, len(results), error)
- if not error and results:
- all_search_results.extend(results)
- web_urls.extend([r.get('url', '') for r in results if r.get('url')])
- else:
- pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
-
- pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
-
- # Summarize search results with Gemini
- if all_search_results:
- logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
- search_summary = summarize_web_content(all_search_results, message)
- if search_summary:
- search_contexts.append(search_summary)
- logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
- record_stage("web_search", search_stage_start)
-
- # ===== STEP 3: GEMINI SUPERVISOR - Handle RAG Mode =====
- rag_contexts = []
- if final_use_rag and has_rag_index:
- rag_stage_start = time.time()
- if elapsed() >= soft_timeout - 10:
- logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
- final_use_rag = False
- else:
- logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
- embed_model = get_or_create_embed_model()
- Settings.embed_model = embed_model
- storage_context = StorageContext.from_defaults(persist_dir=index_dir)
- index = load_index_from_storage(storage_context, settings=Settings)
- base_retriever = index.as_retriever(similarity_top_k=retriever_k)
- auto_merging_retriever = AutoMergingRetriever(
- base_retriever,
- storage_context=storage_context,
- simple_ratio_thresh=merge_threshold,
- verbose=False # Reduce logging noise
- )
- merged_nodes = auto_merging_retriever.retrieve(message)
- retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
- logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
-
- # Brainstorm retrieved docs into contexts
- logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
- rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
- rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
- logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
- record_stage("rag_retrieval", rag_stage_start)
-
- # ===== STEP 4: MEDSWIN SPECIALIST - Execute tasks sequentially =====
- # Initialize medical model
- medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
-
- # Base system prompt for MedSwin (clean, no internal thoughts)
- base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
-
- # Prepare context for MedSwin (combine RAG and search contexts)
- combined_context = ""
- if rag_contexts:
- combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4]) # Max 4 contexts
- if search_contexts:
- if combined_context:
- combined_context += "\n\n"
- combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
-
- # Execute MedSwin tasks for each sub-topic
- logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
- medswin_answers = []
-
- updated_history = history + [
- {"role": "user", "content": original_message},
- {"role": "assistant", "content": ""}
- ]
- thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
- yield updated_history, thoughts_text
-
- medswin_stage_start = time.time()
- for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
- if elapsed() >= hard_timeout - 5:
- logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
- break
-
- task_instruction = sub_topic.get("instruction", "")
- topic_name = sub_topic.get("topic", f"Topic {idx}")
- priority = sub_topic.get("priority", "medium")
-
- logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
-
- # Select relevant context for this task (if multiple contexts available)
- task_context = combined_context
- if len(rag_contexts) > 1 and idx <= len(rag_contexts):
- # Use corresponding RAG context if available
- task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
-
- # Execute MedSwin task (with GPU tag)
- try:
- task_answer = execute_medswin_task(
- medical_model_obj=medical_model_obj,
- medical_tokenizer=medical_tokenizer,
- task_instruction=task_instruction,
- context=task_context if task_context else "",
- system_prompt_base=base_system_prompt,
- temperature=temperature,
- max_new_tokens=min(max_new_tokens, 800), # Limit per task
- top_p=top_p,
- top_k=top_k,
- penalty=penalty
- )
-
- # Format task answer with topic header
- formatted_answer = f"## {topic_name}\n\n{task_answer}"
- medswin_answers.append(formatted_answer)
- logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
-
- # Stream partial answer as we complete each task
- partial_final = "\n\n".join(medswin_answers)
- updated_history[-1]["content"] = partial_final
- thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
- yield updated_history, thoughts_text
-
- except Exception as e:
- logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
- # Continue with next task
- continue
- record_stage("medswin_tasks", medswin_stage_start)
-
- # ===== STEP 5: GEMINI SUPERVISOR - Synthesize final answer with clear context =====
- logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
- raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers] # Remove headers for synthesis
- synthesis_stage_start = time.time()
- final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown)
- record_stage("synthesis", synthesis_stage_start)
-
- if not final_answer or len(final_answer.strip()) < 50:
- # Fallback to simple concatenation if synthesis fails
- logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation")
- final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
-
- # Clean final answer - ensure no tables, only Markdown bullets
- if "|" in final_answer and "---" in final_answer:
- logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
- lines = final_answer.split('\n')
- cleaned_lines = []
- for line in lines:
- if '|' in line and '---' not in line:
- # Convert table row to bullet points
- cells = [cell.strip() for cell in line.split('|') if cell.strip()]
- if cells:
- cleaned_lines.append(f"- {' / '.join(cells)}")
- elif '---' not in line:
- cleaned_lines.append(line)
- final_answer = '\n'.join(cleaned_lines)
-
- # ===== STEP 6: GEMINI SUPERVISOR - Challenge and enhance answer iteratively =====
- max_challenge_iterations = 2 # Limit iterations to avoid timeout
- challenge_iteration = 0
- challenge_stage_start = time.time()
-
- while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
- challenge_iteration += 1
- logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...")
-
- evaluation = gemini_supervisor_challenge(message, final_answer, raw_medswin_answers, rag_contexts, search_contexts)
-
- if evaluation.get("is_optimal", False):
- logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)")
- break
-
- enhancement_instructions = evaluation.get("enhancement_instructions", "")
- if not enhancement_instructions:
- logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal")
- break
-
- logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...")
- enhanced_answer = gemini_supervisor_enhance_answer(
- message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts
- )
-
- if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8: # Ensure enhancement is substantial
- final_answer = enhanced_answer
- logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)")
- else:
- logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
- break
- record_stage("challenge_loop", challenge_stage_start)
-
- # ===== STEP 7: Conditional search trigger (only when search mode enabled) =====
- if final_use_web_search and elapsed() < soft_timeout - 10:
- logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
- clarity_stage_start = time.time()
- clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search)
- record_stage("clarity_check", clarity_stage_start)
-
- if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
- logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
- additional_search_results = []
- followup_stage_start = time.time()
- for search_query in clarity_check.get("search_queries", [])[:3]: # Limit to 3 additional searches
- if elapsed() >= soft_timeout - 5:
- break
- extra_start = time.time()
- results = search_web(search_query, max_results=2)
- extra_duration = time.time() - extra_start
- pipeline_diagnostics["search"]["strategies"].append({
- "query": search_query,
- "target_sources": 2,
- "duration": round(extra_duration, 3),
- "results": len(results),
- "type": "followup"
- })
- additional_search_results.extend(results)
- web_urls.extend([r.get('url', '') for r in results if r.get('url')])
-
- if additional_search_results:
- pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
- logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
- additional_summary = summarize_web_content(additional_search_results, message)
- if additional_summary:
- # Enhance answer with additional search context
- search_contexts.append(additional_summary)
- logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...")
- enhanced_with_search = gemini_supervisor_enhance_answer(
- message, final_answer,
- f"Incorporate the following additional information from web search: {additional_summary}",
- raw_medswin_answers, rag_contexts, search_contexts
- )
- if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
- final_answer = enhanced_with_search
- logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
- record_stage("followup_search", followup_stage_start)
-
- citations_text = ""
-
- # ===== STEP 8: Finalize answer (translate, add citations, format) =====
- # Translate back if needed
- if needs_translation and final_answer:
- logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
- final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
-
- # Add citations if web sources were used
- if web_urls:
- unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
- citation_links = []
- for url in unique_urls[:5]: # Limit to 5 citations
- domain = format_url_as_domain(url)
- if domain:
- citation_links.append(f"[{domain}]({url})")
-
- if citation_links:
- citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
-
- # Add speaker icon
- speaker_icon = ' đ'
- final_answer_with_metadata = final_answer + citations_text + speaker_icon
-
- # Update history with final answer (ONLY final answer, no internal thoughts)
- updated_history[-1]["content"] = final_answer_with_metadata
- thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
- yield updated_history, thoughts_text
-
- # Clean up thought handler
- if thought_handler:
- logger.removeHandler(thought_handler)
-
- # Log completion
- diag_summary = {
- "stage_metrics": pipeline_diagnostics["stage_metrics"],
- "decisions": pipeline_diagnostics["strategy_decisions"],
- "search": pipeline_diagnostics["search"],
- }
- try:
- logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
- except Exception:
- logger.info(f"[MAC] Diagnostics summary (non-serializable)")
- logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
-
-def generate_speech_for_message(text: str):
- """Generate speech for a message and return audio file"""
- audio_path = generate_speech(text)
- if audio_path:
- return audio_path
- return None
-
-def create_demo():
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
- gr.HTML(TITLE)
- gr.HTML(DESCRIPTION)
-
- with gr.Row(elem_classes="main-container"):
- with gr.Column(elem_classes="upload-section"):
- file_upload = gr.File(
- file_count="multiple",
- label="Drag and Drop Files Here",
- file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
- elem_id="file-upload"
- )
- upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
- status_output = gr.Textbox(
- label="Status",
- placeholder="Upload files to start...",
- interactive=False
- )
- file_info_output = gr.HTML(
- label="File Information",
- elem_classes="processing-info"
- )
- upload_button.click(
- fn=create_or_update_index,
- inputs=[file_upload],
- outputs=[status_output, file_info_output]
- )
-
- with gr.Column(elem_classes="chatbot-container"):
- chatbot = gr.Chatbot(
- height=500,
- placeholder="Chat with MedSwin... Type your question below.",
- show_label=False,
- type="messages"
- )
- with gr.Row(elem_classes="input-row"):
- message_input = gr.Textbox(
- placeholder="Type your medical question here...",
- show_label=False,
- container=False,
- lines=1,
- scale=10
- )
- mic_button = gr.Audio(
- sources=["microphone"],
- type="filepath",
- label="",
- show_label=False,
- container=False,
- scale=1
- )
- submit_button = gr.Button("â¤", elem_classes="submit-btn", scale=1)
-
- # Timer display for recording (shown below input row)
- recording_timer = gr.Textbox(
- value="",
- label="",
- show_label=False,
- interactive=False,
- visible=False,
- container=False,
- elem_classes="recording-timer"
- )
-
- # Handle microphone transcription
- import time
- recording_start_time = [None]
-
- def handle_recording_start():
- """Called when recording starts"""
- recording_start_time[0] = time.time()
- return gr.update(visible=True, value="Recording... 0s")
-
- def handle_recording_stop(audio):
- """Called when recording stops"""
- recording_start_time[0] = None
- if audio is None:
- return gr.update(visible=False, value=""), ""
- transcribed = transcribe_audio(audio)
- return gr.update(visible=False, value=""), transcribed
-
- # Use JavaScript for timer updates (simpler than Gradio Timer)
- mic_button.start_recording(
- fn=handle_recording_start,
- outputs=[recording_timer]
- )
-
- mic_button.stop_recording(
- fn=handle_recording_stop,
- inputs=[mic_button],
- outputs=[recording_timer, message_input]
- )
-
- # TTS component for generating speech from messages
- with gr.Row(visible=False) as tts_row:
- tts_text = gr.Textbox(visible=False)
- tts_audio = gr.Audio(label="Generated Speech", visible=False)
-
- # Function to generate speech when speaker icon is clicked
- def generate_speech_from_chat(history):
- """Extract last assistant message and generate speech"""
- if not history or len(history) == 0:
- return None
- last_msg = history[-1]
- if last_msg.get("role") == "assistant":
- text = last_msg.get("content", "").replace(" đ", "").strip()
- if text:
- audio_path = generate_speech(text)
- return audio_path
- return None
-
- # Add TTS button that appears when assistant responds
- tts_button = gr.Button("đ Play Response", visible=False, size="sm")
-
- # Update TTS button visibility and generate speech
- def update_tts_button(history):
- if history and len(history) > 0 and history[-1].get("role") == "assistant":
- return gr.update(visible=True)
- return gr.update(visible=False)
-
- chatbot.change(
- fn=update_tts_button,
- inputs=[chatbot],
- outputs=[tts_button]
- )
-
- tts_button.click(
- fn=generate_speech_from_chat,
- inputs=[chatbot],
- outputs=[tts_audio]
- )
-
- with gr.Accordion("âī¸ Advanced Settings", open=False):
- with gr.Row():
- disable_agentic_reasoning = gr.Checkbox(
- value=False,
- label="Disable agentic reasoning",
- info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
- )
- show_agentic_thought = gr.Button(
- "Show agentic thought",
- size="sm"
- )
- # Scrollable textbox for agentic thoughts (initially hidden)
- agentic_thoughts_box = gr.Textbox(
- label="Agentic Thoughts",
- placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
- lines=8,
- max_lines=15,
- interactive=False,
- visible=False,
- elem_classes="agentic-thoughts"
- )
- with gr.Row():
- use_rag = gr.Checkbox(
- value=False,
- label="Enable Document RAG",
- info="Answer based on uploaded documents (upload required)"
- )
- use_web_search = gr.Checkbox(
- value=False,
- label="Enable Web Search (MCP)",
- info="Fetch knowledge from online medical resources"
- )
-
- medical_model = gr.Radio(
- choices=list(MEDSWIN_MODELS.keys()),
- value=DEFAULT_MEDICAL_MODEL,
- label="Medical Model",
- info="MedSwin TA (default), others download on first use"
- )
-
- system_prompt = gr.Textbox(
- value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
- label="System Prompt",
- lines=3
- )
-
- with gr.Tab("Generation Parameters"):
- temperature = gr.Slider(
- minimum=0,
- maximum=1,
- step=0.1,
- value=0.2,
- label="Temperature"
- )
- max_new_tokens = gr.Slider(
- minimum=512,
- maximum=4096,
- step=128,
- value=2048,
- label="Max New Tokens",
- info="Increased for medical models to prevent early stopping"
- )
- top_p = gr.Slider(
- minimum=0.0,
- maximum=1.0,
- step=0.1,
- value=0.7,
- label="Top P"
- )
- top_k = gr.Slider(
- minimum=1,
- maximum=100,
- step=1,
- value=50,
- label="Top K"
- )
- penalty = gr.Slider(
- minimum=0.0,
- maximum=2.0,
- step=0.1,
- value=1.2,
- label="Repetition Penalty"
- )
-
- with gr.Tab("Retrieval Parameters"):
- retriever_k = gr.Slider(
- minimum=5,
- maximum=30,
- step=1,
- value=15,
- label="Initial Retrieval Size (Top K)"
- )
- merge_threshold = gr.Slider(
- minimum=0.1,
- maximum=0.9,
- step=0.1,
- value=0.5,
- label="Merge Threshold (lower = more merging)"
- )
-
- # Toggle function for showing/hiding agentic thoughts
- show_thoughts_state = gr.State(value=False)
-
- def toggle_thoughts_box(current_state):
- """Toggle visibility of agentic thoughts box"""
- new_state = not current_state
- return gr.update(visible=new_state), new_state
-
- show_agentic_thought.click(
- fn=toggle_thoughts_box,
- inputs=[show_thoughts_state],
- outputs=[agentic_thoughts_box, show_thoughts_state]
- )
-
- submit_button.click(
- fn=stream_chat,
- inputs=[
- message_input,
- chatbot,
- system_prompt,
- temperature,
- max_new_tokens,
- top_p,
- top_k,
- penalty,
- retriever_k,
- merge_threshold,
- use_rag,
- medical_model,
- use_web_search,
- disable_agentic_reasoning,
- show_thoughts_state
- ],
- outputs=[chatbot, agentic_thoughts_box]
- )
-
- message_input.submit(
- fn=stream_chat,
- inputs=[
- message_input,
- chatbot,
- system_prompt,
- temperature,
- max_new_tokens,
- top_p,
- top_k,
- penalty,
- retriever_k,
- merge_threshold,
- use_rag,
- medical_model,
- use_web_search,
- disable_agentic_reasoning,
- show_thoughts_state
- ],
- outputs=[chatbot, agentic_thoughts_box]
- )
-
- return demo
+from logger import logger
+from config import DEFAULT_MEDICAL_MODEL
+import config
+from models import initialize_medical_model, initialize_tts_model
+from mcp import MCP_AVAILABLE
+from ui import create_demo
if __name__ == "__main__":
# Preload models on startup
@@ -2973,7 +17,7 @@ if __name__ == "__main__":
logger.info("Preloading TTS model...")
try:
initialize_tts_model()
- if global_tts_model is not None:
+ if config.global_tts_model is not None:
logger.info("TTS model preloaded successfully!")
else:
logger.warning("TTS model not available - will use MCP or disable voice generation")
@@ -2989,4 +33,4 @@ if __name__ == "__main__":
logger.info("Model preloading complete!")
demo = create_demo()
- demo.launch()
\ No newline at end of file
+ demo.launch()