LiamKhoaLe commited on
Commit
52b4ed7
·
1 Parent(s): b720259

Refactor app.py into modular files for better scalability

Browse files
Files changed (13) hide show
  1. app.py +0 -0
  2. config.py +148 -0
  3. indexing.py +236 -0
  4. logger.py +49 -0
  5. mcp.py +194 -0
  6. models.py +77 -0
  7. pipeline.py +438 -0
  8. reasoning.py +178 -0
  9. search.py +252 -0
  10. supervisor.py +717 -0
  11. ui.py +293 -0
  12. utils.py +109 -0
  13. voice.py +165 -0
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
config.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration constants and global storage"""
2
+ import os
3
+ import threading
4
+
5
+ # Model configurations
6
+ MEDSWIN_MODELS = {
7
+ "MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
8
+ "MedSwin KD": "MedSwin/MedSwin-7B-KD",
9
+ "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
10
+ }
11
+ DEFAULT_MEDICAL_MODEL = "MedSwin TA"
12
+ EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1"
13
+ TTS_MODEL = "maya-research/maya1"
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ if not HF_TOKEN:
16
+ raise ValueError("HF_TOKEN not found in environment variables")
17
+
18
+ # Gemini MCP configuration
19
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
20
+ GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
21
+ GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite")
22
+
23
+ # MCP server configuration
24
+ script_dir = os.path.dirname(os.path.abspath(__file__))
25
+ agent_path = os.path.join(script_dir, "agent.py")
26
+ MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python")
27
+ MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path]
28
+ MCP_TOOLS_CACHE_TTL = int(os.environ.get("MCP_TOOLS_CACHE_TTL", "60"))
29
+
30
+ # Global model storage
31
+ global_medical_models = {}
32
+ global_medical_tokenizers = {}
33
+ global_file_info = {}
34
+ global_tts_model = None
35
+ global_embed_model = None
36
+
37
+ # MCP client storage
38
+ global_mcp_session = None
39
+ global_mcp_stdio_ctx = None
40
+ global_mcp_lock = threading.Lock()
41
+ global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
42
+
43
+ # UI constants
44
+ TITLE = "<h1><center>🩺 MedLLM Agent - Medical RAG & Web Search System</center></h1>"
45
+ DESCRIPTION = """
46
+ <center>
47
+ <p><strong>Advanced Medical AI Assistant</strong> powered by MedSwin models</p>
48
+ <p>📄 <strong>Document RAG:</strong> Answer based on uploaded medical documents</p>
49
+ <p>🌐 <strong>Web Search:</strong> Fetch knowledge from reliable online medical resources</p>
50
+ <p>🌍 <strong>Multi-language:</strong> Automatic translation for non-English queries</p>
51
+ <p>Upload PDF or text files to get started!</p>
52
+ </center>
53
+ """
54
+ CSS = """
55
+ .upload-section {
56
+ max-width: 400px;
57
+ margin: 0 auto;
58
+ padding: 10px;
59
+ border: 2px dashed #ccc;
60
+ border-radius: 10px;
61
+ }
62
+ .upload-button {
63
+ background: #34c759 !important;
64
+ color: white !important;
65
+ border-radius: 25px !important;
66
+ }
67
+ .chatbot-container {
68
+ margin-top: 20px;
69
+ }
70
+ .status-output {
71
+ margin-top: 10px;
72
+ font-size: 14px;
73
+ }
74
+ .processing-info {
75
+ margin-top: 5px;
76
+ font-size: 12px;
77
+ color: #666;
78
+ }
79
+ .info-container {
80
+ margin-top: 10px;
81
+ padding: 10px;
82
+ border-radius: 5px;
83
+ }
84
+ .file-list {
85
+ margin-top: 0;
86
+ max-height: 200px;
87
+ overflow-y: auto;
88
+ padding: 5px;
89
+ border: 1px solid #eee;
90
+ border-radius: 5px;
91
+ }
92
+ .stats-box {
93
+ margin-top: 10px;
94
+ padding: 10px;
95
+ border-radius: 5px;
96
+ font-size: 12px;
97
+ }
98
+ .submit-btn {
99
+ background: #1a73e8 !important;
100
+ color: white !important;
101
+ border-radius: 25px !important;
102
+ margin-left: 10px;
103
+ padding: 5px 10px;
104
+ font-size: 16px;
105
+ }
106
+ .input-row {
107
+ display: flex;
108
+ align-items: center;
109
+ }
110
+ .recording-timer {
111
+ font-size: 12px;
112
+ color: #666;
113
+ text-align: center;
114
+ margin-top: 5px;
115
+ }
116
+ .feature-badge {
117
+ display: inline-block;
118
+ padding: 3px 8px;
119
+ margin: 2px;
120
+ border-radius: 12px;
121
+ font-size: 11px;
122
+ font-weight: bold;
123
+ }
124
+ .badge-rag {
125
+ background: #e3f2fd;
126
+ color: #1976d2;
127
+ }
128
+ .badge-web {
129
+ background: #f3e5f5;
130
+ color: #7b1fa2;
131
+ }
132
+ @media (min-width: 768px) {
133
+ .main-container {
134
+ display: flex;
135
+ justify-content: space-between;
136
+ gap: 20px;
137
+ }
138
+ .upload-section {
139
+ flex: 1;
140
+ max-width: 300px;
141
+ }
142
+ .chatbot-container {
143
+ flex: 2;
144
+ margin-top: 0;
145
+ }
146
+ }
147
+ """
148
+
indexing.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document parsing and indexing functions"""
2
+ import os
3
+ import base64
4
+ import asyncio
5
+ import tempfile
6
+ import time
7
+ import gradio as gr
8
+ import spaces
9
+ from llama_index.core import (
10
+ StorageContext,
11
+ VectorStoreIndex,
12
+ load_index_from_storage,
13
+ Document as LlamaDocument,
14
+ )
15
+ from llama_index.core import Settings
16
+ from llama_index.core.node_parser import (
17
+ HierarchicalNodeParser,
18
+ get_leaf_nodes,
19
+ get_root_nodes,
20
+ )
21
+ from llama_index.core.storage.docstore import SimpleDocumentStore
22
+ from tqdm import tqdm
23
+ from logger import logger
24
+ from mcp import MCP_AVAILABLE, call_agent
25
+ import config
26
+ from models import get_llm_for_rag, get_or_create_embed_model
27
+
28
+ try:
29
+ import nest_asyncio
30
+ except ImportError:
31
+ nest_asyncio = None
32
+
33
+
34
+ async def parse_document_gemini(file_path: str, file_extension: str) -> str:
35
+ """Parse document using Gemini MCP"""
36
+ if not MCP_AVAILABLE:
37
+ return ""
38
+
39
+ try:
40
+ with open(file_path, 'rb') as f:
41
+ file_content = base64.b64encode(f.read()).decode('utf-8')
42
+
43
+ mime_type_map = {
44
+ '.pdf': 'application/pdf',
45
+ '.doc': 'application/msword',
46
+ '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
47
+ '.txt': 'text/plain',
48
+ '.md': 'text/markdown',
49
+ '.json': 'application/json',
50
+ '.xml': 'application/xml',
51
+ '.csv': 'text/csv'
52
+ }
53
+ mime_type = mime_type_map.get(file_extension, 'application/octet-stream')
54
+
55
+ files = [{
56
+ "content": file_content,
57
+ "type": mime_type
58
+ }]
59
+
60
+ system_prompt = "Extract all text content from the document accurately."
61
+ user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible."
62
+
63
+ result = await call_agent(
64
+ user_prompt=user_prompt,
65
+ system_prompt=system_prompt,
66
+ files=files,
67
+ model=config.GEMINI_MODEL_LITE,
68
+ temperature=0.2
69
+ )
70
+
71
+ return result.strip()
72
+ except Exception as e:
73
+ logger.error(f"Gemini document parsing error: {e}")
74
+ return ""
75
+
76
+
77
+ def extract_text_from_document(file):
78
+ """Extract text from document using Gemini MCP"""
79
+ file_name = file.name
80
+ file_extension = os.path.splitext(file_name)[1].lower()
81
+
82
+ if file_extension == '.txt':
83
+ text = file.read().decode('utf-8')
84
+ return text, len(text.split()), None
85
+
86
+ try:
87
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
88
+ file.seek(0)
89
+ tmp_file.write(file.read())
90
+ tmp_file_path = tmp_file.name
91
+
92
+ if MCP_AVAILABLE:
93
+ try:
94
+ loop = asyncio.get_event_loop()
95
+ if loop.is_running():
96
+ if nest_asyncio:
97
+ text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension))
98
+ else:
99
+ logger.error("Error in nested async document parsing: nest_asyncio not available")
100
+ text = ""
101
+ else:
102
+ text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension))
103
+
104
+ try:
105
+ os.unlink(tmp_file_path)
106
+ except:
107
+ pass
108
+
109
+ if text:
110
+ return text, len(text.split()), None
111
+ else:
112
+ return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP")
113
+ except Exception as e:
114
+ logger.error(f"Gemini MCP document parsing error: {e}")
115
+ try:
116
+ os.unlink(tmp_file_path)
117
+ except:
118
+ pass
119
+ return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}")
120
+ else:
121
+ try:
122
+ os.unlink(tmp_file_path)
123
+ except:
124
+ pass
125
+ return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.")
126
+ except Exception as e:
127
+ logger.error(f"Error processing document: {e}")
128
+ return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
129
+
130
+
131
+ @spaces.GPU(max_duration=120)
132
+ def create_or_update_index(files, request: gr.Request):
133
+ """Create or update RAG index from uploaded files"""
134
+ if not files:
135
+ return "Please provide files.", ""
136
+
137
+ start_time = time.time()
138
+ user_id = request.session_hash
139
+ save_dir = f"./{user_id}_index"
140
+
141
+ llm = get_llm_for_rag()
142
+ embed_model = get_or_create_embed_model()
143
+ Settings.llm = llm
144
+ Settings.embed_model = embed_model
145
+ file_stats = []
146
+ new_documents = []
147
+
148
+ for file in tqdm(files, desc="Processing files"):
149
+ file_basename = os.path.basename(file.name)
150
+ text, word_count, error = extract_text_from_document(file)
151
+ if error:
152
+ logger.error(f"Error processing file {file_basename}: {str(error)}")
153
+ file_stats.append({
154
+ "name": file_basename,
155
+ "words": 0,
156
+ "status": f"error: {str(error)}"
157
+ })
158
+ continue
159
+
160
+ doc = LlamaDocument(
161
+ text=text,
162
+ metadata={
163
+ "file_name": file_basename,
164
+ "word_count": word_count,
165
+ "source": "user_upload"
166
+ }
167
+ )
168
+ new_documents.append(doc)
169
+
170
+ file_stats.append({
171
+ "name": file_basename,
172
+ "words": word_count,
173
+ "status": "processed"
174
+ })
175
+
176
+ config.global_file_info[file_basename] = {
177
+ "word_count": word_count,
178
+ "processed_at": time.time()
179
+ }
180
+
181
+ node_parser = HierarchicalNodeParser.from_defaults(
182
+ chunk_sizes=[2048, 512, 128],
183
+ chunk_overlap=20
184
+ )
185
+ logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
186
+ new_nodes = node_parser.get_nodes_from_documents(new_documents)
187
+ new_leaf_nodes = get_leaf_nodes(new_nodes)
188
+ new_root_nodes = get_root_nodes(new_nodes)
189
+ logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
190
+
191
+ if os.path.exists(save_dir):
192
+ logger.info(f"Loading existing index from {save_dir}")
193
+ storage_context = StorageContext.from_defaults(persist_dir=save_dir)
194
+ index = load_index_from_storage(storage_context, settings=Settings)
195
+ docstore = storage_context.docstore
196
+
197
+ docstore.add_documents(new_nodes)
198
+ for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
199
+ index.insert_nodes([node])
200
+
201
+ total_docs = len(docstore.docs)
202
+ logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
203
+ else:
204
+ logger.info("Creating new index")
205
+ docstore = SimpleDocumentStore()
206
+ storage_context = StorageContext.from_defaults(docstore=docstore)
207
+ docstore.add_documents(new_nodes)
208
+
209
+ index = VectorStoreIndex(
210
+ new_leaf_nodes,
211
+ storage_context=storage_context,
212
+ settings=Settings
213
+ )
214
+ total_docs = len(new_documents)
215
+ logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
216
+
217
+ index.storage_context.persist(persist_dir=save_dir)
218
+
219
+ file_list_html = "<div class='file-list'>"
220
+ for stat in file_stats:
221
+ status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
222
+ file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
223
+ file_list_html += "</div>"
224
+ processing_time = time.time() - start_time
225
+ stats_output = f"<div class='stats-box'>"
226
+ stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
227
+ stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
228
+ stats_output += f"✓ Total documents in index: {total_docs}<br>"
229
+ stats_output += f"✓ Index saved to: {save_dir}<br>"
230
+ stats_output += "</div>"
231
+ output_container = f"<div class='info-container'>"
232
+ output_container += file_list_html
233
+ output_container += stats_output
234
+ output_container += "</div>"
235
+ return f"Successfully indexed {len(files)} files.", output_container
236
+
logger.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging configuration and custom handlers"""
2
+ import logging
3
+ import threading
4
+ from transformers import logging as hf_logging
5
+
6
+ # Set logging to INFO level for cleaner output
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Custom logger handler to capture agentic thoughts
11
+ class ThoughtCaptureHandler(logging.Handler):
12
+ """Custom handler to capture internal thoughts from MedSwin and supervisor"""
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.thoughts = []
16
+ self.lock = threading.Lock()
17
+
18
+ def emit(self, record):
19
+ """Capture log messages that contain agentic thoughts"""
20
+ try:
21
+ msg = self.format(record)
22
+ # Only capture messages from GEMINI SUPERVISOR or MEDSWIN
23
+ if "[GEMINI SUPERVISOR]" in msg or "[MEDSWIN]" in msg or "[MAC]" in msg:
24
+ # Remove timestamp and logger name for cleaner display
25
+ parts = msg.split(" - ", 3)
26
+ if len(parts) >= 4:
27
+ clean_msg = parts[-1]
28
+ else:
29
+ clean_msg = msg
30
+ with self.lock:
31
+ self.thoughts.append(clean_msg)
32
+ except Exception:
33
+ pass
34
+
35
+ def get_thoughts(self):
36
+ """Get all captured thoughts as a formatted string"""
37
+ with self.lock:
38
+ return "\n".join(self.thoughts)
39
+
40
+ def clear(self):
41
+ """Clear captured thoughts"""
42
+ with self.lock:
43
+ self.thoughts = []
44
+
45
+ # Set MCP client logging to WARNING to reduce noise
46
+ mcp_client_logger = logging.getLogger("mcp.client")
47
+ mcp_client_logger.setLevel(logging.WARNING)
48
+ hf_logging.set_verbosity_error()
49
+
mcp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MCP session management and tool caching"""
2
+ import os
3
+ import time
4
+ import asyncio
5
+ from logger import logger
6
+ import config
7
+
8
+ # MCP imports
9
+ MCP_CLIENT_INFO = None
10
+ try:
11
+ from mcp import ClientSession, StdioServerParameters
12
+ from mcp import types as mcp_types
13
+ from mcp.client.stdio import stdio_client
14
+ try:
15
+ import nest_asyncio
16
+ nest_asyncio.apply()
17
+ except ImportError:
18
+ pass
19
+ MCP_AVAILABLE = True
20
+ MCP_CLIENT_INFO = mcp_types.Implementation(
21
+ name="MedLLM-Agent",
22
+ version=os.environ.get("SPACE_VERSION", "local"),
23
+ )
24
+ except ImportError as e:
25
+ logger.warning(f"MCP SDK not available: {e}")
26
+ MCP_AVAILABLE = False
27
+ MCP_CLIENT_INFO = None
28
+
29
+
30
+ async def get_mcp_session():
31
+ """Get or create MCP client session with proper context management"""
32
+ if not MCP_AVAILABLE:
33
+ logger.warning("MCP not available - SDK not installed")
34
+ return None
35
+
36
+ if config.global_mcp_session is not None:
37
+ return config.global_mcp_session
38
+
39
+ try:
40
+ mcp_env = os.environ.copy()
41
+ if config.GEMINI_API_KEY:
42
+ mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
43
+ else:
44
+ logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
45
+
46
+ if os.environ.get("GEMINI_MODEL"):
47
+ mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
48
+ if os.environ.get("GEMINI_TIMEOUT"):
49
+ mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
50
+ if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
51
+ mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
52
+ if os.environ.get("GEMINI_TEMPERATURE"):
53
+ mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
54
+
55
+ logger.info("Creating MCP client session...")
56
+
57
+ server_params = StdioServerParameters(
58
+ command=config.MCP_SERVER_COMMAND,
59
+ args=config.MCP_SERVER_ARGS,
60
+ env=mcp_env
61
+ )
62
+
63
+ stdio_ctx = stdio_client(server_params)
64
+ read, write = await stdio_ctx.__aenter__()
65
+
66
+ session = ClientSession(
67
+ read,
68
+ write,
69
+ client_info=MCP_CLIENT_INFO,
70
+ )
71
+
72
+ try:
73
+ await session.__aenter__()
74
+ init_result = await session.initialize()
75
+ server_info = getattr(init_result, "serverInfo", None)
76
+ server_name = getattr(server_info, "name", "unknown")
77
+ server_version = getattr(server_info, "version", "unknown")
78
+ logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})")
79
+ except Exception as e:
80
+ error_msg = str(e)
81
+ error_type = type(e).__name__
82
+ logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
83
+ try:
84
+ await session.__aexit__(None, None, None)
85
+ except Exception:
86
+ pass
87
+ try:
88
+ await stdio_ctx.__aexit__(None, None, None)
89
+ except Exception:
90
+ pass
91
+ return None
92
+
93
+ config.global_mcp_session = session
94
+ config.global_mcp_stdio_ctx = stdio_ctx
95
+ logger.info("✅ MCP client session created successfully")
96
+ return session
97
+ except Exception as e:
98
+ error_type = type(e).__name__
99
+ error_msg = str(e)
100
+ logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
101
+ config.global_mcp_session = None
102
+ config.global_mcp_stdio_ctx = None
103
+ return None
104
+
105
+
106
+ def invalidate_mcp_tools_cache():
107
+ """Invalidate cached MCP tool metadata"""
108
+ config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
109
+
110
+
111
+ async def get_cached_mcp_tools(force_refresh: bool = False):
112
+ """Return cached MCP tools list to avoid repeated list_tools calls"""
113
+ if not MCP_AVAILABLE:
114
+ return []
115
+
116
+ now = time.time()
117
+ if (
118
+ not force_refresh
119
+ and config.global_mcp_tools_cache["tools"]
120
+ and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
121
+ ):
122
+ return config.global_mcp_tools_cache["tools"]
123
+
124
+ session = await get_mcp_session()
125
+ if session is None:
126
+ return []
127
+
128
+ try:
129
+ tools_resp = await session.list_tools()
130
+ tools_list = list(getattr(tools_resp, "tools", []) or [])
131
+ config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
132
+ return tools_list
133
+ except Exception as e:
134
+ logger.error(f"Failed to refresh MCP tools: {e}")
135
+ invalidate_mcp_tools_cache()
136
+ return []
137
+
138
+
139
+ async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
140
+ """Call Gemini MCP generate_content tool"""
141
+ if not MCP_AVAILABLE:
142
+ logger.warning("MCP not available for Gemini call")
143
+ return ""
144
+
145
+ try:
146
+ session = await get_mcp_session()
147
+ if session is None:
148
+ logger.warning("Failed to get MCP session for Gemini call")
149
+ return ""
150
+
151
+ tools = await get_cached_mcp_tools()
152
+ if not tools:
153
+ tools = await get_cached_mcp_tools(force_refresh=True)
154
+ if not tools:
155
+ logger.error("Unable to obtain MCP tool catalog for Gemini calls")
156
+ return ""
157
+
158
+ generate_tool = None
159
+ for tool in tools:
160
+ if tool.name == "generate_content" or "generate_content" in tool.name.lower():
161
+ generate_tool = tool
162
+ logger.info(f"Found Gemini MCP tool: {tool.name}")
163
+ break
164
+
165
+ if not generate_tool:
166
+ logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
167
+ invalidate_mcp_tools_cache()
168
+ return ""
169
+
170
+ arguments = {
171
+ "user_prompt": user_prompt
172
+ }
173
+ if system_prompt:
174
+ arguments["system_prompt"] = system_prompt
175
+ if files:
176
+ arguments["files"] = files
177
+ if model:
178
+ arguments["model"] = model
179
+ if temperature is not None:
180
+ arguments["temperature"] = temperature
181
+
182
+ result = await session.call_tool(generate_tool.name, arguments=arguments)
183
+
184
+ if hasattr(result, 'content') and result.content:
185
+ for item in result.content:
186
+ if hasattr(item, 'text'):
187
+ response_text = item.text.strip()
188
+ return response_text
189
+ logger.warning("⚠️ Gemini MCP returned empty or invalid result")
190
+ return ""
191
+ except Exception as e:
192
+ logger.error(f"Gemini MCP call error: {e}")
193
+ return ""
194
+
models.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model initialization and management"""
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from llama_index.llms.huggingface import HuggingFaceLLM
5
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from logger import logger
7
+ import config
8
+
9
+ try:
10
+ from TTS.api import TTS
11
+ TTS_AVAILABLE = True
12
+ except ImportError:
13
+ TTS_AVAILABLE = False
14
+ TTS = None
15
+
16
+
17
+ def initialize_medical_model(model_name: str):
18
+ """Initialize medical model (MedSwin) - download on demand"""
19
+ if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
20
+ logger.info(f"Initializing medical model: {model_name}...")
21
+ model_path = config.MEDSWIN_MODELS[model_name]
22
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_path,
25
+ device_map="auto",
26
+ trust_remote_code=True,
27
+ token=config.HF_TOKEN,
28
+ torch_dtype=torch.float16
29
+ )
30
+ config.global_medical_models[model_name] = model
31
+ config.global_medical_tokenizers[model_name] = tokenizer
32
+ logger.info(f"Medical model {model_name} initialized successfully")
33
+ return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
34
+
35
+
36
+ def initialize_tts_model():
37
+ """Initialize TTS model for text-to-speech"""
38
+ if not TTS_AVAILABLE:
39
+ logger.warning("TTS library not installed. TTS features will be disabled.")
40
+ return None
41
+ if config.global_tts_model is None:
42
+ try:
43
+ logger.info("Initializing TTS model for voice generation...")
44
+ config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
45
+ logger.info("TTS model initialized successfully")
46
+ except Exception as e:
47
+ logger.warning(f"TTS model initialization failed: {e}")
48
+ logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
49
+ config.global_tts_model = None
50
+ return config.global_tts_model
51
+
52
+
53
+ def get_or_create_embed_model():
54
+ """Reuse embedding model to avoid reloading weights each request"""
55
+ if config.global_embed_model is None:
56
+ logger.info("Initializing shared embedding model for RAG retrieval...")
57
+ config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
58
+ return config.global_embed_model
59
+
60
+
61
+ def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
62
+ """Get LLM for RAG indexing (uses medical model)"""
63
+ medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
64
+
65
+ return HuggingFaceLLM(
66
+ context_window=4096,
67
+ max_new_tokens=max_new_tokens,
68
+ tokenizer=medical_tokenizer,
69
+ model=medical_model_obj,
70
+ generate_kwargs={
71
+ "do_sample": True,
72
+ "temperature": temperature,
73
+ "top_k": top_k,
74
+ "top_p": top_p
75
+ }
76
+ )
77
+
pipeline.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main chat pipeline - stream_chat function"""
2
+ import os
3
+ import json
4
+ import time
5
+ import logging
6
+ import concurrent.futures
7
+ import gradio as gr
8
+ import spaces
9
+ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
10
+ from llama_index.core import Settings
11
+ from llama_index.core.retrievers import AutoMergingRetriever
12
+ from logger import logger, ThoughtCaptureHandler
13
+ from models import initialize_medical_model, get_or_create_embed_model
14
+ from utils import detect_language, translate_text, format_url_as_domain
15
+ from search import search_web, summarize_web_content
16
+ from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
17
+ from supervisor import (
18
+ gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
19
+ gemini_supervisor_rag_brainstorm, execute_medswin_task,
20
+ gemini_supervisor_synthesize, gemini_supervisor_challenge,
21
+ gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity
22
+ )
23
+
24
+
25
+ @spaces.GPU(max_duration=120)
26
+ def stream_chat(
27
+ message: str,
28
+ history: list,
29
+ system_prompt: str,
30
+ temperature: float,
31
+ max_new_tokens: int,
32
+ top_p: float,
33
+ top_k: int,
34
+ penalty: float,
35
+ retriever_k: int,
36
+ merge_threshold: float,
37
+ use_rag: bool,
38
+ medical_model: str,
39
+ use_web_search: bool,
40
+ disable_agentic_reasoning: bool,
41
+ show_thoughts: bool,
42
+ request: gr.Request
43
+ ):
44
+ """Main chat pipeline implementing MAC architecture"""
45
+ if not request:
46
+ yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}], ""
47
+ return
48
+
49
+ thought_handler = None
50
+ if show_thoughts:
51
+ thought_handler = ThoughtCaptureHandler()
52
+ thought_handler.setLevel(logging.INFO)
53
+ thought_handler.clear()
54
+ logger.addHandler(thought_handler)
55
+
56
+ session_start = time.time()
57
+ soft_timeout = 100
58
+ hard_timeout = 118
59
+
60
+ def elapsed():
61
+ return time.time() - session_start
62
+
63
+ user_id = request.session_hash
64
+ index_dir = f"./{user_id}_index"
65
+ has_rag_index = os.path.exists(index_dir)
66
+
67
+ original_lang = detect_language(message)
68
+ original_message = message
69
+ needs_translation = original_lang != "en"
70
+
71
+ pipeline_diagnostics = {
72
+ "reasoning": None,
73
+ "plan": None,
74
+ "strategy_decisions": [],
75
+ "stage_metrics": {},
76
+ "search": {"strategies": [], "total_results": 0}
77
+ }
78
+
79
+ def record_stage(stage_name: str, start_time: float):
80
+ pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
81
+
82
+ translation_stage_start = time.time()
83
+ if needs_translation:
84
+ logger.info(f"[GEMINI SUPERVISOR] Detected non-English language: {original_lang}, translating...")
85
+ message = translate_text(message, target_lang="en", source_lang=original_lang)
86
+ logger.info(f"[GEMINI SUPERVISOR] Translated query: {message[:100]}...")
87
+ record_stage("translation", translation_stage_start)
88
+
89
+ final_use_rag = use_rag and has_rag_index and not disable_agentic_reasoning
90
+ final_use_web_search = use_web_search and not disable_agentic_reasoning
91
+
92
+ plan = None
93
+ if not disable_agentic_reasoning:
94
+ reasoning_stage_start = time.time()
95
+ reasoning = autonomous_reasoning(message, history)
96
+ record_stage("autonomous_reasoning", reasoning_stage_start)
97
+ pipeline_diagnostics["reasoning"] = reasoning
98
+ plan = create_execution_plan(reasoning, message, has_rag_index)
99
+ pipeline_diagnostics["plan"] = plan
100
+ execution_strategy = autonomous_execution_strategy(
101
+ reasoning, plan, final_use_rag, final_use_web_search, has_rag_index
102
+ )
103
+
104
+ if final_use_rag and not reasoning.get("requires_rag", True):
105
+ final_use_rag = False
106
+ pipeline_diagnostics["strategy_decisions"].append("Skipped RAG per autonomous reasoning")
107
+ elif not final_use_rag and reasoning.get("requires_rag", True) and not has_rag_index:
108
+ pipeline_diagnostics["strategy_decisions"].append("Reasoning wanted RAG but no index available")
109
+
110
+ if final_use_web_search and not reasoning.get("requires_web_search", False):
111
+ final_use_web_search = False
112
+ pipeline_diagnostics["strategy_decisions"].append("Skipped web search per autonomous reasoning")
113
+ elif not final_use_web_search and reasoning.get("requires_web_search", False):
114
+ if not use_web_search:
115
+ pipeline_diagnostics["strategy_decisions"].append("User disabled web search despite reasoning request")
116
+ else:
117
+ pipeline_diagnostics["strategy_decisions"].append("Web search requested by reasoning but disabled by mode")
118
+ else:
119
+ pipeline_diagnostics["strategy_decisions"].append("Agentic reasoning disabled by user")
120
+
121
+ if disable_agentic_reasoning:
122
+ logger.info("[MAC] Agentic reasoning disabled - using MedSwin alone")
123
+ breakdown = {
124
+ "sub_topics": [
125
+ {"id": 1, "topic": "Answer", "instruction": message, "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
126
+ ],
127
+ "strategy": "Direct answer",
128
+ "exploration_note": "Direct mode - no breakdown"
129
+ }
130
+ else:
131
+ logger.info("[GEMINI SUPERVISOR] Breaking query into sub-topics...")
132
+ breakdown = gemini_supervisor_breakdown(message, final_use_rag, final_use_web_search, elapsed(), max_duration=120)
133
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(breakdown.get('sub_topics', []))} sub-topics")
134
+
135
+ search_contexts = []
136
+ web_urls = []
137
+ if final_use_web_search:
138
+ search_stage_start = time.time()
139
+ logger.info("[GEMINI SUPERVISOR] Search mode: Creating search strategies...")
140
+ search_strategies = gemini_supervisor_search_strategies(message, elapsed())
141
+
142
+ all_search_results = []
143
+ strategy_jobs = []
144
+ for strategy in search_strategies.get("search_strategies", [])[:4]:
145
+ search_query = strategy.get("strategy", message)
146
+ target_sources = strategy.get("target_sources", 2)
147
+ strategy_jobs.append({
148
+ "query": search_query,
149
+ "target_sources": target_sources,
150
+ "meta": strategy
151
+ })
152
+
153
+ def execute_search(job):
154
+ job_start = time.time()
155
+ try:
156
+ results = search_web(job["query"], max_results=job["target_sources"])
157
+ duration = time.time() - job_start
158
+ return results, duration, None
159
+ except Exception as exc:
160
+ return [], time.time() - job_start, exc
161
+
162
+ def record_search_diag(job, duration, results_count, error=None):
163
+ entry = {
164
+ "query": job["query"],
165
+ "target_sources": job["target_sources"],
166
+ "duration": round(duration, 3),
167
+ "results": results_count
168
+ }
169
+ if error:
170
+ entry["error"] = str(error)
171
+ pipeline_diagnostics["search"]["strategies"].append(entry)
172
+
173
+ if strategy_jobs:
174
+ max_workers = min(len(strategy_jobs), 4)
175
+ if len(strategy_jobs) > 1:
176
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
177
+ future_map = {executor.submit(execute_search, job): job for job in strategy_jobs}
178
+ for future in concurrent.futures.as_completed(future_map):
179
+ job = future_map[future]
180
+ try:
181
+ results, duration, error = future.result()
182
+ except Exception as exc:
183
+ results, duration, error = [], 0.0, exc
184
+ record_search_diag(job, duration, len(results), error)
185
+ if not error and results:
186
+ all_search_results.extend(results)
187
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
188
+ else:
189
+ job = strategy_jobs[0]
190
+ results, duration, error = execute_search(job)
191
+ record_search_diag(job, duration, len(results), error)
192
+ if not error and results:
193
+ all_search_results.extend(results)
194
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
195
+ else:
196
+ pipeline_diagnostics["strategy_decisions"].append("No viable web search strategies returned")
197
+
198
+ pipeline_diagnostics["search"]["total_results"] = len(all_search_results)
199
+
200
+ if all_search_results:
201
+ logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(all_search_results)} search results...")
202
+ search_summary = summarize_web_content(all_search_results, message)
203
+ if search_summary:
204
+ search_contexts.append(search_summary)
205
+ logger.info(f"[GEMINI SUPERVISOR] Search summary created: {len(search_summary)} chars")
206
+ record_stage("web_search", search_stage_start)
207
+
208
+ rag_contexts = []
209
+ if final_use_rag and has_rag_index:
210
+ rag_stage_start = time.time()
211
+ if elapsed() >= soft_timeout - 10:
212
+ logger.warning("[GEMINI SUPERVISOR] Skipping RAG due to time pressure")
213
+ final_use_rag = False
214
+ else:
215
+ logger.info("[GEMINI SUPERVISOR] RAG mode: Retrieving documents...")
216
+ embed_model = get_or_create_embed_model()
217
+ Settings.embed_model = embed_model
218
+ storage_context = StorageContext.from_defaults(persist_dir=index_dir)
219
+ index = load_index_from_storage(storage_context, settings=Settings)
220
+ base_retriever = index.as_retriever(similarity_top_k=retriever_k)
221
+ auto_merging_retriever = AutoMergingRetriever(
222
+ base_retriever,
223
+ storage_context=storage_context,
224
+ simple_ratio_thresh=merge_threshold,
225
+ verbose=False
226
+ )
227
+ merged_nodes = auto_merging_retriever.retrieve(message)
228
+ retrieved_docs = "\n\n".join([n.node.text for n in merged_nodes])
229
+ logger.info(f"[GEMINI SUPERVISOR] Retrieved {len(merged_nodes)} document nodes")
230
+
231
+ logger.info("[GEMINI SUPERVISOR] Brainstorming RAG contexts...")
232
+ rag_brainstorm = gemini_supervisor_rag_brainstorm(message, retrieved_docs, elapsed())
233
+ rag_contexts = [ctx.get("context", "") for ctx in rag_brainstorm.get("contexts", [])]
234
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(rag_contexts)} RAG contexts")
235
+ record_stage("rag_retrieval", rag_stage_start)
236
+
237
+ medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
238
+
239
+ base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
240
+
241
+ combined_context = ""
242
+ if rag_contexts:
243
+ combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4])
244
+ if search_contexts:
245
+ if combined_context:
246
+ combined_context += "\n\n"
247
+ combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
248
+
249
+ logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
250
+ medswin_answers = []
251
+
252
+ updated_history = history + [
253
+ {"role": "user", "content": original_message},
254
+ {"role": "assistant", "content": ""}
255
+ ]
256
+ thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
257
+ yield updated_history, thoughts_text
258
+
259
+ medswin_stage_start = time.time()
260
+ for idx, sub_topic in enumerate(breakdown.get("sub_topics", []), 1):
261
+ if elapsed() >= hard_timeout - 5:
262
+ logger.warning(f"[MEDSWIN] Time limit approaching, stopping at task {idx}")
263
+ break
264
+
265
+ task_instruction = sub_topic.get("instruction", "")
266
+ topic_name = sub_topic.get("topic", f"Topic {idx}")
267
+ priority = sub_topic.get("priority", "medium")
268
+
269
+ logger.info(f"[MEDSWIN] Executing task {idx}/{len(breakdown.get('sub_topics', []))}: {topic_name} (priority: {priority})")
270
+
271
+ task_context = combined_context
272
+ if len(rag_contexts) > 1 and idx <= len(rag_contexts):
273
+ task_context = rag_contexts[idx - 1] if idx <= len(rag_contexts) else combined_context
274
+
275
+ try:
276
+ task_answer = execute_medswin_task(
277
+ medical_model_obj=medical_model_obj,
278
+ medical_tokenizer=medical_tokenizer,
279
+ task_instruction=task_instruction,
280
+ context=task_context if task_context else "",
281
+ system_prompt_base=base_system_prompt,
282
+ temperature=temperature,
283
+ max_new_tokens=min(max_new_tokens, 800),
284
+ top_p=top_p,
285
+ top_k=top_k,
286
+ penalty=penalty
287
+ )
288
+
289
+ formatted_answer = f"## {topic_name}\n\n{task_answer}"
290
+ medswin_answers.append(formatted_answer)
291
+ logger.info(f"[MEDSWIN] Task {idx} completed: {len(task_answer)} chars")
292
+
293
+ partial_final = "\n\n".join(medswin_answers)
294
+ updated_history[-1]["content"] = partial_final
295
+ thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
296
+ yield updated_history, thoughts_text
297
+
298
+ except Exception as e:
299
+ logger.error(f"[MEDSWIN] Task {idx} failed: {e}")
300
+ continue
301
+ record_stage("medswin_tasks", medswin_stage_start)
302
+
303
+ logger.info("[GEMINI SUPERVISOR] Synthesizing final answer from all MedSwin responses...")
304
+ raw_medswin_answers = [ans.split('\n\n', 1)[1] if '\n\n' in ans else ans for ans in medswin_answers]
305
+ synthesis_stage_start = time.time()
306
+ final_answer = gemini_supervisor_synthesize(message, raw_medswin_answers, rag_contexts, search_contexts, breakdown)
307
+ record_stage("synthesis", synthesis_stage_start)
308
+
309
+ if not final_answer or len(final_answer.strip()) < 50:
310
+ logger.warning("[GEMINI SUPERVISOR] Synthesis failed or too short, using concatenation")
311
+ final_answer = "\n\n".join(medswin_answers) if medswin_answers else "I apologize, but I was unable to generate a response."
312
+
313
+ if "|" in final_answer and "---" in final_answer:
314
+ logger.warning("[MEDSWIN] Final answer contains tables, converting to bullets")
315
+ lines = final_answer.split('\n')
316
+ cleaned_lines = []
317
+ for line in lines:
318
+ if '|' in line and '---' not in line:
319
+ cells = [cell.strip() for cell in line.split('|') if cell.strip()]
320
+ if cells:
321
+ cleaned_lines.append(f"- {' / '.join(cells)}")
322
+ elif '---' not in line:
323
+ cleaned_lines.append(line)
324
+ final_answer = '\n'.join(cleaned_lines)
325
+
326
+ max_challenge_iterations = 2
327
+ challenge_iteration = 0
328
+ challenge_stage_start = time.time()
329
+
330
+ while challenge_iteration < max_challenge_iterations and elapsed() < soft_timeout - 15:
331
+ challenge_iteration += 1
332
+ logger.info(f"[GEMINI SUPERVISOR] Challenge iteration {challenge_iteration}/{max_challenge_iterations}...")
333
+
334
+ evaluation = gemini_supervisor_challenge(message, final_answer, raw_medswin_answers, rag_contexts, search_contexts)
335
+
336
+ if evaluation.get("is_optimal", False):
337
+ logger.info(f"[GEMINI SUPERVISOR] Answer confirmed optimal after {challenge_iteration} iteration(s)")
338
+ break
339
+
340
+ enhancement_instructions = evaluation.get("enhancement_instructions", "")
341
+ if not enhancement_instructions:
342
+ logger.info("[GEMINI SUPERVISOR] No enhancement instructions, considering answer optimal")
343
+ break
344
+
345
+ logger.info(f"[GEMINI SUPERVISOR] Enhancing answer based on feedback...")
346
+ enhanced_answer = gemini_supervisor_enhance_answer(
347
+ message, final_answer, enhancement_instructions, raw_medswin_answers, rag_contexts, search_contexts
348
+ )
349
+
350
+ if enhanced_answer and len(enhanced_answer.strip()) > len(final_answer.strip()) * 0.8:
351
+ final_answer = enhanced_answer
352
+ logger.info(f"[GEMINI SUPERVISOR] Answer enhanced (new length: {len(final_answer)} chars)")
353
+ else:
354
+ logger.info("[GEMINI SUPERVISOR] Enhancement did not improve answer significantly, stopping")
355
+ break
356
+ record_stage("challenge_loop", challenge_stage_start)
357
+
358
+ if final_use_web_search and elapsed() < soft_timeout - 10:
359
+ logger.info("[GEMINI SUPERVISOR] Checking if additional search is needed...")
360
+ clarity_stage_start = time.time()
361
+ clarity_check = gemini_supervisor_check_clarity(message, final_answer, final_use_web_search)
362
+ record_stage("clarity_check", clarity_stage_start)
363
+
364
+ if clarity_check.get("needs_search", False) and clarity_check.get("search_queries"):
365
+ logger.info(f"[GEMINI SUPERVISOR] Triggering additional search: {clarity_check.get('search_queries', [])}")
366
+ additional_search_results = []
367
+ followup_stage_start = time.time()
368
+ for search_query in clarity_check.get("search_queries", [])[:3]:
369
+ if elapsed() >= soft_timeout - 5:
370
+ break
371
+ extra_start = time.time()
372
+ results = search_web(search_query, max_results=2)
373
+ extra_duration = time.time() - extra_start
374
+ pipeline_diagnostics["search"]["strategies"].append({
375
+ "query": search_query,
376
+ "target_sources": 2,
377
+ "duration": round(extra_duration, 3),
378
+ "results": len(results),
379
+ "type": "followup"
380
+ })
381
+ additional_search_results.extend(results)
382
+ web_urls.extend([r.get('url', '') for r in results if r.get('url')])
383
+
384
+ if additional_search_results:
385
+ pipeline_diagnostics["search"]["total_results"] += len(additional_search_results)
386
+ logger.info(f"[GEMINI SUPERVISOR] Summarizing {len(additional_search_results)} additional search results...")
387
+ additional_summary = summarize_web_content(additional_search_results, message)
388
+ if additional_summary:
389
+ search_contexts.append(additional_summary)
390
+ logger.info("[GEMINI SUPERVISOR] Enhancing answer with additional search context...")
391
+ enhanced_with_search = gemini_supervisor_enhance_answer(
392
+ message, final_answer,
393
+ f"Incorporate the following additional information from web search: {additional_summary}",
394
+ raw_medswin_answers, rag_contexts, search_contexts
395
+ )
396
+ if enhanced_with_search and len(enhanced_with_search.strip()) > 50:
397
+ final_answer = enhanced_with_search
398
+ logger.info("[GEMINI SUPERVISOR] Answer enhanced with additional search context")
399
+ record_stage("followup_search", followup_stage_start)
400
+
401
+ citations_text = ""
402
+
403
+ if needs_translation and final_answer:
404
+ logger.info(f"[GEMINI SUPERVISOR] Translating response back to {original_lang}...")
405
+ final_answer = translate_text(final_answer, target_lang=original_lang, source_lang="en")
406
+
407
+ if web_urls:
408
+ unique_urls = list(dict.fromkeys(web_urls))
409
+ citation_links = []
410
+ for url in unique_urls[:5]:
411
+ domain = format_url_as_domain(url)
412
+ if domain:
413
+ citation_links.append(f"[{domain}]({url})")
414
+
415
+ if citation_links:
416
+ citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
417
+
418
+ speaker_icon = ' 🔊'
419
+ final_answer_with_metadata = final_answer + citations_text + speaker_icon
420
+
421
+ updated_history[-1]["content"] = final_answer_with_metadata
422
+ thoughts_text = thought_handler.get_thoughts() if thought_handler else ""
423
+ yield updated_history, thoughts_text
424
+
425
+ if thought_handler:
426
+ logger.removeHandler(thought_handler)
427
+
428
+ diag_summary = {
429
+ "stage_metrics": pipeline_diagnostics["stage_metrics"],
430
+ "decisions": pipeline_diagnostics["strategy_decisions"],
431
+ "search": pipeline_diagnostics["search"],
432
+ }
433
+ try:
434
+ logger.info(f"[MAC] Diagnostics summary: {json.dumps(diag_summary)[:1200]}")
435
+ except Exception:
436
+ logger.info(f"[MAC] Diagnostics summary (non-serializable)")
437
+ logger.info(f"[MAC] Final answer generated: {len(final_answer)} chars, {len(breakdown.get('sub_topics', []))} tasks completed")
438
+
reasoning.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Autonomous reasoning and execution planning"""
2
+ import json
3
+ import asyncio
4
+ from logger import logger
5
+ from mcp import MCP_AVAILABLE, call_agent
6
+ from config import GEMINI_MODEL
7
+
8
+ try:
9
+ import nest_asyncio
10
+ except ImportError:
11
+ nest_asyncio = None
12
+
13
+
14
+ async def autonomous_reasoning_gemini(query: str) -> dict:
15
+ """Autonomous reasoning using Gemini MCP"""
16
+ reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
17
+ Query: "{query}"
18
+ Analyze:
19
+ 1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
20
+ 2. Complexity: (simple, moderate, complex, multi_faceted)
21
+ 3. Information Needs: What specific information is required?
22
+ 4. Requires RAG: (yes/no) - Does this need document context?
23
+ 5. Requires Web Search: (yes/no) - Does this need current/updated information?
24
+ 6. Sub-questions: Break down into key sub-questions if complex
25
+ Respond in JSON format:
26
+ {{
27
+ "query_type": "...",
28
+ "complexity": "...",
29
+ "information_needs": ["..."],
30
+ "requires_rag": true/false,
31
+ "requires_web_search": true/false,
32
+ "sub_questions": ["..."]
33
+ }}"""
34
+
35
+ system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses."
36
+
37
+ response = await call_agent(
38
+ user_prompt=reasoning_prompt,
39
+ system_prompt=system_prompt,
40
+ model=GEMINI_MODEL,
41
+ temperature=0.3
42
+ )
43
+
44
+ try:
45
+ json_start = response.find('{')
46
+ json_end = response.rfind('}') + 1
47
+ if json_start >= 0 and json_end > json_start:
48
+ reasoning = json.loads(response[json_start:json_end])
49
+ else:
50
+ raise ValueError("No JSON found")
51
+ except:
52
+ reasoning = {
53
+ "query_type": "general_info",
54
+ "complexity": "moderate",
55
+ "information_needs": ["medical information"],
56
+ "requires_rag": True,
57
+ "requires_web_search": False,
58
+ "sub_questions": [query]
59
+ }
60
+
61
+ logger.info(f"Reasoning analysis: {reasoning}")
62
+ return reasoning
63
+
64
+
65
+ def autonomous_reasoning(query: str, history: list) -> dict:
66
+ """Autonomous reasoning: Analyze query complexity, intent, and information needs"""
67
+ if not MCP_AVAILABLE:
68
+ logger.warning("⚠️ Gemini MCP not available for reasoning, using fallback")
69
+ return {
70
+ "query_type": "general_info",
71
+ "complexity": "moderate",
72
+ "information_needs": ["medical information"],
73
+ "requires_rag": True,
74
+ "requires_web_search": False,
75
+ "sub_questions": [query]
76
+ }
77
+
78
+ try:
79
+ loop = asyncio.get_event_loop()
80
+ if loop.is_running():
81
+ if nest_asyncio:
82
+ reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query))
83
+ return reasoning
84
+ else:
85
+ logger.error("Error in nested async reasoning: nest_asyncio not available")
86
+ else:
87
+ reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query))
88
+ return reasoning
89
+ except Exception as e:
90
+ logger.error(f"Gemini MCP reasoning error: {e}")
91
+
92
+ logger.warning("⚠️ Falling back to default reasoning")
93
+ return {
94
+ "query_type": "general_info",
95
+ "complexity": "moderate",
96
+ "information_needs": ["medical information"],
97
+ "requires_rag": True,
98
+ "requires_web_search": False,
99
+ "sub_questions": [query]
100
+ }
101
+
102
+
103
+ def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict:
104
+ """Planning: Create multi-step execution plan based on reasoning analysis"""
105
+ plan = {
106
+ "steps": [],
107
+ "strategy": "sequential",
108
+ "iterations": 1
109
+ }
110
+
111
+ if reasoning["complexity"] in ["complex", "multi_faceted"]:
112
+ plan["strategy"] = "iterative"
113
+ plan["iterations"] = 2
114
+
115
+ plan["steps"].append({
116
+ "step": 1,
117
+ "action": "detect_language",
118
+ "description": "Detect query language and translate if needed"
119
+ })
120
+
121
+ if reasoning.get("requires_rag", True) and has_rag_index:
122
+ plan["steps"].append({
123
+ "step": 2,
124
+ "action": "rag_retrieval",
125
+ "description": "Retrieve relevant document context",
126
+ "parameters": {"top_k": 15, "merge_threshold": 0.5}
127
+ })
128
+
129
+ if reasoning.get("requires_web_search", False):
130
+ plan["steps"].append({
131
+ "step": 3,
132
+ "action": "web_search",
133
+ "description": "Search web for current/updated information",
134
+ "parameters": {"max_results": 5}
135
+ })
136
+
137
+ if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1:
138
+ plan["steps"].append({
139
+ "step": 4,
140
+ "action": "multi_step_reasoning",
141
+ "description": "Process sub-questions iteratively",
142
+ "sub_questions": reasoning["sub_questions"]
143
+ })
144
+
145
+ plan["steps"].append({
146
+ "step": len(plan["steps"]) + 1,
147
+ "action": "synthesize_answer",
148
+ "description": "Generate comprehensive answer from all sources"
149
+ })
150
+
151
+ if reasoning["complexity"] in ["complex", "multi_faceted"]:
152
+ plan["steps"].append({
153
+ "step": len(plan["steps"]) + 1,
154
+ "action": "self_reflection",
155
+ "description": "Evaluate answer quality and completeness"
156
+ })
157
+
158
+ logger.info(f"Execution plan created: {len(plan['steps'])} steps")
159
+ return plan
160
+
161
+
162
+ def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict:
163
+ """Autonomous execution: Make decisions on information gathering strategy"""
164
+ strategy = {
165
+ "use_rag": use_rag,
166
+ "use_web_search": use_web_search,
167
+ "reasoning_override": False,
168
+ "rationale": ""
169
+ }
170
+
171
+ if reasoning.get("requires_web_search", False) and not use_web_search:
172
+ strategy["rationale"] = "Reasoning suggests web search for current information, but the user kept it disabled."
173
+
174
+ if strategy["rationale"]:
175
+ logger.info(f"Autonomous reasoning note: {strategy['rationale']}")
176
+
177
+ return strategy
178
+
search.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search functions"""
2
+ import json
3
+ import asyncio
4
+ import concurrent.futures
5
+ from logger import logger
6
+ from mcp import MCP_AVAILABLE, get_mcp_session, get_cached_mcp_tools, call_agent
7
+ from config import GEMINI_MODEL
8
+
9
+ try:
10
+ import nest_asyncio
11
+ except ImportError:
12
+ nest_asyncio = None
13
+
14
+
15
+ async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
16
+ """Search web using MCP web search tool (e.g., DuckDuckGo MCP server)"""
17
+ if not MCP_AVAILABLE:
18
+ return []
19
+
20
+ try:
21
+ tools = await get_cached_mcp_tools()
22
+ if not tools:
23
+ return []
24
+
25
+ search_tool = None
26
+ for tool in tools:
27
+ tool_name_lower = tool.name.lower()
28
+ if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
29
+ search_tool = tool
30
+ logger.info(f"Found web search MCP tool: {tool.name}")
31
+ break
32
+
33
+ if not search_tool:
34
+ tools = await get_cached_mcp_tools(force_refresh=True)
35
+ for tool in tools:
36
+ tool_name_lower = tool.name.lower()
37
+ if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
38
+ search_tool = tool
39
+ logger.info(f"Found web search MCP tool after refresh: {tool.name}")
40
+ break
41
+
42
+ if search_tool:
43
+ try:
44
+ session = await get_mcp_session()
45
+ if session is None:
46
+ return []
47
+
48
+ result = await session.call_tool(
49
+ search_tool.name,
50
+ arguments={"query": query, "max_results": max_results}
51
+ )
52
+
53
+ web_content = []
54
+ if hasattr(result, 'content') and result.content:
55
+ for item in result.content:
56
+ if hasattr(item, 'text'):
57
+ try:
58
+ data = json.loads(item.text)
59
+ if isinstance(data, list):
60
+ for entry in data[:max_results]:
61
+ web_content.append({
62
+ 'title': entry.get('title', ''),
63
+ 'url': entry.get('url', entry.get('href', '')),
64
+ 'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
65
+ })
66
+ elif isinstance(data, dict):
67
+ if 'results' in data:
68
+ for entry in data['results'][:max_results]:
69
+ web_content.append({
70
+ 'title': entry.get('title', ''),
71
+ 'url': entry.get('url', entry.get('href', '')),
72
+ 'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
73
+ })
74
+ else:
75
+ web_content.append({
76
+ 'title': data.get('title', ''),
77
+ 'url': data.get('url', data.get('href', '')),
78
+ 'content': data.get('body', data.get('snippet', data.get('content', '')))
79
+ })
80
+ except json.JSONDecodeError:
81
+ web_content.append({
82
+ 'title': '',
83
+ 'url': '',
84
+ 'content': item.text[:1000]
85
+ })
86
+
87
+ if web_content:
88
+ return web_content
89
+ except Exception as e:
90
+ logger.error(f"Error calling web search MCP tool: {e}")
91
+
92
+ else:
93
+ logger.debug("No MCP web search tool discovered in current catalog")
94
+ return []
95
+ except Exception as e:
96
+ logger.error(f"Web search MCP tool error: {e}")
97
+ return []
98
+
99
+
100
+ async def search_web_mcp(query: str, max_results: int = 5) -> list:
101
+ """Search web using MCP tools - tries web search MCP tool first, then falls back to direct search"""
102
+ results = await search_web_mcp_tool(query, max_results)
103
+ if results:
104
+ logger.info(f"✅ Web search via MCP tool: found {len(results)} results")
105
+ return results
106
+
107
+ logger.info("ℹ️ [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)")
108
+ return search_web_fallback(query, max_results)
109
+
110
+
111
+ def search_web_fallback(query: str, max_results: int = 5) -> list:
112
+ """Fallback web search using DuckDuckGo directly (when MCP is not available)"""
113
+ logger.info(f"🔍 [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...")
114
+ try:
115
+ from ddgs import DDGS
116
+ import requests
117
+ from bs4 import BeautifulSoup
118
+ except ImportError:
119
+ logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available")
120
+ return []
121
+
122
+ try:
123
+ with DDGS() as ddgs:
124
+ results = list(ddgs.text(query, max_results=max_results))
125
+ web_content = []
126
+ for result in results:
127
+ try:
128
+ url = result.get('href', '')
129
+ title = result.get('title', '')
130
+ snippet = result.get('body', '')
131
+
132
+ try:
133
+ response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'})
134
+ if response.status_code == 200:
135
+ soup = BeautifulSoup(response.content, 'html.parser')
136
+ for script in soup(["script", "style"]):
137
+ script.decompose()
138
+ text = soup.get_text()
139
+ lines = (line.strip() for line in text.splitlines())
140
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
141
+ text = ' '.join(chunk for chunk in chunks if chunk)
142
+ if len(text) > 1000:
143
+ text = text[:1000] + "..."
144
+ web_content.append({
145
+ 'title': title,
146
+ 'url': url,
147
+ 'content': snippet + "\n" + text[:500] if text else snippet
148
+ })
149
+ else:
150
+ web_content.append({
151
+ 'title': title,
152
+ 'url': url,
153
+ 'content': snippet
154
+ })
155
+ except:
156
+ web_content.append({
157
+ 'title': title,
158
+ 'url': url,
159
+ 'content': snippet
160
+ })
161
+ except Exception as e:
162
+ logger.error(f"Error processing search result: {e}")
163
+ continue
164
+ logger.info(f"✅ [Direct API] Web search completed: {len(web_content)} results")
165
+ return web_content
166
+ except Exception as e:
167
+ logger.error(f"❌ [Direct API] Web search error: {e}")
168
+ return []
169
+
170
+
171
+ def search_web(query: str, max_results: int = 5) -> list:
172
+ """Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs"""
173
+ if MCP_AVAILABLE:
174
+ try:
175
+ try:
176
+ loop = asyncio.get_event_loop()
177
+ except RuntimeError:
178
+ loop = asyncio.new_event_loop()
179
+ asyncio.set_event_loop(loop)
180
+
181
+ if loop.is_running():
182
+ if nest_asyncio:
183
+ results = nest_asyncio.run(search_web_mcp(query, max_results))
184
+ if results:
185
+ return results
186
+ else:
187
+ with concurrent.futures.ThreadPoolExecutor() as executor:
188
+ future = executor.submit(asyncio.run, search_web_mcp(query, max_results))
189
+ results = future.result(timeout=30)
190
+ if results:
191
+ return results
192
+ else:
193
+ results = loop.run_until_complete(search_web_mcp(query, max_results))
194
+ if results:
195
+ return results
196
+ except Exception as e:
197
+ logger.error(f"Error running async MCP search: {e}")
198
+
199
+ logger.info("ℹ️ [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)")
200
+ return search_web_fallback(query, max_results)
201
+
202
+
203
+ async def summarize_web_content_gemini(content_list: list, query: str) -> str:
204
+ """Summarize web search results using Gemini MCP"""
205
+ combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
206
+
207
+ user_prompt = f"""Summarize the following web search results related to the query: "{query}"
208
+ Extract key medical information, facts, and insights. Be concise and focus on reliable information.
209
+ Search Results:
210
+ {combined_content}
211
+ Summary:"""
212
+
213
+ system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately."
214
+
215
+ result = await call_agent(
216
+ user_prompt=user_prompt,
217
+ system_prompt=system_prompt,
218
+ model=GEMINI_MODEL,
219
+ temperature=0.5
220
+ )
221
+
222
+ return result.strip()
223
+
224
+
225
+ def summarize_web_content(content_list: list, query: str) -> str:
226
+ """Summarize web search results using Gemini MCP"""
227
+ if not MCP_AVAILABLE:
228
+ logger.warning("Gemini MCP not available for summarization")
229
+ if content_list:
230
+ return content_list[0].get('content', '')[:500]
231
+ return ""
232
+
233
+ try:
234
+ loop = asyncio.get_event_loop()
235
+ if loop.is_running():
236
+ if nest_asyncio:
237
+ summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query))
238
+ if summary:
239
+ return summary
240
+ else:
241
+ logger.error("Error in nested async summarization: nest_asyncio not available")
242
+ else:
243
+ summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query))
244
+ if summary:
245
+ return summary
246
+ except Exception as e:
247
+ logger.error(f"Gemini MCP summarization error: {e}")
248
+
249
+ if content_list:
250
+ return content_list[0].get('content', '')[:500]
251
+ return ""
252
+
supervisor.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gemini Supervisor functions for MAC architecture"""
2
+ import json
3
+ import asyncio
4
+ import torch
5
+ import spaces
6
+ from logger import logger
7
+ from mcp import MCP_AVAILABLE, call_agent
8
+ from config import GEMINI_MODEL, GEMINI_MODEL_LITE
9
+ from utils import format_prompt_manually
10
+
11
+ try:
12
+ import nest_asyncio
13
+ except ImportError:
14
+ nest_asyncio = None
15
+
16
+
17
+ async def gemini_supervisor_breakdown_async(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
18
+ """Gemini Supervisor: Break user query into sub-topics"""
19
+ remaining_time = max(15, max_duration - time_elapsed)
20
+
21
+ mode_description = []
22
+ if use_rag:
23
+ mode_description.append("RAG mode enabled - will use retrieved documents")
24
+ if use_web_search:
25
+ mode_description.append("Web search mode enabled - will search online sources")
26
+ if not mode_description:
27
+ mode_description.append("Direct answer mode - no additional context")
28
+
29
+ estimated_time_per_task = 8
30
+ max_topics_by_time = max(2, int((remaining_time - 20) / estimated_time_per_task))
31
+ max_topics = min(max_topics_by_time, 10)
32
+
33
+ prompt = f"""You are a supervisor agent coordinating with a MedSwin medical specialist model.
34
+ Break the following medical query into focused sub-topics that MedSwin can answer sequentially.
35
+ Explore different potential approaches to comprehensively address the topic.
36
+
37
+ Query: "{query}"
38
+ Mode: {', '.join(mode_description)}
39
+ Time Remaining: ~{remaining_time:.1f}s
40
+ Maximum Topics: {max_topics} (adjust based on complexity - use as many as needed for thorough coverage)
41
+
42
+ Return ONLY valid JSON (no markdown, no tables, no explanations):
43
+ {{
44
+ "sub_topics": [
45
+ {{
46
+ "id": 1,
47
+ "topic": "concise topic name",
48
+ "instruction": "specific directive for MedSwin to answer this topic",
49
+ "expected_tokens": 200,
50
+ "priority": "high|medium|low",
51
+ "approach": "brief description of approach/angle for this topic"
52
+ }},
53
+ ...
54
+ ],
55
+ "strategy": "brief strategy description explaining the breakdown approach",
56
+ "exploration_note": "brief note on different approaches explored"
57
+ }}
58
+
59
+ Guidelines:
60
+ - Break down the query into as many subtasks as needed for comprehensive coverage
61
+ - Explore different angles/approaches (e.g., clinical, diagnostic, treatment, prevention, research perspectives)
62
+ - Each topic should be focused and answerable in ~200 tokens by MedSwin
63
+ - Prioritize topics by importance (high priority first)
64
+ - Don't limit yourself to 4 topics - use more if the query is complex or multi-faceted"""
65
+
66
+ system_prompt = "You are a medical query supervisor. Break queries into structured JSON sub-topics, exploring different approaches. Return ONLY valid JSON."
67
+
68
+ response = await call_agent(
69
+ user_prompt=prompt,
70
+ system_prompt=system_prompt,
71
+ model=GEMINI_MODEL,
72
+ temperature=0.3
73
+ )
74
+
75
+ try:
76
+ json_start = response.find('{')
77
+ json_end = response.rfind('}') + 1
78
+ if json_start >= 0 and json_end > json_start:
79
+ breakdown = json.loads(response[json_start:json_end])
80
+ logger.info(f"[GEMINI SUPERVISOR] Query broken into {len(breakdown.get('sub_topics', []))} sub-topics")
81
+ return breakdown
82
+ else:
83
+ raise ValueError("Supervisor JSON not found")
84
+ except Exception as exc:
85
+ logger.error(f"[GEMINI SUPERVISOR] Breakdown parsing failed: {exc}")
86
+ breakdown = {
87
+ "sub_topics": [
88
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
89
+ {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
90
+ ],
91
+ "strategy": "Sequential answer with key points",
92
+ "exploration_note": "Fallback breakdown - basic coverage"
93
+ }
94
+ logger.warning(f"[GEMINI SUPERVISOR] Using fallback breakdown")
95
+ return breakdown
96
+
97
+
98
+ async def gemini_supervisor_search_strategies_async(query: str, time_elapsed: float) -> dict:
99
+ """Gemini Supervisor: In search mode, break query into 1-4 searching strategies"""
100
+ prompt = f"""You are supervising web search for a medical query.
101
+ Break this query into 1-4 focused search strategies (each targeting 1-2 sources).
102
+
103
+ Query: "{query}"
104
+
105
+ Return ONLY valid JSON:
106
+ {{
107
+ "search_strategies": [
108
+ {{
109
+ "id": 1,
110
+ "strategy": "search query string",
111
+ "target_sources": 1,
112
+ "focus": "what to search for"
113
+ }},
114
+ ...
115
+ ],
116
+ "max_strategies": 4
117
+ }}
118
+
119
+ Keep strategies focused and avoid overlap."""
120
+
121
+ system_prompt = "You are a search strategy supervisor. Create focused search queries. Return ONLY valid JSON."
122
+
123
+ response = await call_agent(
124
+ user_prompt=prompt,
125
+ system_prompt=system_prompt,
126
+ model=GEMINI_MODEL_LITE,
127
+ temperature=0.2
128
+ )
129
+
130
+ try:
131
+ json_start = response.find('{')
132
+ json_end = response.rfind('}') + 1
133
+ if json_start >= 0 and json_end > json_start:
134
+ strategies = json.loads(response[json_start:json_end])
135
+ logger.info(f"[GEMINI SUPERVISOR] Created {len(strategies.get('search_strategies', []))} search strategies")
136
+ return strategies
137
+ else:
138
+ raise ValueError("Search strategies JSON not found")
139
+ except Exception as exc:
140
+ logger.error(f"[GEMINI SUPERVISOR] Search strategies parsing failed: {exc}")
141
+ return {
142
+ "search_strategies": [
143
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
144
+ ],
145
+ "max_strategies": 1
146
+ }
147
+
148
+
149
+ async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
150
+ """Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
151
+ max_doc_length = 3000
152
+ if len(retrieved_docs) > max_doc_length:
153
+ retrieved_docs = retrieved_docs[:max_doc_length] + "..."
154
+
155
+ prompt = f"""You are supervising RAG context preparation for a medical query.
156
+ Brainstorm the retrieved documents into 1-4 concise, focused contexts that MedSwin can use.
157
+
158
+ Query: "{query}"
159
+ Retrieved Documents:
160
+ {retrieved_docs}
161
+
162
+ Return ONLY valid JSON:
163
+ {{
164
+ "contexts": [
165
+ {{
166
+ "id": 1,
167
+ "context": "concise summary of relevant information (keep under 500 chars)",
168
+ "focus": "what this context covers",
169
+ "relevance": "high|medium|low"
170
+ }},
171
+ ...
172
+ ],
173
+ "max_contexts": 4
174
+ }}
175
+
176
+ Keep contexts brief and factual. Avoid redundancy."""
177
+
178
+ system_prompt = "You are a RAG context supervisor. Summarize documents into concise contexts. Return ONLY valid JSON."
179
+
180
+ response = await call_agent(
181
+ user_prompt=prompt,
182
+ system_prompt=system_prompt,
183
+ model=GEMINI_MODEL_LITE,
184
+ temperature=0.2
185
+ )
186
+
187
+ try:
188
+ json_start = response.find('{')
189
+ json_end = response.rfind('}') + 1
190
+ if json_start >= 0 and json_end > json_start:
191
+ contexts = json.loads(response[json_start:json_end])
192
+ logger.info(f"[GEMINI SUPERVISOR] Brainstormed {len(contexts.get('contexts', []))} RAG contexts")
193
+ return contexts
194
+ else:
195
+ raise ValueError("RAG contexts JSON not found")
196
+ except Exception as exc:
197
+ logger.error(f"[GEMINI SUPERVISOR] RAG brainstorming parsing failed: {exc}")
198
+ return {
199
+ "contexts": [
200
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
201
+ ],
202
+ "max_contexts": 1
203
+ }
204
+
205
+
206
+ def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
207
+ """Wrapper to obtain supervisor breakdown synchronously"""
208
+ if not MCP_AVAILABLE:
209
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable, using fallback breakdown")
210
+ return {
211
+ "sub_topics": [
212
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
213
+ {"id": 2, "topic": "Clinical Details", "instruction": "Provide key clinical insights", "expected_tokens": 200, "priority": "medium", "approach": "clinical perspective"},
214
+ ],
215
+ "strategy": "Sequential answer with key points",
216
+ "exploration_note": "Fallback breakdown - basic coverage"
217
+ }
218
+
219
+ try:
220
+ loop = asyncio.get_event_loop()
221
+ if loop.is_running():
222
+ if nest_asyncio:
223
+ return nest_asyncio.run(
224
+ gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
225
+ )
226
+ else:
227
+ logger.error("[GEMINI SUPERVISOR] Nested breakdown execution failed: nest_asyncio not available")
228
+ return loop.run_until_complete(
229
+ gemini_supervisor_breakdown_async(query, use_rag, use_web_search, time_elapsed, max_duration)
230
+ )
231
+ except Exception as exc:
232
+ logger.error(f"[GEMINI SUPERVISOR] Breakdown request failed: {exc}")
233
+ return {
234
+ "sub_topics": [
235
+ {"id": 1, "topic": "Core Question", "instruction": "Address the main medical question", "expected_tokens": 200, "priority": "high", "approach": "direct answer"},
236
+ ],
237
+ "strategy": "Direct answer",
238
+ "exploration_note": "Fallback breakdown - single topic"
239
+ }
240
+
241
+
242
+ def gemini_supervisor_search_strategies(query: str, time_elapsed: float) -> dict:
243
+ """Wrapper to obtain search strategies synchronously"""
244
+ if not MCP_AVAILABLE:
245
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable for search strategies")
246
+ return {
247
+ "search_strategies": [
248
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
249
+ ],
250
+ "max_strategies": 1
251
+ }
252
+
253
+ try:
254
+ loop = asyncio.get_event_loop()
255
+ if loop.is_running():
256
+ if nest_asyncio:
257
+ return nest_asyncio.run(gemini_supervisor_search_strategies_async(query, time_elapsed))
258
+ else:
259
+ logger.error("[GEMINI SUPERVISOR] Nested search strategies execution failed: nest_asyncio not available")
260
+ return loop.run_until_complete(gemini_supervisor_search_strategies_async(query, time_elapsed))
261
+ except Exception as exc:
262
+ logger.error(f"[GEMINI SUPERVISOR] Search strategies request failed: {exc}")
263
+ return {
264
+ "search_strategies": [
265
+ {"id": 1, "strategy": query, "target_sources": 2, "focus": "main query"}
266
+ ],
267
+ "max_strategies": 1
268
+ }
269
+
270
+
271
+ def gemini_supervisor_rag_brainstorm(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
272
+ """Wrapper to obtain RAG brainstorm synchronously"""
273
+ if not MCP_AVAILABLE:
274
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable for RAG brainstorm")
275
+ return {
276
+ "contexts": [
277
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
278
+ ],
279
+ "max_contexts": 1
280
+ }
281
+
282
+ try:
283
+ loop = asyncio.get_event_loop()
284
+ if loop.is_running():
285
+ if nest_asyncio:
286
+ return nest_asyncio.run(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
287
+ else:
288
+ logger.error("[GEMINI SUPERVISOR] Nested RAG brainstorm execution failed: nest_asyncio not available")
289
+ return loop.run_until_complete(gemini_supervisor_rag_brainstorm_async(query, retrieved_docs, time_elapsed))
290
+ except Exception as exc:
291
+ logger.error(f"[GEMINI SUPERVISOR] RAG brainstorm request failed: {exc}")
292
+ return {
293
+ "contexts": [
294
+ {"id": 1, "context": retrieved_docs[:500], "focus": "retrieved information", "relevance": "high"}
295
+ ],
296
+ "max_contexts": 1
297
+ }
298
+
299
+
300
+ @spaces.GPU(max_duration=120)
301
+ def execute_medswin_task(
302
+ medical_model_obj,
303
+ medical_tokenizer,
304
+ task_instruction: str,
305
+ context: str,
306
+ system_prompt_base: str,
307
+ temperature: float,
308
+ max_new_tokens: int,
309
+ top_p: float,
310
+ top_k: int,
311
+ penalty: float
312
+ ) -> str:
313
+ """MedSwin Specialist: Execute a single task assigned by Gemini Supervisor"""
314
+ if context:
315
+ full_prompt = f"{system_prompt_base}\n\nContext:\n{context}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
316
+ else:
317
+ full_prompt = f"{system_prompt_base}\n\nTask: {task_instruction}\n\nAnswer concisely with key bullet points (Markdown format, no tables):"
318
+
319
+ messages = [{"role": "system", "content": full_prompt}]
320
+
321
+ if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
322
+ try:
323
+ prompt = medical_tokenizer.apply_chat_template(
324
+ messages,
325
+ tokenize=False,
326
+ add_generation_prompt=True
327
+ )
328
+ except Exception as e:
329
+ logger.warning(f"[MEDSWIN] Chat template failed, using manual formatting: {e}")
330
+ prompt = format_prompt_manually(messages, medical_tokenizer)
331
+ else:
332
+ prompt = format_prompt_manually(messages, medical_tokenizer)
333
+
334
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
335
+
336
+ eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
337
+
338
+ with torch.no_grad():
339
+ outputs = medical_model_obj.generate(
340
+ **inputs,
341
+ max_new_tokens=min(max_new_tokens, 800),
342
+ temperature=temperature,
343
+ top_p=top_p,
344
+ top_k=top_k,
345
+ repetition_penalty=penalty,
346
+ do_sample=True,
347
+ eos_token_id=eos_token_id,
348
+ pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
349
+ )
350
+
351
+ response = medical_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
352
+
353
+ response = response.strip()
354
+ if "|" in response and "---" in response:
355
+ logger.warning("[MEDSWIN] Detected table format, converting to Markdown bullets")
356
+ lines = [line.strip() for line in response.split('\n') if line.strip() and not line.strip().startswith('|') and '---' not in line]
357
+ response = '\n'.join([f"- {line}" if not line.startswith('-') else line for line in lines])
358
+
359
+ logger.info(f"[MEDSWIN] Task completed: {len(response)} chars generated")
360
+ return response
361
+
362
+
363
+ async def gemini_supervisor_synthesize_async(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
364
+ """Gemini Supervisor: Synthesize final answer from all MedSwin responses"""
365
+ context_summary = ""
366
+ if rag_contexts:
367
+ context_summary += f"Document Context Available: {len(rag_contexts)} context(s) from uploaded documents.\n"
368
+ if search_contexts:
369
+ context_summary += f"Web Search Context Available: {len(search_contexts)} search result(s).\n"
370
+
371
+ all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
372
+
373
+ prompt = f"""You are a supervisor agent synthesizing a comprehensive medical answer from multiple specialist responses.
374
+
375
+ Original Query: "{query}"
376
+
377
+ Context Available:
378
+ {context_summary}
379
+
380
+ MedSwin Specialist Responses (from {len(medswin_answers)} sub-topics):
381
+ {all_answers_text}
382
+
383
+ Your task:
384
+ 1. Synthesize all responses into a coherent, comprehensive final answer
385
+ 2. Integrate information from all sub-topics seamlessly
386
+ 3. Ensure the answer directly addresses the original query
387
+ 4. Maintain clinical accuracy and clarity
388
+ 5. Use clear structure with appropriate headings and bullet points
389
+ 6. Remove redundancy and contradictions
390
+ 7. Ensure all important points from MedSwin responses are included
391
+
392
+ Return the final synthesized answer in Markdown format. Do not add meta-commentary or explanations - just provide the final answer."""
393
+
394
+ system_prompt = "You are a medical answer synthesis supervisor. Create comprehensive, well-structured final answers from multiple specialist responses."
395
+
396
+ result = await call_agent(
397
+ user_prompt=prompt,
398
+ system_prompt=system_prompt,
399
+ model=GEMINI_MODEL,
400
+ temperature=0.3
401
+ )
402
+
403
+ return result.strip()
404
+
405
+
406
+ async def gemini_supervisor_challenge_async(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
407
+ """Gemini Supervisor: Challenge and evaluate the current answer"""
408
+ context_info = ""
409
+ if rag_contexts:
410
+ context_info += f"Document contexts: {len(rag_contexts)} available.\n"
411
+ if search_contexts:
412
+ context_info += f"Search contexts: {len(search_contexts)} available.\n"
413
+
414
+ all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
415
+
416
+ prompt = f"""You are a supervisor agent evaluating and challenging a medical answer for quality and completeness.
417
+
418
+ Original Query: "{query}"
419
+
420
+ Available Context:
421
+ {context_info}
422
+
423
+ MedSwin Specialist Responses:
424
+ {all_answers_text}
425
+
426
+ Current Synthesized Answer:
427
+ {current_answer[:2000]}
428
+
429
+ Evaluate this answer and provide:
430
+ 1. Completeness: Does it fully address the query? What's missing?
431
+ 2. Accuracy: Are there any inaccuracies or contradictions?
432
+ 3. Clarity: Is it well-structured and clear?
433
+ 4. Context Usage: Are document/search contexts properly utilized?
434
+ 5. Improvement Suggestions: Specific ways to enhance the answer
435
+
436
+ Return ONLY valid JSON:
437
+ {{
438
+ "is_optimal": true/false,
439
+ "completeness_score": 0-10,
440
+ "accuracy_score": 0-10,
441
+ "clarity_score": 0-10,
442
+ "missing_aspects": ["..."],
443
+ "inaccuracies": ["..."],
444
+ "improvement_suggestions": ["..."],
445
+ "needs_more_context": true/false,
446
+ "enhancement_instructions": "specific instructions for improving the answer"
447
+ }}"""
448
+
449
+ system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback in JSON format. Return ONLY valid JSON."
450
+
451
+ response = await call_agent(
452
+ user_prompt=prompt,
453
+ system_prompt=system_prompt,
454
+ model=GEMINI_MODEL,
455
+ temperature=0.3
456
+ )
457
+
458
+ try:
459
+ json_start = response.find('{')
460
+ json_end = response.rfind('}') + 1
461
+ if json_start >= 0 and json_end > json_start:
462
+ evaluation = json.loads(response[json_start:json_end])
463
+ logger.info(f"[GEMINI SUPERVISOR] Challenge evaluation: optimal={evaluation.get('is_optimal', False)}, scores={evaluation.get('completeness_score', 'N/A')}/{evaluation.get('accuracy_score', 'N/A')}/{evaluation.get('clarity_score', 'N/A')}")
464
+ return evaluation
465
+ else:
466
+ raise ValueError("Evaluation JSON not found")
467
+ except Exception as exc:
468
+ logger.error(f"[GEMINI SUPERVISOR] Challenge evaluation parsing failed: {exc}")
469
+ return {
470
+ "is_optimal": True,
471
+ "completeness_score": 7,
472
+ "accuracy_score": 7,
473
+ "clarity_score": 7,
474
+ "missing_aspects": [],
475
+ "inaccuracies": [],
476
+ "improvement_suggestions": [],
477
+ "needs_more_context": False,
478
+ "enhancement_instructions": ""
479
+ }
480
+
481
+
482
+ async def gemini_supervisor_enhance_answer_async(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
483
+ """Gemini Supervisor: Enhance the answer based on challenge feedback"""
484
+ context_info = ""
485
+ if rag_contexts:
486
+ context_info += f"Document contexts: {len(rag_contexts)} available.\n"
487
+ if search_contexts:
488
+ context_info += f"Search contexts: {len(search_contexts)} available.\n"
489
+
490
+ all_answers_text = "\n\n---\n\n".join([f"## {i+1}. {ans}" for i, ans in enumerate(medswin_answers)])
491
+
492
+ prompt = f"""You are a supervisor agent enhancing a medical answer based on evaluation feedback.
493
+
494
+ Original Query: "{query}"
495
+
496
+ Available Context:
497
+ {context_info}
498
+
499
+ MedSwin Specialist Responses:
500
+ {all_answers_text}
501
+
502
+ Current Answer (to enhance):
503
+ {current_answer}
504
+
505
+ Enhancement Instructions:
506
+ {enhancement_instructions}
507
+
508
+ Create an enhanced version of the answer that:
509
+ 1. Addresses all improvement suggestions
510
+ 2. Fills in missing aspects
511
+ 3. Corrects any inaccuracies
512
+ 4. Improves clarity and structure
513
+ 5. Better utilizes available context
514
+ 6. Maintains all valuable information from the current answer
515
+
516
+ Return the enhanced answer in Markdown format. Do not add meta-commentary."""
517
+
518
+ system_prompt = "You are a medical answer enhancement supervisor. Improve answers based on evaluation feedback while maintaining accuracy."
519
+
520
+ result = await call_agent(
521
+ user_prompt=prompt,
522
+ system_prompt=system_prompt,
523
+ model=GEMINI_MODEL,
524
+ temperature=0.3
525
+ )
526
+
527
+ return result.strip()
528
+
529
+
530
+ async def gemini_supervisor_check_clarity_async(query: str, answer: str, use_web_search: bool) -> dict:
531
+ """Gemini Supervisor: Check if answer is unclear or supervisor is unsure"""
532
+ if not use_web_search:
533
+ return {"is_unclear": False, "needs_search": False, "search_queries": []}
534
+
535
+ prompt = f"""You are a supervisor agent evaluating answer clarity and completeness.
536
+
537
+ Query: "{query}"
538
+
539
+ Current Answer:
540
+ {answer[:1500]}
541
+
542
+ Evaluate:
543
+ 1. Is the answer unclear or incomplete?
544
+ 2. Are there gaps that web search could fill?
545
+ 3. Is the supervisor (you) unsure about certain aspects?
546
+
547
+ Return ONLY valid JSON:
548
+ {{
549
+ "is_unclear": true/false,
550
+ "needs_search": true/false,
551
+ "uncertainty_areas": ["..."],
552
+ "search_queries": ["specific search queries to fill gaps"],
553
+ "rationale": "brief explanation"
554
+ }}
555
+
556
+ Only suggest search if the answer is genuinely unclear or has significant gaps that search could address."""
557
+
558
+ system_prompt = "You are a clarity evaluator. Assess if additional web search is needed. Return ONLY valid JSON."
559
+
560
+ response = await call_agent(
561
+ user_prompt=prompt,
562
+ system_prompt=system_prompt,
563
+ model=GEMINI_MODEL_LITE,
564
+ temperature=0.2
565
+ )
566
+
567
+ try:
568
+ json_start = response.find('{')
569
+ json_end = response.rfind('}') + 1
570
+ if json_start >= 0 and json_end > json_start:
571
+ evaluation = json.loads(response[json_start:json_end])
572
+ logger.info(f"[GEMINI SUPERVISOR] Clarity check: unclear={evaluation.get('is_unclear', False)}, needs_search={evaluation.get('needs_search', False)}")
573
+ return evaluation
574
+ else:
575
+ raise ValueError("Clarity check JSON not found")
576
+ except Exception as exc:
577
+ logger.error(f"[GEMINI SUPERVISOR] Clarity check parsing failed: {exc}")
578
+ return {"is_unclear": False, "needs_search": False, "search_queries": []}
579
+
580
+
581
+ def gemini_supervisor_synthesize(query: str, medswin_answers: list, rag_contexts: list, search_contexts: list, breakdown: dict) -> str:
582
+ """Wrapper to synthesize answer synchronously"""
583
+ if not MCP_AVAILABLE:
584
+ logger.warning("[GEMINI SUPERVISOR] MCP unavailable for synthesis, using simple concatenation")
585
+ return "\n\n".join(medswin_answers)
586
+
587
+ try:
588
+ loop = asyncio.get_event_loop()
589
+ if loop.is_running():
590
+ if nest_asyncio:
591
+ return nest_asyncio.run(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
592
+ else:
593
+ logger.error("[GEMINI SUPERVISOR] Nested synthesis failed: nest_asyncio not available")
594
+ return loop.run_until_complete(gemini_supervisor_synthesize_async(query, medswin_answers, rag_contexts, search_contexts, breakdown))
595
+ except Exception as exc:
596
+ logger.error(f"[GEMINI SUPERVISOR] Synthesis failed: {exc}")
597
+ return "\n\n".join(medswin_answers)
598
+
599
+
600
+ def gemini_supervisor_challenge(query: str, current_answer: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> dict:
601
+ """Wrapper to challenge answer synchronously"""
602
+ if not MCP_AVAILABLE:
603
+ return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
604
+
605
+ try:
606
+ loop = asyncio.get_event_loop()
607
+ if loop.is_running():
608
+ if nest_asyncio:
609
+ return nest_asyncio.run(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
610
+ else:
611
+ logger.error("[GEMINI SUPERVISOR] Nested challenge failed: nest_asyncio not available")
612
+ return loop.run_until_complete(gemini_supervisor_challenge_async(query, current_answer, medswin_answers, rag_contexts, search_contexts))
613
+ except Exception as exc:
614
+ logger.error(f"[GEMINI SUPERVISOR] Challenge failed: {exc}")
615
+ return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
616
+
617
+
618
+ def gemini_supervisor_enhance_answer(query: str, current_answer: str, enhancement_instructions: str, medswin_answers: list, rag_contexts: list, search_contexts: list) -> str:
619
+ """Wrapper to enhance answer synchronously"""
620
+ if not MCP_AVAILABLE:
621
+ return current_answer
622
+
623
+ try:
624
+ loop = asyncio.get_event_loop()
625
+ if loop.is_running():
626
+ if nest_asyncio:
627
+ return nest_asyncio.run(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
628
+ else:
629
+ logger.error("[GEMINI SUPERVISOR] Nested enhancement failed: nest_asyncio not available")
630
+ return loop.run_until_complete(gemini_supervisor_enhance_answer_async(query, current_answer, enhancement_instructions, medswin_answers, rag_contexts, search_contexts))
631
+ except Exception as exc:
632
+ logger.error(f"[GEMINI SUPERVISOR] Enhancement failed: {exc}")
633
+ return current_answer
634
+
635
+
636
+ def gemini_supervisor_check_clarity(query: str, answer: str, use_web_search: bool) -> dict:
637
+ """Wrapper to check clarity synchronously"""
638
+ if not MCP_AVAILABLE or not use_web_search:
639
+ return {"is_unclear": False, "needs_search": False, "search_queries": []}
640
+
641
+ try:
642
+ loop = asyncio.get_event_loop()
643
+ if loop.is_running():
644
+ if nest_asyncio:
645
+ return nest_asyncio.run(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
646
+ else:
647
+ logger.error("[GEMINI SUPERVISOR] Nested clarity check failed: nest_asyncio not available")
648
+ return loop.run_until_complete(gemini_supervisor_check_clarity_async(query, answer, use_web_search))
649
+ except Exception as exc:
650
+ logger.error(f"[GEMINI SUPERVISOR] Clarity check failed: {exc}")
651
+ return {"is_unclear": False, "needs_search": False, "search_queries": []}
652
+
653
+
654
+ async def self_reflection_gemini(answer: str, query: str) -> dict:
655
+ """Self-reflection using Gemini MCP"""
656
+ reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
657
+ Query: "{query}"
658
+ Answer: "{answer[:1000]}"
659
+ Evaluate:
660
+ 1. Completeness: Does it address all aspects of the query?
661
+ 2. Accuracy: Is the medical information accurate?
662
+ 3. Clarity: Is it clear and well-structured?
663
+ 4. Sources: Are sources cited appropriately?
664
+ 5. Missing Information: What important information might be missing?
665
+ Respond in JSON:
666
+ {{
667
+ "completeness_score": 0-10,
668
+ "accuracy_score": 0-10,
669
+ "clarity_score": 0-10,
670
+ "overall_score": 0-10,
671
+ "missing_aspects": ["..."],
672
+ "improvement_suggestions": ["..."]
673
+ }}"""
674
+
675
+ system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback."
676
+
677
+ response = await call_agent(
678
+ user_prompt=reflection_prompt,
679
+ system_prompt=system_prompt,
680
+ model=GEMINI_MODEL,
681
+ temperature=0.3
682
+ )
683
+
684
+ try:
685
+ json_start = response.find('{')
686
+ json_end = response.rfind('}') + 1
687
+ if json_start >= 0 and json_end > json_start:
688
+ reflection = json.loads(response[json_start:json_end])
689
+ else:
690
+ reflection = {"overall_score": 7, "improvement_suggestions": []}
691
+ except:
692
+ reflection = {"overall_score": 7, "improvement_suggestions": []}
693
+
694
+ logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}")
695
+ return reflection
696
+
697
+
698
+ def self_reflection(answer: str, query: str, reasoning: dict) -> dict:
699
+ """Self-reflection: Evaluate answer quality and completeness"""
700
+ if not MCP_AVAILABLE:
701
+ logger.warning("Gemini MCP not available for reflection, using fallback")
702
+ return {"overall_score": 7, "improvement_suggestions": []}
703
+
704
+ try:
705
+ loop = asyncio.get_event_loop()
706
+ if loop.is_running():
707
+ if nest_asyncio:
708
+ return nest_asyncio.run(self_reflection_gemini(answer, query))
709
+ else:
710
+ logger.error("Error in nested async reflection: nest_asyncio not available")
711
+ else:
712
+ return loop.run_until_complete(self_reflection_gemini(answer, query))
713
+ except Exception as e:
714
+ logger.error(f"Gemini MCP reflection error: {e}")
715
+
716
+ return {"overall_score": 7, "improvement_suggestions": []}
717
+
ui.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio UI setup"""
2
+ import time
3
+ import gradio as gr
4
+ from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
5
+ from indexing import create_or_update_index
6
+ from pipeline import stream_chat
7
+ from voice import transcribe_audio, generate_speech
8
+
9
+
10
+ def create_demo():
11
+ """Create and return Gradio demo interface"""
12
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
13
+ gr.HTML(TITLE)
14
+ gr.HTML(DESCRIPTION)
15
+
16
+ with gr.Row(elem_classes="main-container"):
17
+ with gr.Column(elem_classes="upload-section"):
18
+ file_upload = gr.File(
19
+ file_count="multiple",
20
+ label="Drag and Drop Files Here",
21
+ file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
22
+ elem_id="file-upload"
23
+ )
24
+ upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
25
+ status_output = gr.Textbox(
26
+ label="Status",
27
+ placeholder="Upload files to start...",
28
+ interactive=False
29
+ )
30
+ file_info_output = gr.HTML(
31
+ label="File Information",
32
+ elem_classes="processing-info"
33
+ )
34
+ upload_button.click(
35
+ fn=create_or_update_index,
36
+ inputs=[file_upload],
37
+ outputs=[status_output, file_info_output]
38
+ )
39
+
40
+ with gr.Column(elem_classes="chatbot-container"):
41
+ chatbot = gr.Chatbot(
42
+ height=500,
43
+ placeholder="Chat with MedSwin... Type your question below.",
44
+ show_label=False,
45
+ type="messages"
46
+ )
47
+ with gr.Row(elem_classes="input-row"):
48
+ message_input = gr.Textbox(
49
+ placeholder="Type your medical question here...",
50
+ show_label=False,
51
+ container=False,
52
+ lines=1,
53
+ scale=10
54
+ )
55
+ mic_button = gr.Audio(
56
+ sources=["microphone"],
57
+ type="filepath",
58
+ label="",
59
+ show_label=False,
60
+ container=False,
61
+ scale=1
62
+ )
63
+ submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
64
+
65
+ recording_timer = gr.Textbox(
66
+ value="",
67
+ label="",
68
+ show_label=False,
69
+ interactive=False,
70
+ visible=False,
71
+ container=False,
72
+ elem_classes="recording-timer"
73
+ )
74
+
75
+ recording_start_time = [None]
76
+
77
+ def handle_recording_start():
78
+ """Called when recording starts"""
79
+ recording_start_time[0] = time.time()
80
+ return gr.update(visible=True, value="Recording... 0s")
81
+
82
+ def handle_recording_stop(audio):
83
+ """Called when recording stops"""
84
+ recording_start_time[0] = None
85
+ if audio is None:
86
+ return gr.update(visible=False, value=""), ""
87
+ transcribed = transcribe_audio(audio)
88
+ return gr.update(visible=False, value=""), transcribed
89
+
90
+ mic_button.start_recording(
91
+ fn=handle_recording_start,
92
+ outputs=[recording_timer]
93
+ )
94
+
95
+ mic_button.stop_recording(
96
+ fn=handle_recording_stop,
97
+ inputs=[mic_button],
98
+ outputs=[recording_timer, message_input]
99
+ )
100
+
101
+ with gr.Row(visible=False) as tts_row:
102
+ tts_text = gr.Textbox(visible=False)
103
+ tts_audio = gr.Audio(label="Generated Speech", visible=False)
104
+
105
+ def generate_speech_from_chat(history):
106
+ """Extract last assistant message and generate speech"""
107
+ if not history or len(history) == 0:
108
+ return None
109
+ last_msg = history[-1]
110
+ if last_msg.get("role") == "assistant":
111
+ text = last_msg.get("content", "").replace(" 🔊", "").strip()
112
+ if text:
113
+ audio_path = generate_speech(text)
114
+ return audio_path
115
+ return None
116
+
117
+ tts_button = gr.Button("🔊 Play Response", visible=False, size="sm")
118
+
119
+ def update_tts_button(history):
120
+ if history and len(history) > 0 and history[-1].get("role") == "assistant":
121
+ return gr.update(visible=True)
122
+ return gr.update(visible=False)
123
+
124
+ chatbot.change(
125
+ fn=update_tts_button,
126
+ inputs=[chatbot],
127
+ outputs=[tts_button]
128
+ )
129
+
130
+ tts_button.click(
131
+ fn=generate_speech_from_chat,
132
+ inputs=[chatbot],
133
+ outputs=[tts_audio]
134
+ )
135
+
136
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
137
+ with gr.Row():
138
+ disable_agentic_reasoning = gr.Checkbox(
139
+ value=False,
140
+ label="Disable agentic reasoning",
141
+ info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
142
+ )
143
+ show_agentic_thought = gr.Button(
144
+ "Show agentic thought",
145
+ size="sm"
146
+ )
147
+ agentic_thoughts_box = gr.Textbox(
148
+ label="Agentic Thoughts",
149
+ placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
150
+ lines=8,
151
+ max_lines=15,
152
+ interactive=False,
153
+ visible=False,
154
+ elem_classes="agentic-thoughts"
155
+ )
156
+ with gr.Row():
157
+ use_rag = gr.Checkbox(
158
+ value=False,
159
+ label="Enable Document RAG",
160
+ info="Answer based on uploaded documents (upload required)"
161
+ )
162
+ use_web_search = gr.Checkbox(
163
+ value=False,
164
+ label="Enable Web Search (MCP)",
165
+ info="Fetch knowledge from online medical resources"
166
+ )
167
+
168
+ medical_model = gr.Radio(
169
+ choices=list(MEDSWIN_MODELS.keys()),
170
+ value=DEFAULT_MEDICAL_MODEL,
171
+ label="Medical Model",
172
+ info="MedSwin TA (default), others download on first use"
173
+ )
174
+
175
+ system_prompt = gr.Textbox(
176
+ value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
177
+ label="System Prompt",
178
+ lines=3
179
+ )
180
+
181
+ with gr.Tab("Generation Parameters"):
182
+ temperature = gr.Slider(
183
+ minimum=0,
184
+ maximum=1,
185
+ step=0.1,
186
+ value=0.2,
187
+ label="Temperature"
188
+ )
189
+ max_new_tokens = gr.Slider(
190
+ minimum=512,
191
+ maximum=4096,
192
+ step=128,
193
+ value=2048,
194
+ label="Max New Tokens",
195
+ info="Increased for medical models to prevent early stopping"
196
+ )
197
+ top_p = gr.Slider(
198
+ minimum=0.0,
199
+ maximum=1.0,
200
+ step=0.1,
201
+ value=0.7,
202
+ label="Top P"
203
+ )
204
+ top_k = gr.Slider(
205
+ minimum=1,
206
+ maximum=100,
207
+ step=1,
208
+ value=50,
209
+ label="Top K"
210
+ )
211
+ penalty = gr.Slider(
212
+ minimum=0.0,
213
+ maximum=2.0,
214
+ step=0.1,
215
+ value=1.2,
216
+ label="Repetition Penalty"
217
+ )
218
+
219
+ with gr.Tab("Retrieval Parameters"):
220
+ retriever_k = gr.Slider(
221
+ minimum=5,
222
+ maximum=30,
223
+ step=1,
224
+ value=15,
225
+ label="Initial Retrieval Size (Top K)"
226
+ )
227
+ merge_threshold = gr.Slider(
228
+ minimum=0.1,
229
+ maximum=0.9,
230
+ step=0.1,
231
+ value=0.5,
232
+ label="Merge Threshold (lower = more merging)"
233
+ )
234
+
235
+ show_thoughts_state = gr.State(value=False)
236
+
237
+ def toggle_thoughts_box(current_state):
238
+ """Toggle visibility of agentic thoughts box"""
239
+ new_state = not current_state
240
+ return gr.update(visible=new_state), new_state
241
+
242
+ show_agentic_thought.click(
243
+ fn=toggle_thoughts_box,
244
+ inputs=[show_thoughts_state],
245
+ outputs=[agentic_thoughts_box, show_thoughts_state]
246
+ )
247
+
248
+ submit_button.click(
249
+ fn=stream_chat,
250
+ inputs=[
251
+ message_input,
252
+ chatbot,
253
+ system_prompt,
254
+ temperature,
255
+ max_new_tokens,
256
+ top_p,
257
+ top_k,
258
+ penalty,
259
+ retriever_k,
260
+ merge_threshold,
261
+ use_rag,
262
+ medical_model,
263
+ use_web_search,
264
+ disable_agentic_reasoning,
265
+ show_thoughts_state
266
+ ],
267
+ outputs=[chatbot, agentic_thoughts_box]
268
+ )
269
+
270
+ message_input.submit(
271
+ fn=stream_chat,
272
+ inputs=[
273
+ message_input,
274
+ chatbot,
275
+ system_prompt,
276
+ temperature,
277
+ max_new_tokens,
278
+ top_p,
279
+ top_k,
280
+ penalty,
281
+ retriever_k,
282
+ merge_threshold,
283
+ use_rag,
284
+ medical_model,
285
+ use_web_search,
286
+ disable_agentic_reasoning,
287
+ show_thoughts_state
288
+ ],
289
+ outputs=[chatbot, agentic_thoughts_box]
290
+ )
291
+
292
+ return demo
293
+
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for translation, language detection, and formatting"""
2
+ import asyncio
3
+ from langdetect import detect, LangDetectException
4
+ from logger import logger
5
+ from mcp import MCP_AVAILABLE, call_agent
6
+ from config import GEMINI_MODEL_LITE
7
+
8
+ try:
9
+ import nest_asyncio
10
+ except ImportError:
11
+ nest_asyncio = None
12
+
13
+
14
+ def format_prompt_manually(messages: list, tokenizer) -> str:
15
+ """Manually format prompt for models without chat template"""
16
+ system_content = ""
17
+ user_content = ""
18
+
19
+ for msg in messages:
20
+ role = msg.get("role", "user")
21
+ content = msg.get("content", "")
22
+
23
+ if role == "system":
24
+ system_content = content
25
+ elif role == "user":
26
+ user_content = content
27
+
28
+ if system_content:
29
+ prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
30
+ else:
31
+ prompt = f"Question: {user_content}\n\nAnswer:"
32
+
33
+ return prompt
34
+
35
+
36
+ def detect_language(text: str) -> str:
37
+ """Detect language of input text"""
38
+ try:
39
+ lang = detect(text)
40
+ return lang
41
+ except LangDetectException:
42
+ return "en"
43
+
44
+
45
+ def format_url_as_domain(url: str) -> str:
46
+ """Format URL as simple domain name (e.g., www.mayoclinic.org)"""
47
+ if not url:
48
+ return ""
49
+ try:
50
+ from urllib.parse import urlparse
51
+ parsed = urlparse(url)
52
+ domain = parsed.netloc or parsed.path
53
+ if domain.startswith('www.'):
54
+ return domain
55
+ elif domain:
56
+ return domain
57
+ return url
58
+ except Exception:
59
+ if '://' in url:
60
+ domain = url.split('://')[1].split('/')[0]
61
+ return domain
62
+ return url
63
+
64
+
65
+ async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str:
66
+ """Translate text using Gemini MCP"""
67
+ if source_lang:
68
+ user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
69
+ else:
70
+ user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
71
+
72
+ system_prompt = "You are a professional translator. Translate accurately and concisely."
73
+
74
+ result = await call_agent(
75
+ user_prompt=user_prompt,
76
+ system_prompt=system_prompt,
77
+ model=GEMINI_MODEL_LITE,
78
+ temperature=0.2
79
+ )
80
+
81
+ return result.strip()
82
+
83
+
84
+ def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str:
85
+ """Translate text using Gemini MCP"""
86
+ if not MCP_AVAILABLE:
87
+ logger.warning("Gemini MCP not available for translation")
88
+ return text
89
+
90
+ try:
91
+ loop = asyncio.get_event_loop()
92
+ if loop.is_running():
93
+ if nest_asyncio:
94
+ translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang))
95
+ if translated:
96
+ logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
97
+ return translated
98
+ else:
99
+ logger.error("Error in nested async translation: nest_asyncio not available")
100
+ else:
101
+ translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang))
102
+ if translated:
103
+ logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
104
+ return translated
105
+ except Exception as e:
106
+ logger.error(f"Gemini MCP translation error: {e}")
107
+
108
+ return text
109
+
voice.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio transcription and text-to-speech functions"""
2
+ import os
3
+ import asyncio
4
+ import tempfile
5
+ import soundfile as sf
6
+ from logger import logger
7
+ from mcp import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
8
+ import config
9
+ from models import TTS_AVAILABLE, initialize_tts_model
10
+
11
+ try:
12
+ import nest_asyncio
13
+ except ImportError:
14
+ nest_asyncio = None
15
+
16
+
17
+ async def transcribe_audio_gemini(audio_path: str) -> str:
18
+ """Transcribe audio using Gemini MCP"""
19
+ if not MCP_AVAILABLE:
20
+ return ""
21
+
22
+ try:
23
+ audio_path_abs = os.path.abspath(audio_path)
24
+ files = [{"path": audio_path_abs}]
25
+
26
+ system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
27
+ user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
28
+
29
+ result = await call_agent(
30
+ user_prompt=user_prompt,
31
+ system_prompt=system_prompt,
32
+ files=files,
33
+ model=config.GEMINI_MODEL_LITE,
34
+ temperature=0.2
35
+ )
36
+
37
+ return result.strip()
38
+ except Exception as e:
39
+ logger.error(f"Gemini transcription error: {e}")
40
+ return ""
41
+
42
+
43
+ def transcribe_audio(audio):
44
+ """Transcribe audio to text using Gemini MCP"""
45
+ if audio is None:
46
+ return ""
47
+
48
+ try:
49
+ if isinstance(audio, str):
50
+ audio_path = audio
51
+ elif isinstance(audio, tuple):
52
+ sample_rate, audio_data = audio
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
54
+ sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
55
+ audio_path = tmp_file.name
56
+ else:
57
+ audio_path = audio
58
+
59
+ if MCP_AVAILABLE:
60
+ try:
61
+ loop = asyncio.get_event_loop()
62
+ if loop.is_running():
63
+ if nest_asyncio:
64
+ transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
65
+ if transcribed:
66
+ logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
67
+ return transcribed
68
+ else:
69
+ logger.error("nest_asyncio not available for nested async transcription")
70
+ else:
71
+ transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
72
+ if transcribed:
73
+ logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
74
+ return transcribed
75
+ except Exception as e:
76
+ logger.error(f"Gemini MCP transcription error: {e}")
77
+
78
+ logger.warning("Gemini MCP transcription not available")
79
+ return ""
80
+ except Exception as e:
81
+ logger.error(f"Transcription error: {e}")
82
+ return ""
83
+
84
+
85
+ async def generate_speech_mcp(text: str) -> str:
86
+ """Generate speech using MCP TTS tool"""
87
+ if not MCP_AVAILABLE:
88
+ return None
89
+
90
+ try:
91
+ session = await get_mcp_session()
92
+ if session is None:
93
+ return None
94
+
95
+ tools = await get_cached_mcp_tools()
96
+ tts_tool = None
97
+ for tool in tools:
98
+ tool_name_lower = tool.name.lower()
99
+ if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower:
100
+ tts_tool = tool
101
+ logger.info(f"Found MCP TTS tool: {tool.name}")
102
+ break
103
+
104
+ if tts_tool:
105
+ result = await session.call_tool(
106
+ tts_tool.name,
107
+ arguments={"text": text, "language": "en"}
108
+ )
109
+
110
+ if hasattr(result, 'content') and result.content:
111
+ for item in result.content:
112
+ if hasattr(item, 'text'):
113
+ if os.path.exists(item.text):
114
+ return item.text
115
+ elif hasattr(item, 'data') and item.data:
116
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
117
+ tmp_file.write(item.data)
118
+ return tmp_file.name
119
+ return None
120
+ except Exception as e:
121
+ logger.warning(f"MCP TTS error: {e}")
122
+ return None
123
+
124
+
125
+ def generate_speech(text: str):
126
+ """Generate speech from text using TTS model (with MCP fallback)"""
127
+ if not text or len(text.strip()) == 0:
128
+ return None
129
+
130
+ if MCP_AVAILABLE:
131
+ try:
132
+ loop = asyncio.get_event_loop()
133
+ if loop.is_running():
134
+ if nest_asyncio:
135
+ audio_path = nest_asyncio.run(generate_speech_mcp(text))
136
+ if audio_path:
137
+ logger.info("Generated speech via MCP")
138
+ return audio_path
139
+ else:
140
+ audio_path = loop.run_until_complete(generate_speech_mcp(text))
141
+ if audio_path:
142
+ return audio_path
143
+ except Exception as e:
144
+ pass
145
+
146
+ if not TTS_AVAILABLE:
147
+ logger.error("TTS library not installed. Please install TTS to use voice generation.")
148
+ return None
149
+
150
+ if config.global_tts_model is None:
151
+ initialize_tts_model()
152
+
153
+ if config.global_tts_model is None:
154
+ logger.error("TTS model not available. Please check dependencies.")
155
+ return None
156
+
157
+ try:
158
+ wav = config.global_tts_model.tts(text)
159
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
160
+ sf.write(tmp_file.name, wav, samplerate=22050)
161
+ return tmp_file.name
162
+ except Exception as e:
163
+ logger.error(f"TTS error: {e}")
164
+ return None
165
+