Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
52b4ed7
1
Parent(s):
b720259
Refactor app.py into modular files for better scalability
Browse files- app.py +0 -0
- config.py +148 -0
- indexing.py +236 -0
- logger.py +49 -0
- mcp.py +194 -0
- models.py +77 -0
- pipeline.py +438 -0
- reasoning.py +178 -0
- search.py +252 -0
- supervisor.py +717 -0
- ui.py +293 -0
- utils.py +109 -0
- voice.py +165 -0
app.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration constants and global storage"""
|
| 2 |
+
import os
|
| 3 |
+
import threading
|
| 4 |
+
|
| 5 |
+
# Model configurations
|
| 6 |
+
MEDSWIN_MODELS = {
|
| 7 |
+
"MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
|
| 8 |
+
"MedSwin KD": "MedSwin/MedSwin-7B-KD",
|
| 9 |
+
"MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
|
| 10 |
+
}
|
| 11 |
+
DEFAULT_MEDICAL_MODEL = "MedSwin TA"
|
| 12 |
+
EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1"
|
| 13 |
+
TTS_MODEL = "maya-research/maya1"
|
| 14 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 15 |
+
if not HF_TOKEN:
|
| 16 |
+
raise ValueError("HF_TOKEN not found in environment variables")
|
| 17 |
+
|
| 18 |
+
# Gemini MCP configuration
|
| 19 |
+
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
| 20 |
+
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
|
| 21 |
+
GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite")
|
| 22 |
+
|
| 23 |
+
# MCP server configuration
|
| 24 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
agent_path = os.path.join(script_dir, "agent.py")
|
| 26 |
+
MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python")
|
| 27 |
+
MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path]
|
| 28 |
+
MCP_TOOLS_CACHE_TTL = int(os.environ.get("MCP_TOOLS_CACHE_TTL", "60"))
|
| 29 |
+
|
| 30 |
+
# Global model storage
|
| 31 |
+
global_medical_models = {}
|
| 32 |
+
global_medical_tokenizers = {}
|
| 33 |
+
global_file_info = {}
|
| 34 |
+
global_tts_model = None
|
| 35 |
+
global_embed_model = None
|
| 36 |
+
|
| 37 |
+
# MCP client storage
|
| 38 |
+
global_mcp_session = None
|
| 39 |
+
global_mcp_stdio_ctx = None
|
| 40 |
+
global_mcp_lock = threading.Lock()
|
| 41 |
+
global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
|
| 42 |
+
|
| 43 |
+
# UI constants
|
| 44 |
+
TITLE = "<h1><center>🩺 MedLLM Agent - Medical RAG & Web Search System</center></h1>"
|
| 45 |
+
DESCRIPTION = """
|
| 46 |
+
<center>
|
| 47 |
+
<p><strong>Advanced Medical AI Assistant</strong> powered by MedSwin models</p>
|
| 48 |
+
<p>📄 <strong>Document RAG:</strong> Answer based on uploaded medical documents</p>
|
| 49 |
+
<p>🌐 <strong>Web Search:</strong> Fetch knowledge from reliable online medical resources</p>
|
| 50 |
+
<p>🌍 <strong>Multi-language:</strong> Automatic translation for non-English queries</p>
|
| 51 |
+
<p>Upload PDF or text files to get started!</p>
|
| 52 |
+
</center>
|
| 53 |
+
"""
|
| 54 |
+
CSS = """
|
| 55 |
+
.upload-section {
|
| 56 |
+
max-width: 400px;
|
| 57 |
+
margin: 0 auto;
|
| 58 |
+
padding: 10px;
|
| 59 |
+
border: 2px dashed #ccc;
|
| 60 |
+
border-radius: 10px;
|
| 61 |
+
}
|
| 62 |
+
.upload-button {
|
| 63 |
+
background: #34c759 !important;
|
| 64 |
+
color: white !important;
|
| 65 |
+
border-radius: 25px !important;
|
| 66 |
+
}
|
| 67 |
+
.chatbot-container {
|
| 68 |
+
margin-top: 20px;
|
| 69 |
+
}
|
| 70 |
+
.status-output {
|
| 71 |
+
margin-top: 10px;
|
| 72 |
+
font-size: 14px;
|
| 73 |
+
}
|
| 74 |
+
.processing-info {
|
| 75 |
+
margin-top: 5px;
|
| 76 |
+
font-size: 12px;
|
| 77 |
+
color: #666;
|
| 78 |
+
}
|
| 79 |
+
.info-container {
|
| 80 |
+
margin-top: 10px;
|
| 81 |
+
padding: 10px;
|
| 82 |
+
border-radius: 5px;
|
| 83 |
+
}
|
| 84 |
+
.file-list {
|
| 85 |
+
margin-top: 0;
|
| 86 |
+
max-height: 200px;
|
| 87 |
+
overflow-y: auto;
|
| 88 |
+
padding: 5px;
|
| 89 |
+
border: 1px solid #eee;
|
| 90 |
+
border-radius: 5px;
|
| 91 |
+
}
|
| 92 |
+
.stats-box {
|
| 93 |
+
margin-top: 10px;
|
| 94 |
+
padding: 10px;
|
| 95 |
+
border-radius: 5px;
|
| 96 |
+
font-size: 12px;
|
| 97 |
+
}
|
| 98 |
+
.submit-btn {
|
| 99 |
+
background: #1a73e8 !important;
|
| 100 |
+
color: white !important;
|
| 101 |
+
border-radius: 25px !important;
|
| 102 |
+
margin-left: 10px;
|
| 103 |
+
padding: 5px 10px;
|
| 104 |
+
font-size: 16px;
|
| 105 |
+
}
|
| 106 |
+
.input-row {
|
| 107 |
+
display: flex;
|
| 108 |
+
align-items: center;
|
| 109 |
+
}
|
| 110 |
+
.recording-timer {
|
| 111 |
+
font-size: 12px;
|
| 112 |
+
color: #666;
|
| 113 |
+
text-align: center;
|
| 114 |
+
margin-top: 5px;
|
| 115 |
+
}
|
| 116 |
+
.feature-badge {
|
| 117 |
+
display: inline-block;
|
| 118 |
+
padding: 3px 8px;
|
| 119 |
+
margin: 2px;
|
| 120 |
+
border-radius: 12px;
|
| 121 |
+
font-size: 11px;
|
| 122 |
+
font-weight: bold;
|
| 123 |
+
}
|
| 124 |
+
.badge-rag {
|
| 125 |
+
background: #e3f2fd;
|
| 126 |
+
color: #1976d2;
|
| 127 |
+
}
|
| 128 |
+
.badge-web {
|
| 129 |
+
background: #f3e5f5;
|
| 130 |
+
color: #7b1fa2;
|
| 131 |
+
}
|
| 132 |
+
@media (min-width: 768px) {
|
| 133 |
+
.main-container {
|
| 134 |
+
display: flex;
|
| 135 |
+
justify-content: space-between;
|
| 136 |
+
gap: 20px;
|
| 137 |
+
}
|
| 138 |
+
.upload-section {
|
| 139 |
+
flex: 1;
|
| 140 |
+
max-width: 300px;
|
| 141 |
+
}
|
| 142 |
+
.chatbot-container {
|
| 143 |
+
flex: 2;
|
| 144 |
+
margin-top: 0;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
"""
|
| 148 |
+
|
indexing.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document parsing and indexing functions"""
|
| 2 |
+
import os
|
| 3 |
+
import base64
|
| 4 |
+
import asyncio
|
| 5 |
+
import tempfile
|
| 6 |
+
import time
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
+
from llama_index.core import (
|
| 10 |
+
StorageContext,
|
| 11 |
+
VectorStoreIndex,
|
| 12 |
+
load_index_from_storage,
|
| 13 |
+
Document as LlamaDocument,
|
| 14 |
+
)
|
| 15 |
+
from llama_index.core import Settings
|
| 16 |
+
from llama_index.core.node_parser import (
|
| 17 |
+
HierarchicalNodeParser,
|
| 18 |
+
get_leaf_nodes,
|
| 19 |
+
get_root_nodes,
|
| 20 |
+
)
|
| 21 |
+
from llama_index.core.storage.docstore import SimpleDocumentStore
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from logger import logger
|
| 24 |
+
from mcp import MCP_AVAILABLE, call_agent
|
| 25 |
+
import config
|
| 26 |
+
from models import get_llm_for_rag, get_or_create_embed_model
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import nest_asyncio
|
| 30 |
+
except ImportError:
|
| 31 |
+
nest_asyncio = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
async def parse_document_gemini(file_path: str, file_extension: str) -> str:
|
| 35 |
+
"""Parse document using Gemini MCP"""
|
| 36 |
+
if not MCP_AVAILABLE:
|
| 37 |
+
return ""
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
with open(file_path, 'rb') as f:
|
| 41 |
+
file_content = base64.b64encode(f.read()).decode('utf-8')
|
| 42 |
+
|
| 43 |
+
mime_type_map = {
|
| 44 |
+
'.pdf': 'application/pdf',
|
| 45 |
+
'.doc': 'application/msword',
|
| 46 |
+
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
| 47 |
+
'.txt': 'text/plain',
|
| 48 |
+
'.md': 'text/markdown',
|
| 49 |
+
'.json': 'application/json',
|
| 50 |
+
'.xml': 'application/xml',
|
| 51 |
+
'.csv': 'text/csv'
|
| 52 |
+
}
|
| 53 |
+
mime_type = mime_type_map.get(file_extension, 'application/octet-stream')
|
| 54 |
+
|
| 55 |
+
files = [{
|
| 56 |
+
"content": file_content,
|
| 57 |
+
"type": mime_type
|
| 58 |
+
}]
|
| 59 |
+
|
| 60 |
+
system_prompt = "Extract all text content from the document accurately."
|
| 61 |
+
user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible."
|
| 62 |
+
|
| 63 |
+
result = await call_agent(
|
| 64 |
+
user_prompt=user_prompt,
|
| 65 |
+
system_prompt=system_prompt,
|
| 66 |
+
files=files,
|
| 67 |
+
model=config.GEMINI_MODEL_LITE,
|
| 68 |
+
temperature=0.2
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return result.strip()
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Gemini document parsing error: {e}")
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def extract_text_from_document(file):
|
| 78 |
+
"""Extract text from document using Gemini MCP"""
|
| 79 |
+
file_name = file.name
|
| 80 |
+
file_extension = os.path.splitext(file_name)[1].lower()
|
| 81 |
+
|
| 82 |
+
if file_extension == '.txt':
|
| 83 |
+
text = file.read().decode('utf-8')
|
| 84 |
+
return text, len(text.split()), None
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
| 88 |
+
file.seek(0)
|
| 89 |
+
tmp_file.write(file.read())
|
| 90 |
+
tmp_file_path = tmp_file.name
|
| 91 |
+
|
| 92 |
+
if MCP_AVAILABLE:
|
| 93 |
+
try:
|
| 94 |
+
loop = asyncio.get_event_loop()
|
| 95 |
+
if loop.is_running():
|
| 96 |
+
if nest_asyncio:
|
| 97 |
+
text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension))
|
| 98 |
+
else:
|
| 99 |
+
logger.error("Error in nested async document parsing: nest_asyncio not available")
|
| 100 |
+
text = ""
|
| 101 |
+
else:
|
| 102 |
+
text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension))
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
os.unlink(tmp_file_path)
|
| 106 |
+
except:
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
if text:
|
| 110 |
+
return text, len(text.split()), None
|
| 111 |
+
else:
|
| 112 |
+
return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Gemini MCP document parsing error: {e}")
|
| 115 |
+
try:
|
| 116 |
+
os.unlink(tmp_file_path)
|
| 117 |
+
except:
|
| 118 |
+
pass
|
| 119 |
+
return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}")
|
| 120 |
+
else:
|
| 121 |
+
try:
|
| 122 |
+
os.unlink(tmp_file_path)
|
| 123 |
+
except:
|
| 124 |
+
pass
|
| 125 |
+
return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Error processing document: {e}")
|
| 128 |
+
return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@spaces.GPU(max_duration=120)
|
| 132 |
+
def create_or_update_index(files, request: gr.Request):
|
| 133 |
+
"""Create or update RAG index from uploaded files"""
|
| 134 |
+
if not files:
|
| 135 |
+
return "Please provide files.", ""
|
| 136 |
+
|
| 137 |
+
start_time = time.time()
|
| 138 |
+
user_id = request.session_hash
|
| 139 |
+
save_dir = f"./{user_id}_index"
|
| 140 |
+
|
| 141 |
+
llm = get_llm_for_rag()
|
| 142 |
+
embed_model = get_or_create_embed_model()
|
| 143 |
+
Settings.llm = llm
|
| 144 |
+
Settings.embed_model = embed_model
|
| 145 |
+
file_stats = []
|
| 146 |
+
new_documents = []
|
| 147 |
+
|
| 148 |
+
for file in tqdm(files, desc="Processing files"):
|
| 149 |
+
file_basename = os.path.basename(file.name)
|
| 150 |
+
text, word_count, error = extract_text_from_document(file)
|
| 151 |
+
if error:
|
| 152 |
+
logger.error(f"Error processing file {file_basename}: {str(error)}")
|
| 153 |
+
file_stats.append({
|
| 154 |
+
"name": file_basename,
|
| 155 |
+
"words": 0,
|
| 156 |
+
"status": f"error: {str(error)}"
|
| 157 |
+
})
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
doc = LlamaDocument(
|
| 161 |
+
text=text,
|
| 162 |
+
metadata={
|
| 163 |
+
"file_name": file_basename,
|
| 164 |
+
"word_count": word_count,
|
| 165 |
+
"source": "user_upload"
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
new_documents.append(doc)
|
| 169 |
+
|
| 170 |
+
file_stats.append({
|
| 171 |
+
"name": file_basename,
|
| 172 |
+
"words": word_count,
|
| 173 |
+
"status": "processed"
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
config.global_file_info[file_basename] = {
|
| 177 |
+
"word_count": word_count,
|
| 178 |
+
"processed_at": time.time()
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
node_parser = HierarchicalNodeParser.from_defaults(
|
| 182 |
+
chunk_sizes=[2048, 512, 128],
|
| 183 |
+
chunk_overlap=20
|
| 184 |
+
)
|
| 185 |
+
logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
|
| 186 |
+
new_nodes = node_parser.get_nodes_from_documents(new_documents)
|
| 187 |
+
new_leaf_nodes = get_leaf_nodes(new_nodes)
|
| 188 |
+
new_root_nodes = get_root_nodes(new_nodes)
|
| 189 |
+
logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
|
| 190 |
+
|
| 191 |
+
if os.path.exists(save_dir):
|
| 192 |
+
logger.info(f"Loading existing index from {save_dir}")
|
| 193 |
+
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
| 194 |
+
index = load_index_from_storage(storage_context, settings=Settings)
|
| 195 |
+
docstore = storage_context.docstore
|
| 196 |
+
|
| 197 |
+
docstore.add_documents(new_nodes)
|
| 198 |
+
for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
|
| 199 |
+
index.insert_nodes([node])
|
| 200 |
+
|
| 201 |
+
total_docs = len(docstore.docs)
|
| 202 |
+
logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
|
| 203 |
+
else:
|
| 204 |
+
logger.info("Creating new index")
|
| 205 |
+
docstore = SimpleDocumentStore()
|
| 206 |
+
storage_context = StorageContext.from_defaults(docstore=docstore)
|
| 207 |
+
docstore.add_documents(new_nodes)
|
| 208 |
+
|
| 209 |
+
index = VectorStoreIndex(
|
| 210 |
+
new_leaf_nodes,
|
| 211 |
+
storage_context=storage_context,
|
| 212 |
+
settings=Settings
|
| 213 |
+
)
|
| 214 |
+
total_docs = len(new_documents)
|
| 215 |
+
logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
|
| 216 |
+
|
| 217 |
+
index.storage_context.persist(persist_dir=save_dir)
|
| 218 |
+
|
| 219 |
+
file_list_html = "<div class='file-list'>"
|
| 220 |
+
for stat in file_stats:
|
| 221 |
+
status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
|
| 222 |
+
file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
|
| 223 |
+
file_list_html += "</div>"
|
| 224 |
+
processing_time = time.time() - start_time
|
| 225 |
+
stats_output = f"<div class='stats-box'>"
|
| 226 |
+
stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
|
| 227 |
+
stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
|
| 228 |
+
stats_output += f"✓ Total documents in index: {total_docs}<br>"
|
| 229 |
+
stats_output += f"✓ Index saved to: {save_dir}<br>"
|
| 230 |
+
stats_output += "</div>"
|
| 231 |
+
output_container = f"<div class='info-container'>"
|
| 232 |
+
output_container += file_list_html
|
| 233 |
+
output_container += stats_output
|
| 234 |
+
output_container += "</div>"
|
| 235 |
+
return f"Successfully indexed {len(files)} files.", output_container
|
| 236 |
+
|
logger.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging configuration and custom handlers"""
|
| 2 |
+
import logging
|
| 3 |
+
import threading
|
| 4 |
+
from transformers import logging as hf_logging
|
| 5 |
+
|
| 6 |
+
# Set logging to INFO level for cleaner output
|
| 7 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
# Custom logger handler to capture agentic thoughts
|
| 11 |
+
class ThoughtCaptureHandler(logging.Handler):
|
| 12 |
+
"""Custom handler to capture internal thoughts from MedSwin and supervisor"""
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.thoughts = []
|
| 16 |
+
self.lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
def emit(self, record):
|
| 19 |
+
"""Capture log messages that contain agentic thoughts"""
|
| 20 |
+
try:
|
| 21 |
+
msg = self.format(record)
|
| 22 |
+
# Only capture messages from GEMINI SUPERVISOR or MEDSWIN
|
| 23 |
+
if "[GEMINI SUPERVISOR]" in msg or "[MEDSWIN]" in msg or "[MAC]" in msg:
|
| 24 |
+
# Remove timestamp and logger name for cleaner display
|
| 25 |
+
parts = msg.split(" - ", 3)
|
| 26 |
+
if len(parts) >= 4:
|
| 27 |
+
clean_msg = parts[-1]
|
| 28 |
+
else:
|
| 29 |
+
clean_msg = msg
|
| 30 |
+
with self.lock:
|
| 31 |
+
self.thoughts.append(clean_msg)
|
| 32 |
+
except Exception:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def get_thoughts(self):
|
| 36 |
+
"""Get all captured thoughts as a formatted string"""
|
| 37 |
+
with self.lock:
|
| 38 |
+
return "\n".join(self.thoughts)
|
| 39 |
+
|
| 40 |
+
def clear(self):
|
| 41 |
+
"""Clear captured thoughts"""
|
| 42 |
+
with self.lock:
|
| 43 |
+
self.thoughts = []
|
| 44 |
+
|
| 45 |
+
# Set MCP client logging to WARNING to reduce noise
|
| 46 |
+
mcp_client_logger = logging.getLogger("mcp.client")
|
| 47 |
+
mcp_client_logger.setLevel(logging.WARNING)
|
| 48 |
+
hf_logging.set_verbosity_error()
|
| 49 |
+
|
mcp.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MCP session management and tool caching"""
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
from logger import logger
|
| 6 |
+
import config
|
| 7 |
+
|
| 8 |
+
# MCP imports
|
| 9 |
+
MCP_CLIENT_INFO = None
|
| 10 |
+
try:
|
| 11 |
+
from mcp import ClientSession, StdioServerParameters
|
| 12 |
+
from mcp import types as mcp_types
|
| 13 |
+
from mcp.client.stdio import stdio_client
|
| 14 |
+
try:
|
| 15 |
+
import nest_asyncio
|
| 16 |
+
nest_asyncio.apply()
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
MCP_AVAILABLE = True
|
| 20 |
+
MCP_CLIENT_INFO = mcp_types.Implementation(
|
| 21 |
+
name="MedLLM-Agent",
|
| 22 |
+
version=os.environ.get("SPACE_VERSION", "local"),
|
| 23 |
+
)
|
| 24 |
+
except ImportError as e:
|
| 25 |
+
logger.warning(f"MCP SDK not available: {e}")
|
| 26 |
+
MCP_AVAILABLE = False
|
| 27 |
+
MCP_CLIENT_INFO = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def get_mcp_session():
|
| 31 |
+
"""Get or create MCP client session with proper context management"""
|
| 32 |
+
if not MCP_AVAILABLE:
|
| 33 |
+
logger.warning("MCP not available - SDK not installed")
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
if config.global_mcp_session is not None:
|
| 37 |
+
return config.global_mcp_session
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
mcp_env = os.environ.copy()
|
| 41 |
+
if config.GEMINI_API_KEY:
|
| 42 |
+
mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
|
| 43 |
+
else:
|
| 44 |
+
logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
|
| 45 |
+
|
| 46 |
+
if os.environ.get("GEMINI_MODEL"):
|
| 47 |
+
mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
|
| 48 |
+
if os.environ.get("GEMINI_TIMEOUT"):
|
| 49 |
+
mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
|
| 50 |
+
if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
|
| 51 |
+
mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
|
| 52 |
+
if os.environ.get("GEMINI_TEMPERATURE"):
|
| 53 |
+
mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
|
| 54 |
+
|
| 55 |
+
logger.info("Creating MCP client session...")
|
| 56 |
+
|
| 57 |
+
server_params = StdioServerParameters(
|
| 58 |
+
command=config.MCP_SERVER_COMMAND,
|
| 59 |
+
args=config.MCP_SERVER_ARGS,
|
| 60 |
+
env=mcp_env
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
stdio_ctx = stdio_client(server_params)
|
| 64 |
+
read, write = await stdio_ctx.__aenter__()
|
| 65 |
+
|
| 66 |
+
session = ClientSession(
|
| 67 |
+
read,
|
| 68 |
+
write,
|
| 69 |
+
client_info=MCP_CLIENT_INFO,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
await session.__aenter__()
|
| 74 |
+
init_result = await session.initialize()
|
| 75 |
+
server_info = getattr(init_result, "serverInfo", None)
|
| 76 |
+
server_name = getattr(server_info, "name", "unknown")
|
| 77 |
+
server_version = getattr(server_info, "version", "unknown")
|
| 78 |
+
logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
error_msg = str(e)
|
| 81 |
+
error_type = type(e).__name__
|
| 82 |
+
logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
|
| 83 |
+
try:
|
| 84 |
+
await session.__aexit__(None, None, None)
|
| 85 |
+
except Exception:
|
| 86 |
+
pass
|
| 87 |
+
try:
|
| 88 |
+
await stdio_ctx.__aexit__(None, None, None)
|
| 89 |
+
except Exception:
|
| 90 |
+
pass
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
config.global_mcp_session = session
|
| 94 |
+
config.global_mcp_stdio_ctx = stdio_ctx
|
| 95 |
+
logger.info("✅ MCP client session created successfully")
|
| 96 |
+
return session
|
| 97 |
+
except Exception as e:
|
| 98 |
+
error_type = type(e).__name__
|
| 99 |
+
error_msg = str(e)
|
| 100 |
+
logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
|
| 101 |
+
config.global_mcp_session = None
|
| 102 |
+
config.global_mcp_stdio_ctx = None
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def invalidate_mcp_tools_cache():
|
| 107 |
+
"""Invalidate cached MCP tool metadata"""
|
| 108 |
+
config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
async def get_cached_mcp_tools(force_refresh: bool = False):
|
| 112 |
+
"""Return cached MCP tools list to avoid repeated list_tools calls"""
|
| 113 |
+
if not MCP_AVAILABLE:
|
| 114 |
+
return []
|
| 115 |
+
|
| 116 |
+
now = time.time()
|
| 117 |
+
if (
|
| 118 |
+
not force_refresh
|
| 119 |
+
and config.global_mcp_tools_cache["tools"]
|
| 120 |
+
and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
|
| 121 |
+
):
|
| 122 |
+
return config.global_mcp_tools_cache["tools"]
|
| 123 |
+
|
| 124 |
+
session = await get_mcp_session()
|
| 125 |
+
if session is None:
|
| 126 |
+
return []
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
tools_resp = await session.list_tools()
|
| 130 |
+
tools_list = list(getattr(tools_resp, "tools", []) or [])
|
| 131 |
+
config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
|
| 132 |
+
return tools_list
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Failed to refresh MCP tools: {e}")
|
| 135 |
+
invalidate_mcp_tools_cache()
|
| 136 |
+
return []
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
|
| 140 |
+
"""Call Gemini MCP generate_content tool"""
|
| 141 |
+
if not MCP_AVAILABLE:
|
| 142 |
+
logger.warning("MCP not available for Gemini call")
|
| 143 |
+
return ""
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
session = await get_mcp_session()
|
| 147 |
+
if session is None:
|
| 148 |
+
logger.warning("Failed to get MCP session for Gemini call")
|
| 149 |
+
return ""
|
| 150 |
+
|
| 151 |
+
tools = await get_cached_mcp_tools()
|
| 152 |
+
if not tools:
|
| 153 |
+
tools = await get_cached_mcp_tools(force_refresh=True)
|
| 154 |
+
if not tools:
|
| 155 |
+
logger.error("Unable to obtain MCP tool catalog for Gemini calls")
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
generate_tool = None
|
| 159 |
+
for tool in tools:
|
| 160 |
+
if tool.name == "generate_content" or "generate_content" in tool.name.lower():
|
| 161 |
+
generate_tool = tool
|
| 162 |
+
logger.info(f"Found Gemini MCP tool: {tool.name}")
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
if not generate_tool:
|
| 166 |
+
logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
|
| 167 |
+
invalidate_mcp_tools_cache()
|
| 168 |
+
return ""
|
| 169 |
+
|
| 170 |
+
arguments = {
|
| 171 |
+
"user_prompt": user_prompt
|
| 172 |
+
}
|
| 173 |
+
if system_prompt:
|
| 174 |
+
arguments["system_prompt"] = system_prompt
|
| 175 |
+
if files:
|
| 176 |
+
arguments["files"] = files
|
| 177 |
+
if model:
|
| 178 |
+
arguments["model"] = model
|
| 179 |
+
if temperature is not None:
|
| 180 |
+
arguments["temperature"] = temperature
|
| 181 |
+
|
| 182 |
+
result = await session.call_tool(generate_tool.name, arguments=arguments)
|
| 183 |
+
|
| 184 |
+
if hasattr(result, 'content') and result.content:
|
| 185 |
+
for item in result.content:
|
| 186 |
+
if hasattr(item, 'text'):
|
| 187 |
+
response_text = item.text.strip()
|
| 188 |
+
return response_text
|
| 189 |
+
logger.warning("⚠️ Gemini MCP returned empty or invalid result")
|
| 190 |
+
return ""
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Gemini MCP call error: {e}")
|
| 193 |
+
return ""
|
| 194 |
+
|
models.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model initialization and management"""
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
from llama_index.llms.huggingface import HuggingFaceLLM
|
| 5 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 6 |
+
from logger import logger
|
| 7 |
+
import config
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from TTS.api import TTS
|
| 11 |
+
TTS_AVAILABLE = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
TTS_AVAILABLE = False
|
| 14 |
+
TTS = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def initialize_medical_model(model_name: str):
|
| 18 |
+
"""Initialize medical model (MedSwin) - download on demand"""
|
| 19 |
+
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
|
| 20 |
+
logger.info(f"Initializing medical model: {model_name}...")
|
| 21 |
+
model_path = config.MEDSWIN_MODELS[model_name]
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
+
model_path,
|
| 25 |
+
device_map="auto",
|
| 26 |
+
trust_remote_code=True,
|
| 27 |
+
token=config.HF_TOKEN,
|
| 28 |
+
torch_dtype=torch.float16
|
| 29 |
+
)
|
| 30 |
+
config.global_medical_models[model_name] = model
|
| 31 |
+
config.global_medical_tokenizers[model_name] = tokenizer
|
| 32 |
+
logger.info(f"Medical model {model_name} initialized successfully")
|
| 33 |
+
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def initialize_tts_model():
|
| 37 |
+
"""Initialize TTS model for text-to-speech"""
|
| 38 |
+
if not TTS_AVAILABLE:
|
| 39 |
+
logger.warning("TTS library not installed. TTS features will be disabled.")
|
| 40 |
+
return None
|
| 41 |
+
if config.global_tts_model is None:
|
| 42 |
+
try:
|
| 43 |
+
logger.info("Initializing TTS model for voice generation...")
|
| 44 |
+
config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
|
| 45 |
+
logger.info("TTS model initialized successfully")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.warning(f"TTS model initialization failed: {e}")
|
| 48 |
+
logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
|
| 49 |
+
config.global_tts_model = None
|
| 50 |
+
return config.global_tts_model
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_or_create_embed_model():
|
| 54 |
+
"""Reuse embedding model to avoid reloading weights each request"""
|
| 55 |
+
if config.global_embed_model is None:
|
| 56 |
+
logger.info("Initializing shared embedding model for RAG retrieval...")
|
| 57 |
+
config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
|
| 58 |
+
return config.global_embed_model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
|
| 62 |
+
"""Get LLM for RAG indexing (uses medical model)"""
|
| 63 |
+
medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
|
| 64 |
+
|
| 65 |
+
return HuggingFaceLLM(
|
| 66 |
+
context_window=4096,
|
| 67 |
+
max_new_tokens=max_new_tokens,
|
| 68 |
+
tokenizer=medical_tokenizer,
|
| 69 |
+
model=medical_model_obj,
|
| 70 |
+
generate_kwargs={
|
| 71 |
+
"do_sample": True,
|
| 72 |
+
"temperature": temperature,
|
| 73 |
+
"top_k": top_k,
|
| 74 |
+
"top_p": top_p
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
pipeline.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main chat pipeline - stream_chat function"""
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
+
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
| 10 |
+
from llama_index.core import Settings
|
| 11 |
+
from llama_index.core.retrievers import AutoMergingRetriever
|
| 12 |
+
from logger import logger, ThoughtCaptureHandler
|
| 13 |
+
from models import initialize_medical_model, get_or_create_embed_model
|
| 14 |
+
from utils import detect_language, translate_text, format_url_as_domain
|
| 15 |
+
from search import search_web, summarize_web_content
|
| 16 |
+
from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
|
| 17 |
+
from supervisor import (
|
| 18 |
+
gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
|
| 19 |
+
gemini_supervisor_rag_brainstorm, execute_medswin_task,
|
| 20 |
+
gemini_supervisor_synthesize, gemini_supervisor_challenge,
|
| 21 |
+
gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@spaces.GPU(max_duration=120)
|
| 26 |
+
def stream_chat(
|
| 27 |
+
message: str,
|
| 28 |
+
history: list,
|
| 29 |
+
system_prompt: str,
|
| 30 |
+
temperature: float,
|
| 31 |
+
max_new_tokens: int,
|
| 32 |
+
top_p: float,
|
| 33 |
+
top_k: int,
|
| 34 |
+
penalty: float,
|
| 35 |
+
retriever_k: int,
|
| 36 |
+
merge_threshold: float,
|
| 37 |
+
use_rag: bool,
|
| 38 |
+
medical_model: str,
|
| 39 |
+
use_web_search: bool,
|
| 40 |
+
disable_agentic_reasoning: bool,
|
| 41 |
+
show_thoughts: bool,
|
| 42 |
+
request: gr.Request
|
| 43 |
+
):
|
| 44 |
+
"""Main chat pipeline implementing MAC architecture"""
|
| 45 |
+
if not request:
|
| 46 |
+
yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], ""
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
thought_handler = None
|
| 50 |
+
if show_thoughts:
|
| 51 |
+
thought_handler = ThoughtCaptureHandler()
|
| 52 |
+
thought_handler.setLevel(logging.INFO)
|
| 53 |
+
thought_handler.clear()
|
| 54 |
+
logger.addHandler(thought_handler)
|
| 55 |
+
|
| 56 |
+
session_start = time.time()
|
| 57 |
+
soft_timeout = 100
|
| 58 |
+
hard_timeout = 118
|
| 59 |
+
|
| 60 |
+
def elapsed():
|
| 61 |
+
return time.time() - session_start
|
| 62 |
+
|
| 63 |
+
user_id = request.session_hash
|
| 64 |
+
index_dir = f"./{user_id}_index"
|
| 65 |
+
has_rag_index = os.path.exists(index_dir)
|
| 66 |
+
|
| 67 |
+
original_lang = detect_language(message)
|
| 68 |
+
original_message = message
|
| 69 |
+
needs_translation = original_lang != "en"
|
| 70 |
+
|
| 71 |
+
pipeline_diagnostics = {
|
| 72 |
+
"reasoning": None,
|
| 73 |
+
"plan": None,
|
| 74 |
+
"strategy_decisions": [],
|
| 75 |
+
"stage_metrics": {},
|
| 76 |
+
"search": {"strategies": [], "total_results": 0}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def record_stage(stage_name: str, start_time: float):
|
| 80 |
+
pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
|
| 81 |
+
|
| 82 |
+
translation_stage_start = time.time()
|
| 83 |
+
if needs_translation:
|
| 84 |
+
logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
|
| 85 |
+
message = translate_text(message, target_lang="en", source_lang=original_lang)
|
| 86 |
+
logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
|
| 87 |
+
record_stage("translation", translation_stage_start)
|
| 88 |
+
|
| 89 |
+
final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
|
| 90 |
+
final_use_web_search = use_web_search and not disable_agentic_reasoning
|
| 91 |
+
|
| 92 |
+
plan = None
|
| 93 |
+
if not disable_agentic_reasoning:
|
| 94 |
+
reasoning_stage_start = time.time()
|
| 95 |
+
reasoning = autonomous_reasoning(message, history)
|
| 96 |
+
record_stage("autonomous_reasoning", reasoning_stage_start)
|
| 97 |
+
pipeline_diagnostics["reasoning"] = reasoning
|
| 98 |
+
plan = create_execution_plan(reasoning, message, has_rag_index)
|
| 99 |
+
pipeline_diagnostics["plan"] = plan
|
| 100 |
+
execution_strategy = autonomous_execution_strategy(
|
| 101 |
+
reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if final_use_rag and not reasoning.get("requires_rag", True):
|
| 105 |
+
final_use_rag = False
|
| 106 |
+
pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
|
| 107 |
+
elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
|
| 108 |
+
pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
|
| 109 |
+
|
| 110 |
+
if final_use_web_search and not reasoning.get("requires_web_search", False):
|
| 111 |
+
final_use_web_search = False
|
| 112 |
+
pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
|
| 113 |
+
elif not final_use_web_search and reasoning.get("requires_web_search", False):
|
| 114 |
+
if not use_web_search:
|
| 115 |
+
pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
|
| 116 |
+
else:
|
| 117 |
+
pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
|
| 118 |
+
else:
|
| 119 |
+
pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
|
| 120 |
+
|
| 121 |
+
if disable_agentic_reasoning:
|
| 122 |
+
logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
|
| 123 |
+
breakdown = {
|
| 124 |
+
"sub_topics": [
|
| 125 |
+
{"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
|
| 126 |
+
],
|
| 127 |
+
"strategy": "Direct answer",
|
| 128 |
+
"exploration_note": "Direct mode - no breakdown"
|
| 129 |
+
}
|
| 130 |
+
else:
|
| 131 |
+
logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
|
| 132 |
+
breakdown = gemini_supervisor_breakdown(message, final_use_rag, final_use_web_search, elapsed(), max_duration=120)
|
| 133 |
+
logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
|
| 134 |
+
|
| 135 |
+
search_contexts = []
|
| 136 |
+
web_urls = []
|
| 137 |
+
if final_use_web_search:
|
| 138 |
+
search_stage_start = time.time()
|
| 139 |
+
logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
|
| 140 |
+
search_strategies = gemini_supervisor_search_strategies(message, elapsed())
|
| 141 |
+
|
| 142 |
+
all_search_results = []
|
| 143 |
+
strategy_jobs = []
|
| 144 |
+
for strategy in search_strategies.get("search_strategies", [])[:4]:
|
| 145 |
+
search_query = strategy.get("strategy", message)
|
| 146 |
+
target_sources = strategy.get("target_sources", 2)
|
| 147 |
+
strategy_jobs.append({
|
| 148 |
+
"query": search_query,
|
| 149 |
+
"target_sources": target_sources,
|
| 150 |
+
"meta": strategy
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
def execute_search(job):
|
| 154 |
+
job_start = time.time()
|
| 155 |
+
try:
|
| 156 |
+
results = search_web(job["query"], max_results=job["target_sources"])
|
| 157 |
+
duration = time.time() - job_start
|
| 158 |
+
return results, duration, None
|
| 159 |
+
except Exception as exc:
|
| 160 |
+
return [], time.time() - job_start, exc
|
| 161 |
+
|
| 162 |
+
def record_search_diag(job, duration, results_count, error=None):
|
| 163 |
+
entry = {
|
| 164 |
+
"query": job["query"],
|
| 165 |
+
"target_sources": job["target_sources"],
|
| 166 |
+
"duration": round(duration, 3),
|
| 167 |
+
"results": results_count
|
| 168 |
+
}
|
| 169 |
+
if error:
|
| 170 |
+
entry["error"] = str(error)
|
| 171 |
+
pipeline_diagnostics["search"]["strategies"].append(entry)
|
| 172 |
+
|
| 173 |
+
if strategy_jobs:
|
| 174 |
+
max_workers = min(len(strategy_jobs), 4)
|
| 175 |
+
if len(strategy_jobs) > 1:
|
| 176 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 177 |
+
future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
|
| 178 |
+
for future in concurrent.futures.as_completed(future_map):
|
| 179 |
+
job = future_map[future]
|
| 180 |
+
try:
|
| 181 |
+
results, duration, error = future.result()
|
| 182 |
+
except Exception as exc:
|
| 183 |
+
results, duration, error = [], 0.0, exc
|
| 184 |
+
record_search_diag(job, duration, len(results), error)
|
| 185 |
+
if not error and results:
|
| 186 |
+
all_search_results.extend(results)
|
| 187 |
+
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
|
| 188 |
+
else:
|
| 189 |
+
job = strategy_jobs[0]
|
| 190 |
+
results, duration, error = execute_search(job)
|
| 191 |
+
record_search_diag(job, duration, len(results), error)
|
| 192 |
+
if not error and results:
|
| 193 |
+
all_search_results.extend(results)
|
| 194 |
+
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
|
| 195 |
+
else:
|
| 196 |
+
pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
|
| 197 |
+
|
| 198 |
+
pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
|
| 199 |
+
|
| 200 |
+
if all_search_results:
|
| 201 |
+
logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
|
| 202 |
+
search_summary = summarize_web_content(all_search_results, message)
|
| 203 |
+
if search_summary:
|
| 204 |
+
search_contexts.append(search_summary)
|
| 205 |
+
logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
|
| 206 |
+
record_stage("web_search", search_stage_start)
|
| 207 |
+
|
| 208 |
+
rag_contexts = []
|
| 209 |
+
if final_use_rag and has_rag_index:
|
| 210 |
+
rag_stage_start = time.time()
|
| 211 |
+
if elapsed() >= soft_timeout - 10:
|
| 212 |
+
logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
|
| 213 |
+
final_use_rag = False
|
| 214 |
+
else:
|
| 215 |
+
logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
|
| 216 |
+
embed_model = get_or_create_embed_model()
|
| 217 |
+
Settings.embed_model = embed_model
|
| 218 |
+
storage_context = StorageContext.from_defaults(persist_dir=index_dir)
|
| 219 |
+
index = load_index_from_storage(storage_context, settings=Settings)
|
| 220 |
+
base_retriever = index.as_retriever(similarity_top_k=retriever_k)
|
| 221 |
+
auto_merging_retriever = AutoMergingRetriever(
|
| 222 |
+
base_retriever,
|
| 223 |
+
storage_context=storage_context,
|
| 224 |
+
simple_ratio_thresh=merge_threshold,
|
| 225 |
+
verbose=False
|
| 226 |
+
)
|
| 227 |
+
merged_nodes = auto_merging_retriever.retrieve(message)
|
| 228 |
+
retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
|
| 229 |
+
logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
|
| 230 |
+
|
| 231 |
+
logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
|
| 232 |
+
rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
|
| 233 |
+
rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
|
| 234 |
+
logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
|
| 235 |
+
record_stage("rag_retrieval", rag_stage_start)
|
| 236 |
+
|
| 237 |
+
medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
|
| 238 |
+
|
| 239 |
+
base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
|
| 240 |
+
|
| 241 |
+
combined_context = ""
|
| 242 |
+
if rag_contexts:
|
| 243 |
+
combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4])
|
| 244 |
+
if search_contexts:
|
| 245 |
+
if combined_context:
|
| 246 |
+
combined_context += "\n\n"
|
| 247 |
+
combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
|
| 248 |
+
|
| 249 |
+
logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
|
| 250 |
+
medswin_answers = []
|
| 251 |
+
|
| 252 |
+
updated_history = history + [
|
| 253 |
+
{"role": "user", "content": original_message},
|
| 254 |
+
{"role": "assistant", "content": ""}
|
| 255 |
+
]
|
| 256 |
+
thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
|
| 257 |
+
yield updated_history, thoughts_text
|
| 258 |
+
|
| 259 |
+
medswin_stage_start = time.time()
|
| 260 |
+
for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
|
| 261 |
+
if elapsed() >= hard_timeout - 5:
|
| 262 |
+
logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
task_instruction = sub_topic.get("instruction", "")
|
| 266 |
+
topic_name = sub_topic.get("topic", f"Topic {idx}")
|
| 267 |
+
priority = sub_topic.get("priority", "medium")
|
| 268 |
+
|
| 269 |
+
logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
|
| 270 |
+
|
| 271 |
+
task_context = combined_context
|
| 272 |
+
if len(rag_contexts) > 1 and idx <= len(rag_contexts):
|
| 273 |
+
task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
task_answer = execute_medswin_task(
|
| 277 |
+
medical_model_obj=medical_model_obj,
|
| 278 |
+
medical_tokenizer=medical_tokenizer,
|
| 279 |
+
task_instruction=task_instruction,
|
| 280 |
+
context=task_context if task_context else "",
|
| 281 |
+
system_prompt_base=base_system_prompt,
|
| 282 |
+
temperature=temperature,
|
| 283 |
+
max_new_tokens=min(max_new_tokens, 800),
|
| 284 |
+
top_p=top_p,
|
| 285 |
+
top_k=top_k,
|
| 286 |
+
penalty=penalty
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
formatted_answer = f"## {topic_name}\n\n{task_answer}"
|
| 290 |
+
medswin_answers.append(formatted_answer)
|
| 291 |
+
logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
|
| 292 |
+
|
| 293 |
+
partial_final = "\n\n".join(medswin_answers)
|
| 294 |
+
updated_history[-1]["content"] = partial_final
|
| 295 |
+
thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
|
| 296 |
+
yield updated_history, thoughts_text
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
|
| 300 |
+
continue
|
| 301 |
+
record_stage("medswin_tasks", medswin_stage_start)
|
| 302 |
+
|
| 303 |
+
logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
|
| 304 |
+
raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers]
|
| 305 |
+
synthesis_stage_start = time.time()
|
| 306 |
+
final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown)
|
| 307 |
+
record_stage("synthesis", synthesis_stage_start)
|
| 308 |
+
|
| 309 |
+
if not final_answer or len(final_answer.strip()) < 50:
|
| 310 |
+
logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation")
|
| 311 |
+
final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
|
| 312 |
+
|
| 313 |
+
if "|" in final_answer and "---" in final_answer:
|
| 314 |
+
logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
|
| 315 |
+
lines = final_answer.split('\n')
|
| 316 |
+
cleaned_lines = []
|
| 317 |
+
for line in lines:
|
| 318 |
+
if '|' in line and '---' not in line:
|
| 319 |
+
cells = [cell.strip() for cell in line.split('|') if cell.strip()]
|
| 320 |
+
if cells:
|
| 321 |
+
cleaned_lines.append(f"- {' / '.join(cells)}")
|
| 322 |
+
elif '---' not in line:
|
| 323 |
+
cleaned_lines.append(line)
|
| 324 |
+
final_answer = '\n'.join(cleaned_lines)
|
| 325 |
+
|
| 326 |
+
max_challenge_iterations = 2
|
| 327 |
+
challenge_iteration = 0
|
| 328 |
+
challenge_stage_start = time.time()
|
| 329 |
+
|
| 330 |
+
while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
|
| 331 |
+
challenge_iteration += 1
|
| 332 |
+
logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...")
|
| 333 |
+
|
| 334 |
+
evaluation = gemini_supervisor_challenge(message, final_answer, raw_medswin_answers, rag_contexts, search_contexts)
|
| 335 |
+
|
| 336 |
+
if evaluation.get("is_optimal", False):
|
| 337 |
+
logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)")
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
enhancement_instructions = evaluation.get("enhancement_instructions", "")
|
| 341 |
+
if not enhancement_instructions:
|
| 342 |
+
logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal")
|
| 343 |
+
break
|
| 344 |
+
|
| 345 |
+
logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...")
|
| 346 |
+
enhanced_answer = gemini_supervisor_enhance_answer(
|
| 347 |
+
message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8:
|
| 351 |
+
final_answer = enhanced_answer
|
| 352 |
+
logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)")
|
| 353 |
+
else:
|
| 354 |
+
logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
|
| 355 |
+
break
|
| 356 |
+
record_stage("challenge_loop", challenge_stage_start)
|
| 357 |
+
|
| 358 |
+
if final_use_web_search and elapsed() < soft_timeout - 10:
|
| 359 |
+
logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
|
| 360 |
+
clarity_stage_start = time.time()
|
| 361 |
+
clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search)
|
| 362 |
+
record_stage("clarity_check", clarity_stage_start)
|
| 363 |
+
|
| 364 |
+
if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
|
| 365 |
+
logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
|
| 366 |
+
additional_search_results = []
|
| 367 |
+
followup_stage_start = time.time()
|
| 368 |
+
for search_query in clarity_check.get("search_queries", [])[:3]:
|
| 369 |
+
if elapsed() >= soft_timeout - 5:
|
| 370 |
+
break
|
| 371 |
+
extra_start = time.time()
|
| 372 |
+
results = search_web(search_query, max_results=2)
|
| 373 |
+
extra_duration = time.time() - extra_start
|
| 374 |
+
pipeline_diagnostics["search"]["strategies"].append({
|
| 375 |
+
"query": search_query,
|
| 376 |
+
"target_sources": 2,
|
| 377 |
+
"duration": round(extra_duration, 3),
|
| 378 |
+
"results": len(results),
|
| 379 |
+
"type": "followup"
|
| 380 |
+
})
|
| 381 |
+
additional_search_results.extend(results)
|
| 382 |
+
web_urls.extend([r.get('url', '') for r in results if r.get('url')])
|
| 383 |
+
|
| 384 |
+
if additional_search_results:
|
| 385 |
+
pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
|
| 386 |
+
logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
|
| 387 |
+
additional_summary = summarize_web_content(additional_search_results, message)
|
| 388 |
+
if additional_summary:
|
| 389 |
+
search_contexts.append(additional_summary)
|
| 390 |
+
logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...")
|
| 391 |
+
enhanced_with_search = gemini_supervisor_enhance_answer(
|
| 392 |
+
message, final_answer,
|
| 393 |
+
f"Incorporate the following additional information from web search: {additional_summary}",
|
| 394 |
+
raw_medswin_answers, rag_contexts, search_contexts
|
| 395 |
+
)
|
| 396 |
+
if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
|
| 397 |
+
final_answer = enhanced_with_search
|
| 398 |
+
logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
|
| 399 |
+
record_stage("followup_search", followup_stage_start)
|
| 400 |
+
|
| 401 |
+
citations_text = ""
|
| 402 |
+
|
| 403 |
+
if needs_translation and final_answer:
|
| 404 |
+
logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
|
| 405 |
+
final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
|
| 406 |
+
|
| 407 |
+
if web_urls:
|
| 408 |
+
unique_urls = list(dict.fromkeys(web_urls))
|
| 409 |
+
citation_links = []
|
| 410 |
+
for url in unique_urls[:5]:
|
| 411 |
+
domain = format_url_as_domain(url)
|
| 412 |
+
if domain:
|
| 413 |
+
citation_links.append(f"[{domain}]({url})")
|
| 414 |
+
|
| 415 |
+
if citation_links:
|
| 416 |
+
citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
|
| 417 |
+
|
| 418 |
+
speaker_icon = ' 🔊'
|
| 419 |
+
final_answer_with_metadata = final_answer + citations_text + speaker_icon
|
| 420 |
+
|
| 421 |
+
updated_history[-1]["content"] = final_answer_with_metadata
|
| 422 |
+
thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
|
| 423 |
+
yield updated_history, thoughts_text
|
| 424 |
+
|
| 425 |
+
if thought_handler:
|
| 426 |
+
logger.removeHandler(thought_handler)
|
| 427 |
+
|
| 428 |
+
diag_summary = {
|
| 429 |
+
"stage_metrics": pipeline_diagnostics["stage_metrics"],
|
| 430 |
+
"decisions": pipeline_diagnostics["strategy_decisions"],
|
| 431 |
+
"search": pipeline_diagnostics["search"],
|
| 432 |
+
}
|
| 433 |
+
try:
|
| 434 |
+
logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
|
| 435 |
+
except Exception:
|
| 436 |
+
logger.info(f"[MAC] Diagnostics summary (non-serializable)")
|
| 437 |
+
logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
|
| 438 |
+
|
reasoning.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Autonomous reasoning and execution planning"""
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
from logger import logger
|
| 5 |
+
from mcp import MCP_AVAILABLE, call_agent
|
| 6 |
+
from config import GEMINI_MODEL
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import nest_asyncio
|
| 10 |
+
except ImportError:
|
| 11 |
+
nest_asyncio = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
async def autonomous_reasoning_gemini(query: str) -> dict:
|
| 15 |
+
"""Autonomous reasoning using Gemini MCP"""
|
| 16 |
+
reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
|
| 17 |
+
Query: "{query}"
|
| 18 |
+
Analyze:
|
| 19 |
+
1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
|
| 20 |
+
2. Complexity: (simple, moderate, complex, multi_faceted)
|
| 21 |
+
3. Information Needs: What specific information is required?
|
| 22 |
+
4. Requires RAG: (yes/no) - Does this need document context?
|
| 23 |
+
5. Requires Web Search: (yes/no) - Does this need current/updated information?
|
| 24 |
+
6. Sub-questions: Break down into key sub-questions if complex
|
| 25 |
+
Respond in JSON format:
|
| 26 |
+
{{
|
| 27 |
+
"query_type": "...",
|
| 28 |
+
"complexity": "...",
|
| 29 |
+
"information_needs": ["..."],
|
| 30 |
+
"requires_rag": true/false,
|
| 31 |
+
"requires_web_search": true/false,
|
| 32 |
+
"sub_questions": ["..."]
|
| 33 |
+
}}"""
|
| 34 |
+
|
| 35 |
+
system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses."
|
| 36 |
+
|
| 37 |
+
response = await call_agent(
|
| 38 |
+
user_prompt=reasoning_prompt,
|
| 39 |
+
system_prompt=system_prompt,
|
| 40 |
+
model=GEMINI_MODEL,
|
| 41 |
+
temperature=0.3
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
json_start = response.find('{')
|
| 46 |
+
json_end = response.rfind('}') + 1
|
| 47 |
+
if json_start >= 0 and json_end > json_start:
|
| 48 |
+
reasoning = json.loads(response[json_start:json_end])
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("No JSON found")
|
| 51 |
+
except:
|
| 52 |
+
reasoning = {
|
| 53 |
+
"query_type": "general_info",
|
| 54 |
+
"complexity": "moderate",
|
| 55 |
+
"information_needs": ["medical information"],
|
| 56 |
+
"requires_rag": True,
|
| 57 |
+
"requires_web_search": False,
|
| 58 |
+
"sub_questions": [query]
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
logger.info(f"Reasoning analysis: {reasoning}")
|
| 62 |
+
return reasoning
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def autonomous_reasoning(query: str, history: list) -> dict:
|
| 66 |
+
"""Autonomous reasoning: Analyze query complexity, intent, and information needs"""
|
| 67 |
+
if not MCP_AVAILABLE:
|
| 68 |
+
logger.warning("⚠️ Gemini MCP not available for reasoning, using fallback")
|
| 69 |
+
return {
|
| 70 |
+
"query_type": "general_info",
|
| 71 |
+
"complexity": "moderate",
|
| 72 |
+
"information_needs": ["medical information"],
|
| 73 |
+
"requires_rag": True,
|
| 74 |
+
"requires_web_search": False,
|
| 75 |
+
"sub_questions": [query]
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
loop = asyncio.get_event_loop()
|
| 80 |
+
if loop.is_running():
|
| 81 |
+
if nest_asyncio:
|
| 82 |
+
reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query))
|
| 83 |
+
return reasoning
|
| 84 |
+
else:
|
| 85 |
+
logger.error("Error in nested async reasoning: nest_asyncio not available")
|
| 86 |
+
else:
|
| 87 |
+
reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query))
|
| 88 |
+
return reasoning
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Gemini MCP reasoning error: {e}")
|
| 91 |
+
|
| 92 |
+
logger.warning("⚠️ Falling back to default reasoning")
|
| 93 |
+
return {
|
| 94 |
+
"query_type": "general_info",
|
| 95 |
+
"complexity": "moderate",
|
| 96 |
+
"information_needs": ["medical information"],
|
| 97 |
+
"requires_rag": True,
|
| 98 |
+
"requires_web_search": False,
|
| 99 |
+
"sub_questions": [query]
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict:
|
| 104 |
+
"""Planning: Create multi-step execution plan based on reasoning analysis"""
|
| 105 |
+
plan = {
|
| 106 |
+
"steps": [],
|
| 107 |
+
"strategy": "sequential",
|
| 108 |
+
"iterations": 1
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
if reasoning["complexity"] in ["complex", "multi_faceted"]:
|
| 112 |
+
plan["strategy"] = "iterative"
|
| 113 |
+
plan["iterations"] = 2
|
| 114 |
+
|
| 115 |
+
plan["steps"].append({
|
| 116 |
+
"step": 1,
|
| 117 |
+
"action": "detect_language",
|
| 118 |
+
"description": "Detect query language and translate if needed"
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
if reasoning.get("requires_rag", True) and has_rag_index:
|
| 122 |
+
plan["steps"].append({
|
| 123 |
+
"step": 2,
|
| 124 |
+
"action": "rag_retrieval",
|
| 125 |
+
"description": "Retrieve relevant document context",
|
| 126 |
+
"parameters": {"top_k": 15, "merge_threshold": 0.5}
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
if reasoning.get("requires_web_search", False):
|
| 130 |
+
plan["steps"].append({
|
| 131 |
+
"step": 3,
|
| 132 |
+
"action": "web_search",
|
| 133 |
+
"description": "Search web for current/updated information",
|
| 134 |
+
"parameters": {"max_results": 5}
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1:
|
| 138 |
+
plan["steps"].append({
|
| 139 |
+
"step": 4,
|
| 140 |
+
"action": "multi_step_reasoning",
|
| 141 |
+
"description": "Process sub-questions iteratively",
|
| 142 |
+
"sub_questions": reasoning["sub_questions"]
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
plan["steps"].append({
|
| 146 |
+
"step": len(plan["steps"]) + 1,
|
| 147 |
+
"action": "synthesize_answer",
|
| 148 |
+
"description": "Generate comprehensive answer from all sources"
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
if reasoning["complexity"] in ["complex", "multi_faceted"]:
|
| 152 |
+
plan["steps"].append({
|
| 153 |
+
"step": len(plan["steps"]) + 1,
|
| 154 |
+
"action": "self_reflection",
|
| 155 |
+
"description": "Evaluate answer quality and completeness"
|
| 156 |
+
})
|
| 157 |
+
|
| 158 |
+
logger.info(f"Execution plan created: {len(plan['steps'])} steps")
|
| 159 |
+
return plan
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict:
|
| 163 |
+
"""Autonomous execution: Make decisions on information gathering strategy"""
|
| 164 |
+
strategy = {
|
| 165 |
+
"use_rag": use_rag,
|
| 166 |
+
"use_web_search": use_web_search,
|
| 167 |
+
"reasoning_override": False,
|
| 168 |
+
"rationale": ""
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
if reasoning.get("requires_web_search", False) and not use_web_search:
|
| 172 |
+
strategy["rationale"] = "Reasoning suggests web search for current information, but the user kept it disabled."
|
| 173 |
+
|
| 174 |
+
if strategy["rationale"]:
|
| 175 |
+
logger.info(f"Autonomous reasoning note: {strategy['rationale']}")
|
| 176 |
+
|
| 177 |
+
return strategy
|
| 178 |
+
|
search.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web search functions"""
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import concurrent.futures
|
| 5 |
+
from logger import logger
|
| 6 |
+
from mcp import MCP_AVAILABLE, get_mcp_session, get_cached_mcp_tools, call_agent
|
| 7 |
+
from config import GEMINI_MODEL
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import nest_asyncio
|
| 11 |
+
except ImportError:
|
| 12 |
+
nest_asyncio = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
|
| 16 |
+
"""Search web using MCP web search tool (e.g., DuckDuckGo MCP server)"""
|
| 17 |
+
if not MCP_AVAILABLE:
|
| 18 |
+
return []
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
tools = await get_cached_mcp_tools()
|
| 22 |
+
if not tools:
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
search_tool = None
|
| 26 |
+
for tool in tools:
|
| 27 |
+
tool_name_lower = tool.name.lower()
|
| 28 |
+
if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
|
| 29 |
+
search_tool = tool
|
| 30 |
+
logger.info(f"Found web search MCP tool: {tool.name}")
|
| 31 |
+
break
|
| 32 |
+
|
| 33 |
+
if not search_tool:
|
| 34 |
+
tools = await get_cached_mcp_tools(force_refresh=True)
|
| 35 |
+
for tool in tools:
|
| 36 |
+
tool_name_lower = tool.name.lower()
|
| 37 |
+
if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
|
| 38 |
+
search_tool = tool
|
| 39 |
+
logger.info(f"Found web search MCP tool after refresh: {tool.name}")
|
| 40 |
+
break
|
| 41 |
+
|
| 42 |
+
if search_tool:
|
| 43 |
+
try:
|
| 44 |
+
session = await get_mcp_session()
|
| 45 |
+
if session is None:
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
result = await session.call_tool(
|
| 49 |
+
search_tool.name,
|
| 50 |
+
arguments={"query": query, "max_results": max_results}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
web_content = []
|
| 54 |
+
if hasattr(result, 'content') and result.content:
|
| 55 |
+
for item in result.content:
|
| 56 |
+
if hasattr(item, 'text'):
|
| 57 |
+
try:
|
| 58 |
+
data = json.loads(item.text)
|
| 59 |
+
if isinstance(data, list):
|
| 60 |
+
for entry in data[:max_results]:
|
| 61 |
+
web_content.append({
|
| 62 |
+
'title': entry.get('title', ''),
|
| 63 |
+
'url': entry.get('url', entry.get('href', '')),
|
| 64 |
+
'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
|
| 65 |
+
})
|
| 66 |
+
elif isinstance(data, dict):
|
| 67 |
+
if 'results' in data:
|
| 68 |
+
for entry in data['results'][:max_results]:
|
| 69 |
+
web_content.append({
|
| 70 |
+
'title': entry.get('title', ''),
|
| 71 |
+
'url': entry.get('url', entry.get('href', '')),
|
| 72 |
+
'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
|
| 73 |
+
})
|
| 74 |
+
else:
|
| 75 |
+
web_content.append({
|
| 76 |
+
'title': data.get('title', ''),
|
| 77 |
+
'url': data.get('url', data.get('href', '')),
|
| 78 |
+
'content': data.get('body', data.get('snippet', data.get('content', '')))
|
| 79 |
+
})
|
| 80 |
+
except json.JSONDecodeError:
|
| 81 |
+
web_content.append({
|
| 82 |
+
'title': '',
|
| 83 |
+
'url': '',
|
| 84 |
+
'content': item.text[:1000]
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
if web_content:
|
| 88 |
+
return web_content
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Error calling web search MCP tool: {e}")
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
logger.debug("No MCP web search tool discovered in current catalog")
|
| 94 |
+
return []
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Web search MCP tool error: {e}")
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
async def search_web_mcp(query: str, max_results: int = 5) -> list:
|
| 101 |
+
"""Search web using MCP tools - tries web search MCP tool first, then falls back to direct search"""
|
| 102 |
+
results = await search_web_mcp_tool(query, max_results)
|
| 103 |
+
if results:
|
| 104 |
+
logger.info(f"✅ Web search via MCP tool: found {len(results)} results")
|
| 105 |
+
return results
|
| 106 |
+
|
| 107 |
+
logger.info("ℹ️ [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)")
|
| 108 |
+
return search_web_fallback(query, max_results)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def search_web_fallback(query: str, max_results: int = 5) -> list:
|
| 112 |
+
"""Fallback web search using DuckDuckGo directly (when MCP is not available)"""
|
| 113 |
+
logger.info(f"🔍 [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...")
|
| 114 |
+
try:
|
| 115 |
+
from ddgs import DDGS
|
| 116 |
+
import requests
|
| 117 |
+
from bs4 import BeautifulSoup
|
| 118 |
+
except ImportError:
|
| 119 |
+
logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available")
|
| 120 |
+
return []
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
with DDGS() as ddgs:
|
| 124 |
+
results = list(ddgs.text(query, max_results=max_results))
|
| 125 |
+
web_content = []
|
| 126 |
+
for result in results:
|
| 127 |
+
try:
|
| 128 |
+
url = result.get('href', '')
|
| 129 |
+
title = result.get('title', '')
|
| 130 |
+
snippet = result.get('body', '')
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'})
|
| 134 |
+
if response.status_code == 200:
|
| 135 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 136 |
+
for script in soup(["script", "style"]):
|
| 137 |
+
script.decompose()
|
| 138 |
+
text = soup.get_text()
|
| 139 |
+
lines = (line.strip() for line in text.splitlines())
|
| 140 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
| 141 |
+
text = ' '.join(chunk for chunk in chunks if chunk)
|
| 142 |
+
if len(text) > 1000:
|
| 143 |
+
text = text[:1000] + "..."
|
| 144 |
+
web_content.append({
|
| 145 |
+
'title': title,
|
| 146 |
+
'url': url,
|
| 147 |
+
'content': snippet + "\n" + text[:500] if text else snippet
|
| 148 |
+
})
|
| 149 |
+
else:
|
| 150 |
+
web_content.append({
|
| 151 |
+
'title': title,
|
| 152 |
+
'url': url,
|
| 153 |
+
'content': snippet
|
| 154 |
+
})
|
| 155 |
+
except:
|
| 156 |
+
web_content.append({
|
| 157 |
+
'title': title,
|
| 158 |
+
'url': url,
|
| 159 |
+
'content': snippet
|
| 160 |
+
})
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"Error processing search result: {e}")
|
| 163 |
+
continue
|
| 164 |
+
logger.info(f"✅ [Direct API] Web search completed: {len(web_content)} results")
|
| 165 |
+
return web_content
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"❌ [Direct API] Web search error: {e}")
|
| 168 |
+
return []
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def search_web(query: str, max_results: int = 5) -> list:
|
| 172 |
+
"""Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs"""
|
| 173 |
+
if MCP_AVAILABLE:
|
| 174 |
+
try:
|
| 175 |
+
try:
|
| 176 |
+
loop = asyncio.get_event_loop()
|
| 177 |
+
except RuntimeError:
|
| 178 |
+
loop = asyncio.new_event_loop()
|
| 179 |
+
asyncio.set_event_loop(loop)
|
| 180 |
+
|
| 181 |
+
if loop.is_running():
|
| 182 |
+
if nest_asyncio:
|
| 183 |
+
results = nest_asyncio.run(search_web_mcp(query, max_results))
|
| 184 |
+
if results:
|
| 185 |
+
return results
|
| 186 |
+
else:
|
| 187 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 188 |
+
future = executor.submit(asyncio.run, search_web_mcp(query, max_results))
|
| 189 |
+
results = future.result(timeout=30)
|
| 190 |
+
if results:
|
| 191 |
+
return results
|
| 192 |
+
else:
|
| 193 |
+
results = loop.run_until_complete(search_web_mcp(query, max_results))
|
| 194 |
+
if results:
|
| 195 |
+
return results
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Error running async MCP search: {e}")
|
| 198 |
+
|
| 199 |
+
logger.info("ℹ️ [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)")
|
| 200 |
+
return search_web_fallback(query, max_results)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
async def summarize_web_content_gemini(content_list: list, query: str) -> str:
|
| 204 |
+
"""Summarize web search results using Gemini MCP"""
|
| 205 |
+
combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
|
| 206 |
+
|
| 207 |
+
user_prompt = f"""Summarize the following web search results related to the query: "{query}"
|
| 208 |
+
Extract key medical information, facts, and insights. Be concise and focus on reliable information.
|
| 209 |
+
Search Results:
|
| 210 |
+
{combined_content}
|
| 211 |
+
Summary:"""
|
| 212 |
+
|
| 213 |
+
system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately."
|
| 214 |
+
|
| 215 |
+
result = await call_agent(
|
| 216 |
+
user_prompt=user_prompt,
|
| 217 |
+
system_prompt=system_prompt,
|
| 218 |
+
model=GEMINI_MODEL,
|
| 219 |
+
temperature=0.5
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return result.strip()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def summarize_web_content(content_list: list, query: str) -> str:
|
| 226 |
+
"""Summarize web search results using Gemini MCP"""
|
| 227 |
+
if not MCP_AVAILABLE:
|
| 228 |
+
logger.warning("Gemini MCP not available for summarization")
|
| 229 |
+
if content_list:
|
| 230 |
+
return content_list[0].get('content', '')[:500]
|
| 231 |
+
return ""
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
loop = asyncio.get_event_loop()
|
| 235 |
+
if loop.is_running():
|
| 236 |
+
if nest_asyncio:
|
| 237 |
+
summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query))
|
| 238 |
+
if summary:
|
| 239 |
+
return summary
|
| 240 |
+
else:
|
| 241 |
+
logger.error("Error in nested async summarization: nest_asyncio not available")
|
| 242 |
+
else:
|
| 243 |
+
summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query))
|
| 244 |
+
if summary:
|
| 245 |
+
return summary
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(f"Gemini MCP summarization error: {e}")
|
| 248 |
+
|
| 249 |
+
if content_list:
|
| 250 |
+
return content_list[0].get('content', '')[:500]
|
| 251 |
+
return ""
|
| 252 |
+
|
supervisor.py
ADDED
|
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gemini Supervisor functions for MAC architecture"""
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import torch
|
| 5 |
+
import spaces
|
| 6 |
+
from logger import logger
|
| 7 |
+
from mcp import MCP_AVAILABLE, call_agent
|
| 8 |
+
from config import GEMINI_MODEL, GEMINI_MODEL_LITE
|
| 9 |
+
from utils import format_prompt_manually
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import nest_asyncio
|
| 13 |
+
except ImportError:
|
| 14 |
+
nest_asyncio = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
async def gemini_supervisor_breakdown_async(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
|
| 18 |
+
"""Gemini Supervisor: Break user query into sub-topics"""
|
| 19 |
+
remaining_time = max(15, max_duration - time_elapsed)
|
| 20 |
+
|
| 21 |
+
mode_description = []
|
| 22 |
+
if use_rag:
|
| 23 |
+
mode_description.append("RAG mode enabled - will use retrieved documents")
|
| 24 |
+
if use_web_search:
|
| 25 |
+
mode_description.append("Web search mode enabled - will search online sources")
|
| 26 |
+
if not mode_description:
|
| 27 |
+
mode_description.append("Direct answer mode - no additional context")
|
| 28 |
+
|
| 29 |
+
estimated_time_per_task = 8
|
| 30 |
+
max_topics_by_time = max(2, int((remaining_time - 20) / estimated_time_per_task))
|
| 31 |
+
max_topics = min(max_topics_by_time, 10)
|
| 32 |
+
|
| 33 |
+
prompt = f"""You are a supervisor agent coordinating with a MedSwin medical specialist model.
|
| 34 |
+
Break the following medical query into focused sub-topics that MedSwin can answer sequentially.
|
| 35 |
+
Explore different potential approaches to comprehensively address the topic.
|
| 36 |
+
|
| 37 |
+
Query: "{query}"
|
| 38 |
+
Mode: {', '.join(mode_description)}
|
| 39 |
+
Time Remaining: ~{remaining_time:.1f}s
|
| 40 |
+
Maximum Topics: {max_topics} (adjust based on complexity - use as many as needed for thorough coverage)
|
| 41 |
+
|
| 42 |
+
Return ONLY valid JSON (no markdown, no tables, no explanations):
|
| 43 |
+
{{
|
| 44 |
+
"sub_topics": [
|
| 45 |
+
{{
|
| 46 |
+
"id": 1,
|
| 47 |
+
"topic": "concise topic name",
|
| 48 |
+
"instruction": "specific directive for MedSwin to answer this topic",
|
| 49 |
+
"expected_tokens": 200,
|
| 50 |
+
"priority": "high|medium|low",
|
| 51 |
+
"approach": "brief description of approach/angle for this topic"
|
| 52 |
+
}},
|
| 53 |
+
...
|
| 54 |
+
],
|
| 55 |
+
"strategy": "brief strategy description explaining the breakdown approach",
|
| 56 |
+
"exploration_note": "brief note on different approaches explored"
|
| 57 |
+
}}
|
| 58 |
+
|
| 59 |
+
Guidelines:
|
| 60 |
+
- Break down the query into as many subtasks as needed for comprehensive coverage
|
| 61 |
+
- Explore different angles/approaches (e.g., clinical, diagnostic, treatment, prevention, research perspectives)
|
| 62 |
+
- Each topic should be focused and answerable in ~200 tokens by MedSwin
|
| 63 |
+
- Prioritize topics by importance (high priority first)
|
| 64 |
+
- Don't limit yourself to 4 topics - use more if the query is complex or multi-faceted"""
|
| 65 |
+
|
| 66 |
+
system_prompt = "You are a medical query supervisor. Break queries into structured JSON sub-topics, exploring different approaches. Return ONLY valid JSON."
|
| 67 |
+
|
| 68 |
+
response = await call_agent(
|
| 69 |
+
user_prompt=prompt,
|
| 70 |
+
system_prompt=system_prompt,
|
| 71 |
+
model=GEMINI_MODEL,
|
| 72 |
+
temperature=0.3
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
json_start = response.find('{')
|
| 77 |
+
json_end = response.rfind('}') + 1
|
| 78 |
+
if json_start >= 0 and json_end > json_start:
|
| 79 |
+
breakdown = json.loads(response[json_start:json_end])
|
| 80 |
+
logger.info(f"[GEMINI SUPERVISOR] Query broken into {len(breakdown.get('sub_topics', []))} sub-topics")
|
| 81 |
+
return breakdown
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError("Supervisor JSON not found")
|
| 84 |
+
except Exception as exc:
|
| 85 |
+
logger.error(f"[GEMINI SUPERVISOR] Breakdown parsing failed: {exc}")
|
| 86 |
+
breakdown = {
|
| 87 |
+
"sub_topics": [
|
| 88 |
+
{"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
|
| 89 |
+
{"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
|
| 90 |
+
],
|
| 91 |
+
"strategy": "Sequential answer with key points",
|
| 92 |
+
"exploration_note": "Fallback breakdown - basic coverage"
|
| 93 |
+
}
|
| 94 |
+
logger.warning(f"[GEMINI SUPERVISOR] Using fallback breakdown")
|
| 95 |
+
return breakdown
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
async def gemini_supervisor_search_strategies_async(query: str, time_elapsed: float) -> dict:
|
| 99 |
+
"""Gemini Supervisor: In search mode, break query into 1-4 searching strategies"""
|
| 100 |
+
prompt = f"""You are supervising web search for a medical query.
|
| 101 |
+
Break this query into 1-4 focused search strategies (each targeting 1-2 sources).
|
| 102 |
+
|
| 103 |
+
Query: "{query}"
|
| 104 |
+
|
| 105 |
+
Return ONLY valid JSON:
|
| 106 |
+
{{
|
| 107 |
+
"search_strategies": [
|
| 108 |
+
{{
|
| 109 |
+
"id": 1,
|
| 110 |
+
"strategy": "search query string",
|
| 111 |
+
"target_sources": 1,
|
| 112 |
+
"focus": "what to search for"
|
| 113 |
+
}},
|
| 114 |
+
...
|
| 115 |
+
],
|
| 116 |
+
"max_strategies": 4
|
| 117 |
+
}}
|
| 118 |
+
|
| 119 |
+
Keep strategies focused and avoid overlap."""
|
| 120 |
+
|
| 121 |
+
system_prompt = "You are a search strategy supervisor. Create focused search queries. Return ONLY valid JSON."
|
| 122 |
+
|
| 123 |
+
response = await call_agent(
|
| 124 |
+
user_prompt=prompt,
|
| 125 |
+
system_prompt=system_prompt,
|
| 126 |
+
model=GEMINI_MODEL_LITE,
|
| 127 |
+
temperature=0.2
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
json_start = response.find('{')
|
| 132 |
+
json_end = response.rfind('}') + 1
|
| 133 |
+
if json_start >= 0 and json_end > json_start:
|
| 134 |
+
strategies = json.loads(response[json_start:json_end])
|
| 135 |
+
logger.info(f"[GEMINI SUPERVISOR] Created {len(strategies.get('search_strategies', []))} search strategies")
|
| 136 |
+
return strategies
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError("Search strategies JSON not found")
|
| 139 |
+
except Exception as exc:
|
| 140 |
+
logger.error(f"[GEMINI SUPERVISOR] Search strategies parsing failed: {exc}")
|
| 141 |
+
return {
|
| 142 |
+
"search_strategies": [
|
| 143 |
+
{"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
|
| 144 |
+
],
|
| 145 |
+
"max_strategies": 1
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
|
| 150 |
+
"""Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
|
| 151 |
+
max_doc_length = 3000
|
| 152 |
+
if len(retrieved_docs) > max_doc_length:
|
| 153 |
+
retrieved_docs = retrieved_docs[:max_doc_length] + "..."
|
| 154 |
+
|
| 155 |
+
prompt = f"""You are supervising RAG context preparation for a medical query.
|
| 156 |
+
Brainstorm the retrieved documents into 1-4 concise, focused contexts that MedSwin can use.
|
| 157 |
+
|
| 158 |
+
Query: "{query}"
|
| 159 |
+
Retrieved Documents:
|
| 160 |
+
{retrieved_docs}
|
| 161 |
+
|
| 162 |
+
Return ONLY valid JSON:
|
| 163 |
+
{{
|
| 164 |
+
"contexts": [
|
| 165 |
+
{{
|
| 166 |
+
"id": 1,
|
| 167 |
+
"context": "concise summary of relevant information (keep under 500 chars)",
|
| 168 |
+
"focus": "what this context covers",
|
| 169 |
+
"relevance": "high|medium|low"
|
| 170 |
+
}},
|
| 171 |
+
...
|
| 172 |
+
],
|
| 173 |
+
"max_contexts": 4
|
| 174 |
+
}}
|
| 175 |
+
|
| 176 |
+
Keep contexts brief and factual. Avoid redundancy."""
|
| 177 |
+
|
| 178 |
+
system_prompt = "You are a RAG context supervisor. Summarize documents into concise contexts. Return ONLY valid JSON."
|
| 179 |
+
|
| 180 |
+
response = await call_agent(
|
| 181 |
+
user_prompt=prompt,
|
| 182 |
+
system_prompt=system_prompt,
|
| 183 |
+
model=GEMINI_MODEL_LITE,
|
| 184 |
+
temperature=0.2
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
json_start = response.find('{')
|
| 189 |
+
json_end = response.rfind('}') + 1
|
| 190 |
+
if json_start >= 0 and json_end > json_start:
|
| 191 |
+
contexts = json.loads(response[json_start:json_end])
|
| 192 |
+
logger.info(f"[GEMINI SUPERVISOR] Brainstormed {len(contexts.get('contexts', []))} RAG contexts")
|
| 193 |
+
return contexts
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError("RAG contexts JSON not found")
|
| 196 |
+
except Exception as exc:
|
| 197 |
+
logger.error(f"[GEMINI SUPERVISOR] RAG brainstorming parsing failed: {exc}")
|
| 198 |
+
return {
|
| 199 |
+
"contexts": [
|
| 200 |
+
{"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
|
| 201 |
+
],
|
| 202 |
+
"max_contexts": 1
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
|
| 207 |
+
"""Wrapper to obtain supervisor breakdown synchronously"""
|
| 208 |
+
if not MCP_AVAILABLE:
|
| 209 |
+
logger.warning("[GEMINI SUPERVISOR] MCP unavailable, using fallback breakdown")
|
| 210 |
+
return {
|
| 211 |
+
"sub_topics": [
|
| 212 |
+
{"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
|
| 213 |
+
{"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
|
| 214 |
+
],
|
| 215 |
+
"strategy": "Sequential answer with key points",
|
| 216 |
+
"exploration_note": "Fallback breakdown - basic coverage"
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
loop = asyncio.get_event_loop()
|
| 221 |
+
if loop.is_running():
|
| 222 |
+
if nest_asyncio:
|
| 223 |
+
return nest_asyncio.run(
|
| 224 |
+
gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
logger.error("[GEMINI SUPERVISOR] Nested breakdown execution failed: nest_asyncio not available")
|
| 228 |
+
return loop.run_until_complete(
|
| 229 |
+
gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
|
| 230 |
+
)
|
| 231 |
+
except Exception as exc:
|
| 232 |
+
logger.error(f"[GEMINI SUPERVISOR] Breakdown request failed: {exc}")
|
| 233 |
+
return {
|
| 234 |
+
"sub_topics": [
|
| 235 |
+
{"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
|
| 236 |
+
],
|
| 237 |
+
"strategy": "Direct answer",
|
| 238 |
+
"exploration_note": "Fallback breakdown - single topic"
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def gemini_supervisor_search_strategies(query: str, time_elapsed: float) -> dict:
|
| 243 |
+
"""Wrapper to obtain search strategies synchronously"""
|
| 244 |
+
if not MCP_AVAILABLE:
|
| 245 |
+
logger.warning("[GEMINI SUPERVISOR] MCP unavailable for search strategies")
|
| 246 |
+
return {
|
| 247 |
+
"search_strategies": [
|
| 248 |
+
{"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
|
| 249 |
+
],
|
| 250 |
+
"max_strategies": 1
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
loop = asyncio.get_event_loop()
|
| 255 |
+
if loop.is_running():
|
| 256 |
+
if nest_asyncio:
|
| 257 |
+
return nest_asyncio.run(gemini_supervisor_search_strategies_async(query, time_elapsed))
|
| 258 |
+
else:
|
| 259 |
+
logger.error("[GEMINI SUPERVISOR] Nested search strategies execution failed: nest_asyncio not available")
|
| 260 |
+
return loop.run_until_complete(gemini_supervisor_search_strategies_async(query, time_elapsed))
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
logger.error(f"[GEMINI SUPERVISOR] Search strategies request failed: {exc}")
|
| 263 |
+
return {
|
| 264 |
+
"search_strategies": [
|
| 265 |
+
{"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
|
| 266 |
+
],
|
| 267 |
+
"max_strategies": 1
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
|
| 272 |
+
"""Wrapper to obtain RAG brainstorm synchronously"""
|
| 273 |
+
if not MCP_AVAILABLE:
|
| 274 |
+
logger.warning("[GEMINI SUPERVISOR] MCP unavailable for RAG brainstorm")
|
| 275 |
+
return {
|
| 276 |
+
"contexts": [
|
| 277 |
+
{"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
|
| 278 |
+
],
|
| 279 |
+
"max_contexts": 1
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
loop = asyncio.get_event_loop()
|
| 284 |
+
if loop.is_running():
|
| 285 |
+
if nest_asyncio:
|
| 286 |
+
return nest_asyncio.run(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
|
| 287 |
+
else:
|
| 288 |
+
logger.error("[GEMINI SUPERVISOR] Nested RAG brainstorm execution failed: nest_asyncio not available")
|
| 289 |
+
return loop.run_until_complete(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
|
| 290 |
+
except Exception as exc:
|
| 291 |
+
logger.error(f"[GEMINI SUPERVISOR] RAG brainstorm request failed: {exc}")
|
| 292 |
+
return {
|
| 293 |
+
"contexts": [
|
| 294 |
+
{"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
|
| 295 |
+
],
|
| 296 |
+
"max_contexts": 1
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@spaces.GPU(max_duration=120)
|
| 301 |
+
def execute_medswin_task(
|
| 302 |
+
medical_model_obj,
|
| 303 |
+
medical_tokenizer,
|
| 304 |
+
task_instruction: str,
|
| 305 |
+
context: str,
|
| 306 |
+
system_prompt_base: str,
|
| 307 |
+
temperature: float,
|
| 308 |
+
max_new_tokens: int,
|
| 309 |
+
top_p: float,
|
| 310 |
+
top_k: int,
|
| 311 |
+
penalty: float
|
| 312 |
+
) -> str:
|
| 313 |
+
"""MedSwin Specialist: Execute a single task assigned by Gemini Supervisor"""
|
| 314 |
+
if context:
|
| 315 |
+
full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
|
| 316 |
+
else:
|
| 317 |
+
full_prompt = f"{system_prompt_base}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
|
| 318 |
+
|
| 319 |
+
messages = [{"role": "system", "content": full_prompt}]
|
| 320 |
+
|
| 321 |
+
if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
|
| 322 |
+
try:
|
| 323 |
+
prompt = medical_tokenizer.apply_chat_template(
|
| 324 |
+
messages,
|
| 325 |
+
tokenize=False,
|
| 326 |
+
add_generation_prompt=True
|
| 327 |
+
)
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.warning(f"[MEDSWIN] Chat template failed, using manual formatting: {e}")
|
| 330 |
+
prompt = format_prompt_manually(messages, medical_tokenizer)
|
| 331 |
+
else:
|
| 332 |
+
prompt = format_prompt_manually(messages, medical_tokenizer)
|
| 333 |
+
|
| 334 |
+
inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
|
| 335 |
+
|
| 336 |
+
eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
|
| 337 |
+
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
outputs = medical_model_obj.generate(
|
| 340 |
+
**inputs,
|
| 341 |
+
max_new_tokens=min(max_new_tokens, 800),
|
| 342 |
+
temperature=temperature,
|
| 343 |
+
top_p=top_p,
|
| 344 |
+
top_k=top_k,
|
| 345 |
+
repetition_penalty=penalty,
|
| 346 |
+
do_sample=True,
|
| 347 |
+
eos_token_id=eos_token_id,
|
| 348 |
+
pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
response = medical_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
| 352 |
+
|
| 353 |
+
response = response.strip()
|
| 354 |
+
if "|" in response and "---" in response:
|
| 355 |
+
logger.warning("[MEDSWIN] Detected table format, converting to Markdown bullets")
|
| 356 |
+
lines = [line.strip() for line in response.split('\n') if line.strip() and not line.strip().startswith('|') and '---' not in line]
|
| 357 |
+
response = '\n'.join([f"- {line}" if not line.startswith('-') else line for line in lines])
|
| 358 |
+
|
| 359 |
+
logger.info(f"[MEDSWIN] Task completed: {len(response)} chars generated")
|
| 360 |
+
return response
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
|
| 364 |
+
"""Gemini Supervisor: Synthesize final answer from all MedSwin responses"""
|
| 365 |
+
context_summary = ""
|
| 366 |
+
if rag_contexts:
|
| 367 |
+
context_summary += f"Document Context Available: {len(rag_contexts)} context(s) from uploaded documents.\n"
|
| 368 |
+
if search_contexts:
|
| 369 |
+
context_summary += f"Web Search Context Available: {len(search_contexts)} search result(s).\n"
|
| 370 |
+
|
| 371 |
+
all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
|
| 372 |
+
|
| 373 |
+
prompt = f"""You are a supervisor agent synthesizing a comprehensive medical answer from multiple specialist responses.
|
| 374 |
+
|
| 375 |
+
Original Query: "{query}"
|
| 376 |
+
|
| 377 |
+
Context Available:
|
| 378 |
+
{context_summary}
|
| 379 |
+
|
| 380 |
+
MedSwin Specialist Responses (from {len(medswin_answers)} sub-topics):
|
| 381 |
+
{all_answers_text}
|
| 382 |
+
|
| 383 |
+
Your task:
|
| 384 |
+
1. Synthesize all responses into a coherent, comprehensive final answer
|
| 385 |
+
2. Integrate information from all sub-topics seamlessly
|
| 386 |
+
3. Ensure the answer directly addresses the original query
|
| 387 |
+
4. Maintain clinical accuracy and clarity
|
| 388 |
+
5. Use clear structure with appropriate headings and bullet points
|
| 389 |
+
6. Remove redundancy and contradictions
|
| 390 |
+
7. Ensure all important points from MedSwin responses are included
|
| 391 |
+
|
| 392 |
+
Return the final synthesized answer in Markdown format. Do not add meta-commentary or explanations - just provide the final answer."""
|
| 393 |
+
|
| 394 |
+
system_prompt = "You are a medical answer synthesis supervisor. Create comprehensive, well-structured final answers from multiple specialist responses."
|
| 395 |
+
|
| 396 |
+
result = await call_agent(
|
| 397 |
+
user_prompt=prompt,
|
| 398 |
+
system_prompt=system_prompt,
|
| 399 |
+
model=GEMINI_MODEL,
|
| 400 |
+
temperature=0.3
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return result.strip()
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
async def gemini_supervisor_challenge_async(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
|
| 407 |
+
"""Gemini Supervisor: Challenge and evaluate the current answer"""
|
| 408 |
+
context_info = ""
|
| 409 |
+
if rag_contexts:
|
| 410 |
+
context_info += f"Document contexts: {len(rag_contexts)} available.\n"
|
| 411 |
+
if search_contexts:
|
| 412 |
+
context_info += f"Search contexts: {len(search_contexts)} available.\n"
|
| 413 |
+
|
| 414 |
+
all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
|
| 415 |
+
|
| 416 |
+
prompt = f"""You are a supervisor agent evaluating and challenging a medical answer for quality and completeness.
|
| 417 |
+
|
| 418 |
+
Original Query: "{query}"
|
| 419 |
+
|
| 420 |
+
Available Context:
|
| 421 |
+
{context_info}
|
| 422 |
+
|
| 423 |
+
MedSwin Specialist Responses:
|
| 424 |
+
{all_answers_text}
|
| 425 |
+
|
| 426 |
+
Current Synthesized Answer:
|
| 427 |
+
{current_answer[:2000]}
|
| 428 |
+
|
| 429 |
+
Evaluate this answer and provide:
|
| 430 |
+
1. Completeness: Does it fully address the query? What's missing?
|
| 431 |
+
2. Accuracy: Are there any inaccuracies or contradictions?
|
| 432 |
+
3. Clarity: Is it well-structured and clear?
|
| 433 |
+
4. Context Usage: Are document/search contexts properly utilized?
|
| 434 |
+
5. Improvement Suggestions: Specific ways to enhance the answer
|
| 435 |
+
|
| 436 |
+
Return ONLY valid JSON:
|
| 437 |
+
{{
|
| 438 |
+
"is_optimal": true/false,
|
| 439 |
+
"completeness_score": 0-10,
|
| 440 |
+
"accuracy_score": 0-10,
|
| 441 |
+
"clarity_score": 0-10,
|
| 442 |
+
"missing_aspects": ["..."],
|
| 443 |
+
"inaccuracies": ["..."],
|
| 444 |
+
"improvement_suggestions": ["..."],
|
| 445 |
+
"needs_more_context": true/false,
|
| 446 |
+
"enhancement_instructions": "specific instructions for improving the answer"
|
| 447 |
+
}}"""
|
| 448 |
+
|
| 449 |
+
system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback in JSON format. Return ONLY valid JSON."
|
| 450 |
+
|
| 451 |
+
response = await call_agent(
|
| 452 |
+
user_prompt=prompt,
|
| 453 |
+
system_prompt=system_prompt,
|
| 454 |
+
model=GEMINI_MODEL,
|
| 455 |
+
temperature=0.3
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
try:
|
| 459 |
+
json_start = response.find('{')
|
| 460 |
+
json_end = response.rfind('}') + 1
|
| 461 |
+
if json_start >= 0 and json_end > json_start:
|
| 462 |
+
evaluation = json.loads(response[json_start:json_end])
|
| 463 |
+
logger.info(f"[GEMINI SUPERVISOR] Challenge evaluation: optimal={evaluation.get('is_optimal', False)}, scores={evaluation.get('completeness_score', 'N/A')}/{evaluation.get('accuracy_score', 'N/A')}/{evaluation.get('clarity_score', 'N/A')}")
|
| 464 |
+
return evaluation
|
| 465 |
+
else:
|
| 466 |
+
raise ValueError("Evaluation JSON not found")
|
| 467 |
+
except Exception as exc:
|
| 468 |
+
logger.error(f"[GEMINI SUPERVISOR] Challenge evaluation parsing failed: {exc}")
|
| 469 |
+
return {
|
| 470 |
+
"is_optimal": True,
|
| 471 |
+
"completeness_score": 7,
|
| 472 |
+
"accuracy_score": 7,
|
| 473 |
+
"clarity_score": 7,
|
| 474 |
+
"missing_aspects": [],
|
| 475 |
+
"inaccuracies": [],
|
| 476 |
+
"improvement_suggestions": [],
|
| 477 |
+
"needs_more_context": False,
|
| 478 |
+
"enhancement_instructions": ""
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
async def gemini_supervisor_enhance_answer_async(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
|
| 483 |
+
"""Gemini Supervisor: Enhance the answer based on challenge feedback"""
|
| 484 |
+
context_info = ""
|
| 485 |
+
if rag_contexts:
|
| 486 |
+
context_info += f"Document contexts: {len(rag_contexts)} available.\n"
|
| 487 |
+
if search_contexts:
|
| 488 |
+
context_info += f"Search contexts: {len(search_contexts)} available.\n"
|
| 489 |
+
|
| 490 |
+
all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
|
| 491 |
+
|
| 492 |
+
prompt = f"""You are a supervisor agent enhancing a medical answer based on evaluation feedback.
|
| 493 |
+
|
| 494 |
+
Original Query: "{query}"
|
| 495 |
+
|
| 496 |
+
Available Context:
|
| 497 |
+
{context_info}
|
| 498 |
+
|
| 499 |
+
MedSwin Specialist Responses:
|
| 500 |
+
{all_answers_text}
|
| 501 |
+
|
| 502 |
+
Current Answer (to enhance):
|
| 503 |
+
{current_answer}
|
| 504 |
+
|
| 505 |
+
Enhancement Instructions:
|
| 506 |
+
{enhancement_instructions}
|
| 507 |
+
|
| 508 |
+
Create an enhanced version of the answer that:
|
| 509 |
+
1. Addresses all improvement suggestions
|
| 510 |
+
2. Fills in missing aspects
|
| 511 |
+
3. Corrects any inaccuracies
|
| 512 |
+
4. Improves clarity and structure
|
| 513 |
+
5. Better utilizes available context
|
| 514 |
+
6. Maintains all valuable information from the current answer
|
| 515 |
+
|
| 516 |
+
Return the enhanced answer in Markdown format. Do not add meta-commentary."""
|
| 517 |
+
|
| 518 |
+
system_prompt = "You are a medical answer enhancement supervisor. Improve answers based on evaluation feedback while maintaining accuracy."
|
| 519 |
+
|
| 520 |
+
result = await call_agent(
|
| 521 |
+
user_prompt=prompt,
|
| 522 |
+
system_prompt=system_prompt,
|
| 523 |
+
model=GEMINI_MODEL,
|
| 524 |
+
temperature=0.3
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return result.strip()
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
async def gemini_supervisor_check_clarity_async(query: str, answer: str, use_web_search: bool) -> dict:
|
| 531 |
+
"""Gemini Supervisor: Check if answer is unclear or supervisor is unsure"""
|
| 532 |
+
if not use_web_search:
|
| 533 |
+
return {"is_unclear": False, "needs_search": False, "search_queries": []}
|
| 534 |
+
|
| 535 |
+
prompt = f"""You are a supervisor agent evaluating answer clarity and completeness.
|
| 536 |
+
|
| 537 |
+
Query: "{query}"
|
| 538 |
+
|
| 539 |
+
Current Answer:
|
| 540 |
+
{answer[:1500]}
|
| 541 |
+
|
| 542 |
+
Evaluate:
|
| 543 |
+
1. Is the answer unclear or incomplete?
|
| 544 |
+
2. Are there gaps that web search could fill?
|
| 545 |
+
3. Is the supervisor (you) unsure about certain aspects?
|
| 546 |
+
|
| 547 |
+
Return ONLY valid JSON:
|
| 548 |
+
{{
|
| 549 |
+
"is_unclear": true/false,
|
| 550 |
+
"needs_search": true/false,
|
| 551 |
+
"uncertainty_areas": ["..."],
|
| 552 |
+
"search_queries": ["specific search queries to fill gaps"],
|
| 553 |
+
"rationale": "brief explanation"
|
| 554 |
+
}}
|
| 555 |
+
|
| 556 |
+
Only suggest search if the answer is genuinely unclear or has significant gaps that search could address."""
|
| 557 |
+
|
| 558 |
+
system_prompt = "You are a clarity evaluator. Assess if additional web search is needed. Return ONLY valid JSON."
|
| 559 |
+
|
| 560 |
+
response = await call_agent(
|
| 561 |
+
user_prompt=prompt,
|
| 562 |
+
system_prompt=system_prompt,
|
| 563 |
+
model=GEMINI_MODEL_LITE,
|
| 564 |
+
temperature=0.2
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
try:
|
| 568 |
+
json_start = response.find('{')
|
| 569 |
+
json_end = response.rfind('}') + 1
|
| 570 |
+
if json_start >= 0 and json_end > json_start:
|
| 571 |
+
evaluation = json.loads(response[json_start:json_end])
|
| 572 |
+
logger.info(f"[GEMINI SUPERVISOR] Clarity check: unclear={evaluation.get('is_unclear', False)}, needs_search={evaluation.get('needs_search', False)}")
|
| 573 |
+
return evaluation
|
| 574 |
+
else:
|
| 575 |
+
raise ValueError("Clarity check JSON not found")
|
| 576 |
+
except Exception as exc:
|
| 577 |
+
logger.error(f"[GEMINI SUPERVISOR] Clarity check parsing failed: {exc}")
|
| 578 |
+
return {"is_unclear": False, "needs_search": False, "search_queries": []}
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def gemini_supervisor_synthesize(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
|
| 582 |
+
"""Wrapper to synthesize answer synchronously"""
|
| 583 |
+
if not MCP_AVAILABLE:
|
| 584 |
+
logger.warning("[GEMINI SUPERVISOR] MCP unavailable for synthesis, using simple concatenation")
|
| 585 |
+
return "\n\n".join(medswin_answers)
|
| 586 |
+
|
| 587 |
+
try:
|
| 588 |
+
loop = asyncio.get_event_loop()
|
| 589 |
+
if loop.is_running():
|
| 590 |
+
if nest_asyncio:
|
| 591 |
+
return nest_asyncio.run(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
|
| 592 |
+
else:
|
| 593 |
+
logger.error("[GEMINI SUPERVISOR] Nested synthesis failed: nest_asyncio not available")
|
| 594 |
+
return loop.run_until_complete(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
|
| 595 |
+
except Exception as exc:
|
| 596 |
+
logger.error(f"[GEMINI SUPERVISOR] Synthesis failed: {exc}")
|
| 597 |
+
return "\n\n".join(medswin_answers)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def gemini_supervisor_challenge(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
|
| 601 |
+
"""Wrapper to challenge answer synchronously"""
|
| 602 |
+
if not MCP_AVAILABLE:
|
| 603 |
+
return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
loop = asyncio.get_event_loop()
|
| 607 |
+
if loop.is_running():
|
| 608 |
+
if nest_asyncio:
|
| 609 |
+
return nest_asyncio.run(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
|
| 610 |
+
else:
|
| 611 |
+
logger.error("[GEMINI SUPERVISOR] Nested challenge failed: nest_asyncio not available")
|
| 612 |
+
return loop.run_until_complete(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
|
| 613 |
+
except Exception as exc:
|
| 614 |
+
logger.error(f"[GEMINI SUPERVISOR] Challenge failed: {exc}")
|
| 615 |
+
return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def gemini_supervisor_enhance_answer(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
|
| 619 |
+
"""Wrapper to enhance answer synchronously"""
|
| 620 |
+
if not MCP_AVAILABLE:
|
| 621 |
+
return current_answer
|
| 622 |
+
|
| 623 |
+
try:
|
| 624 |
+
loop = asyncio.get_event_loop()
|
| 625 |
+
if loop.is_running():
|
| 626 |
+
if nest_asyncio:
|
| 627 |
+
return nest_asyncio.run(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
|
| 628 |
+
else:
|
| 629 |
+
logger.error("[GEMINI SUPERVISOR] Nested enhancement failed: nest_asyncio not available")
|
| 630 |
+
return loop.run_until_complete(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
|
| 631 |
+
except Exception as exc:
|
| 632 |
+
logger.error(f"[GEMINI SUPERVISOR] Enhancement failed: {exc}")
|
| 633 |
+
return current_answer
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def gemini_supervisor_check_clarity(query: str, answer: str, use_web_search: bool) -> dict:
|
| 637 |
+
"""Wrapper to check clarity synchronously"""
|
| 638 |
+
if not MCP_AVAILABLE or not use_web_search:
|
| 639 |
+
return {"is_unclear": False, "needs_search": False, "search_queries": []}
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
loop = asyncio.get_event_loop()
|
| 643 |
+
if loop.is_running():
|
| 644 |
+
if nest_asyncio:
|
| 645 |
+
return nest_asyncio.run(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
|
| 646 |
+
else:
|
| 647 |
+
logger.error("[GEMINI SUPERVISOR] Nested clarity check failed: nest_asyncio not available")
|
| 648 |
+
return loop.run_until_complete(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
|
| 649 |
+
except Exception as exc:
|
| 650 |
+
logger.error(f"[GEMINI SUPERVISOR] Clarity check failed: {exc}")
|
| 651 |
+
return {"is_unclear": False, "needs_search": False, "search_queries": []}
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
async def self_reflection_gemini(answer: str, query: str) -> dict:
|
| 655 |
+
"""Self-reflection using Gemini MCP"""
|
| 656 |
+
reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
|
| 657 |
+
Query: "{query}"
|
| 658 |
+
Answer: "{answer[:1000]}"
|
| 659 |
+
Evaluate:
|
| 660 |
+
1. Completeness: Does it address all aspects of the query?
|
| 661 |
+
2. Accuracy: Is the medical information accurate?
|
| 662 |
+
3. Clarity: Is it clear and well-structured?
|
| 663 |
+
4. Sources: Are sources cited appropriately?
|
| 664 |
+
5. Missing Information: What important information might be missing?
|
| 665 |
+
Respond in JSON:
|
| 666 |
+
{{
|
| 667 |
+
"completeness_score": 0-10,
|
| 668 |
+
"accuracy_score": 0-10,
|
| 669 |
+
"clarity_score": 0-10,
|
| 670 |
+
"overall_score": 0-10,
|
| 671 |
+
"missing_aspects": ["..."],
|
| 672 |
+
"improvement_suggestions": ["..."]
|
| 673 |
+
}}"""
|
| 674 |
+
|
| 675 |
+
system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback."
|
| 676 |
+
|
| 677 |
+
response = await call_agent(
|
| 678 |
+
user_prompt=reflection_prompt,
|
| 679 |
+
system_prompt=system_prompt,
|
| 680 |
+
model=GEMINI_MODEL,
|
| 681 |
+
temperature=0.3
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
try:
|
| 685 |
+
json_start = response.find('{')
|
| 686 |
+
json_end = response.rfind('}') + 1
|
| 687 |
+
if json_start >= 0 and json_end > json_start:
|
| 688 |
+
reflection = json.loads(response[json_start:json_end])
|
| 689 |
+
else:
|
| 690 |
+
reflection = {"overall_score": 7, "improvement_suggestions": []}
|
| 691 |
+
except:
|
| 692 |
+
reflection = {"overall_score": 7, "improvement_suggestions": []}
|
| 693 |
+
|
| 694 |
+
logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}")
|
| 695 |
+
return reflection
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def self_reflection(answer: str, query: str, reasoning: dict) -> dict:
|
| 699 |
+
"""Self-reflection: Evaluate answer quality and completeness"""
|
| 700 |
+
if not MCP_AVAILABLE:
|
| 701 |
+
logger.warning("Gemini MCP not available for reflection, using fallback")
|
| 702 |
+
return {"overall_score": 7, "improvement_suggestions": []}
|
| 703 |
+
|
| 704 |
+
try:
|
| 705 |
+
loop = asyncio.get_event_loop()
|
| 706 |
+
if loop.is_running():
|
| 707 |
+
if nest_asyncio:
|
| 708 |
+
return nest_asyncio.run(self_reflection_gemini(answer, query))
|
| 709 |
+
else:
|
| 710 |
+
logger.error("Error in nested async reflection: nest_asyncio not available")
|
| 711 |
+
else:
|
| 712 |
+
return loop.run_until_complete(self_reflection_gemini(answer, query))
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logger.error(f"Gemini MCP reflection error: {e}")
|
| 715 |
+
|
| 716 |
+
return {"overall_score": 7, "improvement_suggestions": []}
|
| 717 |
+
|
ui.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio UI setup"""
|
| 2 |
+
import time
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
|
| 5 |
+
from indexing import create_or_update_index
|
| 6 |
+
from pipeline import stream_chat
|
| 7 |
+
from voice import transcribe_audio, generate_speech
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_demo():
|
| 11 |
+
"""Create and return Gradio demo interface"""
|
| 12 |
+
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 13 |
+
gr.HTML(TITLE)
|
| 14 |
+
gr.HTML(DESCRIPTION)
|
| 15 |
+
|
| 16 |
+
with gr.Row(elem_classes="main-container"):
|
| 17 |
+
with gr.Column(elem_classes="upload-section"):
|
| 18 |
+
file_upload = gr.File(
|
| 19 |
+
file_count="multiple",
|
| 20 |
+
label="Drag and Drop Files Here",
|
| 21 |
+
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
|
| 22 |
+
elem_id="file-upload"
|
| 23 |
+
)
|
| 24 |
+
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
|
| 25 |
+
status_output = gr.Textbox(
|
| 26 |
+
label="Status",
|
| 27 |
+
placeholder="Upload files to start...",
|
| 28 |
+
interactive=False
|
| 29 |
+
)
|
| 30 |
+
file_info_output = gr.HTML(
|
| 31 |
+
label="File Information",
|
| 32 |
+
elem_classes="processing-info"
|
| 33 |
+
)
|
| 34 |
+
upload_button.click(
|
| 35 |
+
fn=create_or_update_index,
|
| 36 |
+
inputs=[file_upload],
|
| 37 |
+
outputs=[status_output, file_info_output]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
with gr.Column(elem_classes="chatbot-container"):
|
| 41 |
+
chatbot = gr.Chatbot(
|
| 42 |
+
height=500,
|
| 43 |
+
placeholder="Chat with MedSwin... Type your question below.",
|
| 44 |
+
show_label=False,
|
| 45 |
+
type="messages"
|
| 46 |
+
)
|
| 47 |
+
with gr.Row(elem_classes="input-row"):
|
| 48 |
+
message_input = gr.Textbox(
|
| 49 |
+
placeholder="Type your medical question here...",
|
| 50 |
+
show_label=False,
|
| 51 |
+
container=False,
|
| 52 |
+
lines=1,
|
| 53 |
+
scale=10
|
| 54 |
+
)
|
| 55 |
+
mic_button = gr.Audio(
|
| 56 |
+
sources=["microphone"],
|
| 57 |
+
type="filepath",
|
| 58 |
+
label="",
|
| 59 |
+
show_label=False,
|
| 60 |
+
container=False,
|
| 61 |
+
scale=1
|
| 62 |
+
)
|
| 63 |
+
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
|
| 64 |
+
|
| 65 |
+
recording_timer = gr.Textbox(
|
| 66 |
+
value="",
|
| 67 |
+
label="",
|
| 68 |
+
show_label=False,
|
| 69 |
+
interactive=False,
|
| 70 |
+
visible=False,
|
| 71 |
+
container=False,
|
| 72 |
+
elem_classes="recording-timer"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
recording_start_time = [None]
|
| 76 |
+
|
| 77 |
+
def handle_recording_start():
|
| 78 |
+
"""Called when recording starts"""
|
| 79 |
+
recording_start_time[0] = time.time()
|
| 80 |
+
return gr.update(visible=True, value="Recording... 0s")
|
| 81 |
+
|
| 82 |
+
def handle_recording_stop(audio):
|
| 83 |
+
"""Called when recording stops"""
|
| 84 |
+
recording_start_time[0] = None
|
| 85 |
+
if audio is None:
|
| 86 |
+
return gr.update(visible=False, value=""), ""
|
| 87 |
+
transcribed = transcribe_audio(audio)
|
| 88 |
+
return gr.update(visible=False, value=""), transcribed
|
| 89 |
+
|
| 90 |
+
mic_button.start_recording(
|
| 91 |
+
fn=handle_recording_start,
|
| 92 |
+
outputs=[recording_timer]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
mic_button.stop_recording(
|
| 96 |
+
fn=handle_recording_stop,
|
| 97 |
+
inputs=[mic_button],
|
| 98 |
+
outputs=[recording_timer, message_input]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
with gr.Row(visible=False) as tts_row:
|
| 102 |
+
tts_text = gr.Textbox(visible=False)
|
| 103 |
+
tts_audio = gr.Audio(label="Generated Speech", visible=False)
|
| 104 |
+
|
| 105 |
+
def generate_speech_from_chat(history):
|
| 106 |
+
"""Extract last assistant message and generate speech"""
|
| 107 |
+
if not history or len(history) == 0:
|
| 108 |
+
return None
|
| 109 |
+
last_msg = history[-1]
|
| 110 |
+
if last_msg.get("role") == "assistant":
|
| 111 |
+
text = last_msg.get("content", "").replace(" 🔊", "").strip()
|
| 112 |
+
if text:
|
| 113 |
+
audio_path = generate_speech(text)
|
| 114 |
+
return audio_path
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
|
| 118 |
+
|
| 119 |
+
def update_tts_button(history):
|
| 120 |
+
if history and len(history) > 0 and history[-1].get("role") == "assistant":
|
| 121 |
+
return gr.update(visible=True)
|
| 122 |
+
return gr.update(visible=False)
|
| 123 |
+
|
| 124 |
+
chatbot.change(
|
| 125 |
+
fn=update_tts_button,
|
| 126 |
+
inputs=[chatbot],
|
| 127 |
+
outputs=[tts_button]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
tts_button.click(
|
| 131 |
+
fn=generate_speech_from_chat,
|
| 132 |
+
inputs=[chatbot],
|
| 133 |
+
outputs=[tts_audio]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 137 |
+
with gr.Row():
|
| 138 |
+
disable_agentic_reasoning = gr.Checkbox(
|
| 139 |
+
value=False,
|
| 140 |
+
label="Disable agentic reasoning",
|
| 141 |
+
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
|
| 142 |
+
)
|
| 143 |
+
show_agentic_thought = gr.Button(
|
| 144 |
+
"Show agentic thought",
|
| 145 |
+
size="sm"
|
| 146 |
+
)
|
| 147 |
+
agentic_thoughts_box = gr.Textbox(
|
| 148 |
+
label="Agentic Thoughts",
|
| 149 |
+
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
|
| 150 |
+
lines=8,
|
| 151 |
+
max_lines=15,
|
| 152 |
+
interactive=False,
|
| 153 |
+
visible=False,
|
| 154 |
+
elem_classes="agentic-thoughts"
|
| 155 |
+
)
|
| 156 |
+
with gr.Row():
|
| 157 |
+
use_rag = gr.Checkbox(
|
| 158 |
+
value=False,
|
| 159 |
+
label="Enable Document RAG",
|
| 160 |
+
info="Answer based on uploaded documents (upload required)"
|
| 161 |
+
)
|
| 162 |
+
use_web_search = gr.Checkbox(
|
| 163 |
+
value=False,
|
| 164 |
+
label="Enable Web Search (MCP)",
|
| 165 |
+
info="Fetch knowledge from online medical resources"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
medical_model = gr.Radio(
|
| 169 |
+
choices=list(MEDSWIN_MODELS.keys()),
|
| 170 |
+
value=DEFAULT_MEDICAL_MODEL,
|
| 171 |
+
label="Medical Model",
|
| 172 |
+
info="MedSwin TA (default), others download on first use"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
system_prompt = gr.Textbox(
|
| 176 |
+
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
|
| 177 |
+
label="System Prompt",
|
| 178 |
+
lines=3
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
with gr.Tab("Generation Parameters"):
|
| 182 |
+
temperature = gr.Slider(
|
| 183 |
+
minimum=0,
|
| 184 |
+
maximum=1,
|
| 185 |
+
step=0.1,
|
| 186 |
+
value=0.2,
|
| 187 |
+
label="Temperature"
|
| 188 |
+
)
|
| 189 |
+
max_new_tokens = gr.Slider(
|
| 190 |
+
minimum=512,
|
| 191 |
+
maximum=4096,
|
| 192 |
+
step=128,
|
| 193 |
+
value=2048,
|
| 194 |
+
label="Max New Tokens",
|
| 195 |
+
info="Increased for medical models to prevent early stopping"
|
| 196 |
+
)
|
| 197 |
+
top_p = gr.Slider(
|
| 198 |
+
minimum=0.0,
|
| 199 |
+
maximum=1.0,
|
| 200 |
+
step=0.1,
|
| 201 |
+
value=0.7,
|
| 202 |
+
label="Top P"
|
| 203 |
+
)
|
| 204 |
+
top_k = gr.Slider(
|
| 205 |
+
minimum=1,
|
| 206 |
+
maximum=100,
|
| 207 |
+
step=1,
|
| 208 |
+
value=50,
|
| 209 |
+
label="Top K"
|
| 210 |
+
)
|
| 211 |
+
penalty = gr.Slider(
|
| 212 |
+
minimum=0.0,
|
| 213 |
+
maximum=2.0,
|
| 214 |
+
step=0.1,
|
| 215 |
+
value=1.2,
|
| 216 |
+
label="Repetition Penalty"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
with gr.Tab("Retrieval Parameters"):
|
| 220 |
+
retriever_k = gr.Slider(
|
| 221 |
+
minimum=5,
|
| 222 |
+
maximum=30,
|
| 223 |
+
step=1,
|
| 224 |
+
value=15,
|
| 225 |
+
label="Initial Retrieval Size (Top K)"
|
| 226 |
+
)
|
| 227 |
+
merge_threshold = gr.Slider(
|
| 228 |
+
minimum=0.1,
|
| 229 |
+
maximum=0.9,
|
| 230 |
+
step=0.1,
|
| 231 |
+
value=0.5,
|
| 232 |
+
label="Merge Threshold (lower = more merging)"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
show_thoughts_state = gr.State(value=False)
|
| 236 |
+
|
| 237 |
+
def toggle_thoughts_box(current_state):
|
| 238 |
+
"""Toggle visibility of agentic thoughts box"""
|
| 239 |
+
new_state = not current_state
|
| 240 |
+
return gr.update(visible=new_state), new_state
|
| 241 |
+
|
| 242 |
+
show_agentic_thought.click(
|
| 243 |
+
fn=toggle_thoughts_box,
|
| 244 |
+
inputs=[show_thoughts_state],
|
| 245 |
+
outputs=[agentic_thoughts_box, show_thoughts_state]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
submit_button.click(
|
| 249 |
+
fn=stream_chat,
|
| 250 |
+
inputs=[
|
| 251 |
+
message_input,
|
| 252 |
+
chatbot,
|
| 253 |
+
system_prompt,
|
| 254 |
+
temperature,
|
| 255 |
+
max_new_tokens,
|
| 256 |
+
top_p,
|
| 257 |
+
top_k,
|
| 258 |
+
penalty,
|
| 259 |
+
retriever_k,
|
| 260 |
+
merge_threshold,
|
| 261 |
+
use_rag,
|
| 262 |
+
medical_model,
|
| 263 |
+
use_web_search,
|
| 264 |
+
disable_agentic_reasoning,
|
| 265 |
+
show_thoughts_state
|
| 266 |
+
],
|
| 267 |
+
outputs=[chatbot, agentic_thoughts_box]
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
message_input.submit(
|
| 271 |
+
fn=stream_chat,
|
| 272 |
+
inputs=[
|
| 273 |
+
message_input,
|
| 274 |
+
chatbot,
|
| 275 |
+
system_prompt,
|
| 276 |
+
temperature,
|
| 277 |
+
max_new_tokens,
|
| 278 |
+
top_p,
|
| 279 |
+
top_k,
|
| 280 |
+
penalty,
|
| 281 |
+
retriever_k,
|
| 282 |
+
merge_threshold,
|
| 283 |
+
use_rag,
|
| 284 |
+
medical_model,
|
| 285 |
+
use_web_search,
|
| 286 |
+
disable_agentic_reasoning,
|
| 287 |
+
show_thoughts_state
|
| 288 |
+
],
|
| 289 |
+
outputs=[chatbot, agentic_thoughts_box]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
return demo
|
| 293 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for translation, language detection, and formatting"""
|
| 2 |
+
import asyncio
|
| 3 |
+
from langdetect import detect, LangDetectException
|
| 4 |
+
from logger import logger
|
| 5 |
+
from mcp import MCP_AVAILABLE, call_agent
|
| 6 |
+
from config import GEMINI_MODEL_LITE
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import nest_asyncio
|
| 10 |
+
except ImportError:
|
| 11 |
+
nest_asyncio = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def format_prompt_manually(messages: list, tokenizer) -> str:
|
| 15 |
+
"""Manually format prompt for models without chat template"""
|
| 16 |
+
system_content = ""
|
| 17 |
+
user_content = ""
|
| 18 |
+
|
| 19 |
+
for msg in messages:
|
| 20 |
+
role = msg.get("role", "user")
|
| 21 |
+
content = msg.get("content", "")
|
| 22 |
+
|
| 23 |
+
if role == "system":
|
| 24 |
+
system_content = content
|
| 25 |
+
elif role == "user":
|
| 26 |
+
user_content = content
|
| 27 |
+
|
| 28 |
+
if system_content:
|
| 29 |
+
prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
|
| 30 |
+
else:
|
| 31 |
+
prompt = f"Question: {user_content}\n\nAnswer:"
|
| 32 |
+
|
| 33 |
+
return prompt
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def detect_language(text: str) -> str:
|
| 37 |
+
"""Detect language of input text"""
|
| 38 |
+
try:
|
| 39 |
+
lang = detect(text)
|
| 40 |
+
return lang
|
| 41 |
+
except LangDetectException:
|
| 42 |
+
return "en"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def format_url_as_domain(url: str) -> str:
|
| 46 |
+
"""Format URL as simple domain name (e.g., www.mayoclinic.org)"""
|
| 47 |
+
if not url:
|
| 48 |
+
return ""
|
| 49 |
+
try:
|
| 50 |
+
from urllib.parse import urlparse
|
| 51 |
+
parsed = urlparse(url)
|
| 52 |
+
domain = parsed.netloc or parsed.path
|
| 53 |
+
if domain.startswith('www.'):
|
| 54 |
+
return domain
|
| 55 |
+
elif domain:
|
| 56 |
+
return domain
|
| 57 |
+
return url
|
| 58 |
+
except Exception:
|
| 59 |
+
if '://' in url:
|
| 60 |
+
domain = url.split('://')[1].split('/')[0]
|
| 61 |
+
return domain
|
| 62 |
+
return url
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str:
|
| 66 |
+
"""Translate text using Gemini MCP"""
|
| 67 |
+
if source_lang:
|
| 68 |
+
user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
|
| 69 |
+
else:
|
| 70 |
+
user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
|
| 71 |
+
|
| 72 |
+
system_prompt = "You are a professional translator. Translate accurately and concisely."
|
| 73 |
+
|
| 74 |
+
result = await call_agent(
|
| 75 |
+
user_prompt=user_prompt,
|
| 76 |
+
system_prompt=system_prompt,
|
| 77 |
+
model=GEMINI_MODEL_LITE,
|
| 78 |
+
temperature=0.2
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return result.strip()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str:
|
| 85 |
+
"""Translate text using Gemini MCP"""
|
| 86 |
+
if not MCP_AVAILABLE:
|
| 87 |
+
logger.warning("Gemini MCP not available for translation")
|
| 88 |
+
return text
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
loop = asyncio.get_event_loop()
|
| 92 |
+
if loop.is_running():
|
| 93 |
+
if nest_asyncio:
|
| 94 |
+
translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang))
|
| 95 |
+
if translated:
|
| 96 |
+
logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
|
| 97 |
+
return translated
|
| 98 |
+
else:
|
| 99 |
+
logger.error("Error in nested async translation: nest_asyncio not available")
|
| 100 |
+
else:
|
| 101 |
+
translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang))
|
| 102 |
+
if translated:
|
| 103 |
+
logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
|
| 104 |
+
return translated
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Gemini MCP translation error: {e}")
|
| 107 |
+
|
| 108 |
+
return text
|
| 109 |
+
|
voice.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio transcription and text-to-speech functions"""
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
import tempfile
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
from logger import logger
|
| 7 |
+
from mcp import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
|
| 8 |
+
import config
|
| 9 |
+
from models import TTS_AVAILABLE, initialize_tts_model
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import nest_asyncio
|
| 13 |
+
except ImportError:
|
| 14 |
+
nest_asyncio = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
async def transcribe_audio_gemini(audio_path: str) -> str:
|
| 18 |
+
"""Transcribe audio using Gemini MCP"""
|
| 19 |
+
if not MCP_AVAILABLE:
|
| 20 |
+
return ""
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
audio_path_abs = os.path.abspath(audio_path)
|
| 24 |
+
files = [{"path": audio_path_abs}]
|
| 25 |
+
|
| 26 |
+
system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
|
| 27 |
+
user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
|
| 28 |
+
|
| 29 |
+
result = await call_agent(
|
| 30 |
+
user_prompt=user_prompt,
|
| 31 |
+
system_prompt=system_prompt,
|
| 32 |
+
files=files,
|
| 33 |
+
model=config.GEMINI_MODEL_LITE,
|
| 34 |
+
temperature=0.2
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return result.strip()
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"Gemini transcription error: {e}")
|
| 40 |
+
return ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def transcribe_audio(audio):
|
| 44 |
+
"""Transcribe audio to text using Gemini MCP"""
|
| 45 |
+
if audio is None:
|
| 46 |
+
return ""
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
if isinstance(audio, str):
|
| 50 |
+
audio_path = audio
|
| 51 |
+
elif isinstance(audio, tuple):
|
| 52 |
+
sample_rate, audio_data = audio
|
| 53 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 54 |
+
sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
|
| 55 |
+
audio_path = tmp_file.name
|
| 56 |
+
else:
|
| 57 |
+
audio_path = audio
|
| 58 |
+
|
| 59 |
+
if MCP_AVAILABLE:
|
| 60 |
+
try:
|
| 61 |
+
loop = asyncio.get_event_loop()
|
| 62 |
+
if loop.is_running():
|
| 63 |
+
if nest_asyncio:
|
| 64 |
+
transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
|
| 65 |
+
if transcribed:
|
| 66 |
+
logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
|
| 67 |
+
return transcribed
|
| 68 |
+
else:
|
| 69 |
+
logger.error("nest_asyncio not available for nested async transcription")
|
| 70 |
+
else:
|
| 71 |
+
transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
|
| 72 |
+
if transcribed:
|
| 73 |
+
logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
|
| 74 |
+
return transcribed
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Gemini MCP transcription error: {e}")
|
| 77 |
+
|
| 78 |
+
logger.warning("Gemini MCP transcription not available")
|
| 79 |
+
return ""
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Transcription error: {e}")
|
| 82 |
+
return ""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
async def generate_speech_mcp(text: str) -> str:
|
| 86 |
+
"""Generate speech using MCP TTS tool"""
|
| 87 |
+
if not MCP_AVAILABLE:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
session = await get_mcp_session()
|
| 92 |
+
if session is None:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
tools = await get_cached_mcp_tools()
|
| 96 |
+
tts_tool = None
|
| 97 |
+
for tool in tools:
|
| 98 |
+
tool_name_lower = tool.name.lower()
|
| 99 |
+
if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
|
| 100 |
+
tts_tool = tool
|
| 101 |
+
logger.info(f"Found MCP TTS tool: {tool.name}")
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
if tts_tool:
|
| 105 |
+
result = await session.call_tool(
|
| 106 |
+
tts_tool.name,
|
| 107 |
+
arguments={"text": text, "language": "en"}
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if hasattr(result, 'content') and result.content:
|
| 111 |
+
for item in result.content:
|
| 112 |
+
if hasattr(item, 'text'):
|
| 113 |
+
if os.path.exists(item.text):
|
| 114 |
+
return item.text
|
| 115 |
+
elif hasattr(item, 'data') and item.data:
|
| 116 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 117 |
+
tmp_file.write(item.data)
|
| 118 |
+
return tmp_file.name
|
| 119 |
+
return None
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.warning(f"MCP TTS error: {e}")
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def generate_speech(text: str):
|
| 126 |
+
"""Generate speech from text using TTS model (with MCP fallback)"""
|
| 127 |
+
if not text or len(text.strip()) == 0:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
if MCP_AVAILABLE:
|
| 131 |
+
try:
|
| 132 |
+
loop = asyncio.get_event_loop()
|
| 133 |
+
if loop.is_running():
|
| 134 |
+
if nest_asyncio:
|
| 135 |
+
audio_path = nest_asyncio.run(generate_speech_mcp(text))
|
| 136 |
+
if audio_path:
|
| 137 |
+
logger.info("Generated speech via MCP")
|
| 138 |
+
return audio_path
|
| 139 |
+
else:
|
| 140 |
+
audio_path = loop.run_until_complete(generate_speech_mcp(text))
|
| 141 |
+
if audio_path:
|
| 142 |
+
return audio_path
|
| 143 |
+
except Exception as e:
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
if not TTS_AVAILABLE:
|
| 147 |
+
logger.error("TTS library not installed. Please install TTS to use voice generation.")
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
if config.global_tts_model is None:
|
| 151 |
+
initialize_tts_model()
|
| 152 |
+
|
| 153 |
+
if config.global_tts_model is None:
|
| 154 |
+
logger.error("TTS model not available. Please check dependencies.")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
wav = config.global_tts_model.tts(text)
|
| 159 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 160 |
+
sf.write(tmp_file.name, wav, samplerate=22050)
|
| 161 |
+
return tmp_file.name
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.error(f"TTS error: {e}")
|
| 164 |
+
return None
|
| 165 |
+
|