File size: 20,161 Bytes
ffcfd50
 
 
 
fc23c24
ffcfd50
 
 
fc23c24
 
 
 
 
 
 
 
ffcfd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f3ac98
ffcfd50
6ab08df
2f3ac98
ffcfd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a43fcc
ffcfd50
4a43fcc
ffcfd50
 
4a43fcc
 
ffcfd50
 
4a43fcc
 
ffcfd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc23c24
dd13e35
 
 
fc23c24
 
 
 
 
 
 
 
dd13e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc23c24
dd13e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc23c24
dd13e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc23c24
 
ffcfd50
fc23c24
 
 
 
 
 
 
ffcfd50
 
 
 
 
 
 
 
 
 
 
6ab08df
ffcfd50
 
4a43fcc
ffcfd50
 
6ab08df
 
ffcfd50
 
4a43fcc
ffcfd50
 
6ab08df
4a43fcc
 
 
ffcfd50
 
6ab08df
 
ffcfd50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ab08df
 
 
 
7a7ea02
 
 
 
6ab08df
7a7ea02
6ab08df
 
7a7ea02
6ab08df
 
 
7a7ea02
 
 
6ab08df
7a7ea02
6ab08df
 
fc23c24
 
 
 
 
 
 
6ab08df
 
 
 
 
 
 
fc23c24
 
 
 
 
 
 
6ab08df
ffcfd50
 
 
 
 
4a43fcc
6ab08df
4a43fcc
ffcfd50
 
 
4a43fcc
 
6ab08df
4a43fcc
6ab08df
4a43fcc
 
fc23c24
 
 
 
 
 
 
ffcfd50
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""MCP session management and tool caching"""
import os
import time
import asyncio
import base64
from logger import logger
import config

# Direct Gemini API imports
GEMINI_DIRECT_AVAILABLE = False
try:
    from google import genai
    GEMINI_DIRECT_AVAILABLE = True
except ImportError:
    GEMINI_DIRECT_AVAILABLE = False

# MCP imports
MCP_CLIENT_INFO = None
MCP_AVAILABLE = False
try:
    from mcp import ClientSession, StdioServerParameters
    from mcp import types as mcp_types
    from mcp.client.stdio import stdio_client
    MCP_AVAILABLE = True
    try:
        import nest_asyncio
        nest_asyncio.apply()
    except ImportError:
        pass
    MCP_CLIENT_INFO = mcp_types.Implementation(
        name="MedLLM-Agent",
        version=os.environ.get("SPACE_VERSION", "local"),
    )
    logger.info("βœ… MCP SDK imported successfully")
except ImportError as e:
    logger.warning(f"❌ MCP SDK import failed: {e}")
    logger.info("   Install with: pip install mcp>=0.1.0")
    logger.info("   The app will continue to work with fallback functionality (direct API calls)")
    MCP_AVAILABLE = False
    MCP_CLIENT_INFO = None
except Exception as e:
    logger.error(f"❌ Unexpected error initializing MCP: {type(e).__name__}: {e}")
    logger.info("   The app will continue to work with fallback functionality")
    MCP_AVAILABLE = False
    MCP_CLIENT_INFO = None


async def get_mcp_session():
    """Get or create MCP client session with proper context management"""
    if not MCP_AVAILABLE:
        logger.warning("MCP not available - SDK not installed")
        return None
    
    # Reuse existing session if available
    if config.global_mcp_session is not None:
        logger.debug("Reusing existing MCP session")
        return config.global_mcp_session
    
    try:
        mcp_env = os.environ.copy()
        if config.GEMINI_API_KEY:
            mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
            logger.info(f"βœ… GEMINI_API_KEY found: {config.GEMINI_API_KEY[:10]}...{config.GEMINI_API_KEY[-4:]}")
        else:
            logger.warning("❌ GEMINI_API_KEY not set in environment. Gemini MCP features will not work.")
            logger.warning("   Set it with: export GEMINI_API_KEY='your-api-key'")
            return None
        
        if os.environ.get("GEMINI_MODEL"):
            mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
        if os.environ.get("GEMINI_TIMEOUT"):
            mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
        if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
            mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
        if os.environ.get("GEMINI_TEMPERATURE"):
            mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
        
        logger.info(f"Creating MCP client session... (command: {config.MCP_SERVER_COMMAND} {config.MCP_SERVER_ARGS})")
        
        server_params = StdioServerParameters(
            command=config.MCP_SERVER_COMMAND,
            args=config.MCP_SERVER_ARGS,
            env=mcp_env
        )
        
        stdio_ctx = stdio_client(server_params)
        read, write = await stdio_ctx.__aenter__()
        
        session = ClientSession(
            read,
            write,
            client_info=MCP_CLIENT_INFO,
        )
        
        try:
            await session.__aenter__()
            init_result = await session.initialize()
            server_info = getattr(init_result, "serverInfo", None)
            server_name = getattr(server_info, "name", "unknown")
            server_version = getattr(server_info, "version", "unknown")
            logger.info(f"βœ… MCP session initialized successfully (server={server_name} v{server_version})")
        except Exception as e:
            error_msg = str(e)
            error_type = type(e).__name__
            logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
            logger.error(f"   This might be due to:")
            logger.error(f"   - Invalid GEMINI_API_KEY")
            logger.error(f"   - agent.py server not starting correctly")
            logger.error(f"   - Network/firewall issues")
            logger.error(f"   - MCP server process crashed or timed out")
            import traceback
            logger.error(f"   Full traceback: {traceback.format_exc()}")
            try:
                await session.__aexit__(None, None, None)
            except Exception as cleanup_error:
                logger.debug(f"Session cleanup error (ignored): {cleanup_error}")
            try:
                await stdio_ctx.__aexit__(None, None, None)
            except Exception as cleanup_error:
                logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
            return None
        
        config.global_mcp_session = session
        config.global_mcp_stdio_ctx = stdio_ctx
        logger.info("βœ… MCP client session created successfully")
        return session
    except Exception as e:
        error_type = type(e).__name__
        error_msg = str(e)
        logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
        config.global_mcp_session = None
        config.global_mcp_stdio_ctx = None
        return None


def invalidate_mcp_tools_cache():
    """Invalidate cached MCP tool metadata"""
    config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}


async def get_cached_mcp_tools(force_refresh: bool = False):
    """Return cached MCP tools list to avoid repeated list_tools calls"""
    if not MCP_AVAILABLE:
        return []
    
    now = time.time()
    if (
        not force_refresh
        and config.global_mcp_tools_cache["tools"]
        and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
    ):
        return config.global_mcp_tools_cache["tools"]
    
    session = await get_mcp_session()
    if session is None:
        return []
    
    try:
        tools_resp = await session.list_tools()
        tools_list = list(getattr(tools_resp, "tools", []) or [])
        config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
        return tools_list
    except Exception as e:
        logger.error(f"Failed to refresh MCP tools: {e}")
        invalidate_mcp_tools_cache()
        return []


async def test_mcp_connection() -> bool:
    """Test MCP connection and return True if successful"""
    if not MCP_AVAILABLE:
        logger.warning("Cannot test MCP: SDK not available")
        return False
    
    if not config.GEMINI_API_KEY:
        logger.warning("Cannot test MCP: GEMINI_API_KEY not set")
        return False
    
    try:
        session = await get_mcp_session()
        if session is None:
            logger.warning("MCP connection test failed: Could not create session")
            return False
        
        # Try to list tools as a connectivity test
        tools = await get_cached_mcp_tools()
        if tools:
            logger.info(f"βœ… MCP connection test successful! Found {len(tools)} tools")
            return True
        else:
            logger.warning("MCP connection test: Session created but no tools found")
            return False
    except Exception as e:
        logger.error(f"MCP connection test failed: {type(e).__name__}: {e}")
        return False


async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
    """Call Gemini API directly without MCP
    Includes retry logic with exponential backoff to handle GPU task aborted errors
    """
    if not GEMINI_DIRECT_AVAILABLE:
        logger.error("❌ google-genai not installed - cannot use direct API")
        return ""
    
    if not config.GEMINI_API_KEY:
        logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
        return ""
    
    max_retries = 3
    base_delay = 1.0  # Base delay in seconds
    
    for attempt in range(max_retries):
        try:
            gemini_client = genai.Client(api_key=config.GEMINI_API_KEY)
            model_name = model or config.GEMINI_MODEL
            temp = temperature if temperature is not None else 0.2
            
            # Prepare content
            contents = user_prompt
            if system_prompt:
                contents = f"{system_prompt}\n\n{user_prompt}"
            
            gemini_contents = [contents]
            
            # Handle files if provided
            if files:
                for file_obj in files:
                    try:
                        if "path" in file_obj:
                            file_path = file_obj["path"]
                            mime_type = file_obj.get("type")
                            
                            if not os.path.exists(file_path):
                                logger.warning(f"File not found: {file_path}")
                                continue
                            
                            with open(file_path, 'rb') as f:
                                file_data = f.read()
                            
                            if not mime_type:
                                from mimetypes import guess_type
                                mime_type, _ = guess_type(file_path)
                                if not mime_type:
                                    mime_type = "application/octet-stream"
                            
                            gemini_contents.append({
                                "inline_data": {
                                    "mime_type": mime_type,
                                    "data": base64.b64encode(file_data).decode('utf-8')
                                }
                            })
                        elif "content" in file_obj:
                            file_data = base64.b64decode(file_obj["content"])
                            mime_type = file_obj.get("type", "application/octet-stream")
                            gemini_contents.append({
                                "inline_data": {
                                    "mime_type": mime_type,
                                    "data": file_obj["content"]
                                }
                            })
                    except Exception as e:
                        logger.warning(f"Error processing file: {e}")
                        continue
            
            generation_config = {
                "temperature": temp,
                "max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192"))
            }
            
            logger.info(f"πŸ”΅ Calling Gemini API directly with model={model_name}, temperature={temp}")
            
            def generate_sync():
                return gemini_client.models.generate_content(
                    model=model_name,
                    contents=gemini_contents,
                    config=generation_config,
                )
            
            timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0)
            response = await asyncio.wait_for(
                asyncio.to_thread(generate_sync),
                timeout=timeout_seconds
            )
            
            logger.info(f"βœ… Gemini API call completed successfully")
            
            # Extract text from response
            if response and hasattr(response, 'text') and response.text:
                return response.text.strip()
            elif response and hasattr(response, 'candidates') and response.candidates:
                text_parts = []
                for candidate in response.candidates:
                    if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
                        for part in candidate.content.parts:
                            if hasattr(part, 'text'):
                                text_parts.append(part.text)
                if text_parts:
                    return ''.join(text_parts).strip()
            
            logger.warning("⚠️ Gemini API returned empty response")
            return ""
            
        except asyncio.TimeoutError:
            if attempt < max_retries - 1:
                delay = base_delay * (2 ** attempt)  # Exponential backoff: 1s, 2s, 4s
                logger.warning(f"⏳ Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
                await asyncio.sleep(delay)
                continue
            else:
                logger.error(f"❌ Gemini API call timed out after {max_retries} attempts")
                return ""
        except Exception as e:
            error_type = type(e).__name__
            error_msg = str(e).lower()
            is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg)
            
            if is_gpu_error and attempt < max_retries - 1:
                delay = base_delay * (2 ** attempt)  # Exponential backoff: 1s, 2s, 4s
                logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
                await asyncio.sleep(delay)
                continue
            else:
                logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}")
                if attempt == max_retries - 1:
                    import traceback
                    logger.error(f"Full traceback: {traceback.format_exc()}")
                return ""


async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
    """Call Gemini - either via MCP or direct API based on USE_API config"""
    # Check if we should use direct API
    if config.USE_API:
        logger.info("πŸ”΅ Using direct Gemini API (USE_API=true)")
        return await call_agent_direct_api(user_prompt, system_prompt, files, model, temperature)
    
    # Otherwise use MCP
    if not MCP_AVAILABLE:
        logger.debug("MCP not available for Gemini call")
        return ""
    
    if not config.GEMINI_API_KEY:
        logger.warning("GEMINI_API_KEY not set - cannot use Gemini MCP")
        return ""
    
    try:
        session = await get_mcp_session()
        if session is None:
            logger.error("❌ Failed to get MCP session for Gemini call - check GEMINI_API_KEY and agent.py")
            # Invalidate session to force retry on next call
            config.global_mcp_session = None
            config.global_mcp_stdio_ctx = None
            return ""
        
        logger.debug(f"MCP session obtained: {type(session).__name__}")
        
        tools = await get_cached_mcp_tools()
        if not tools:
            logger.info("MCP tools cache empty, refreshing...")
            tools = await get_cached_mcp_tools(force_refresh=True)
        if not tools:
            logger.error("❌ Unable to obtain MCP tool catalog for Gemini calls")
            # Invalidate session to force retry on next call
            config.global_mcp_session = None
            config.global_mcp_stdio_ctx = None
            return ""
        
        logger.debug(f"Found {len(tools)} MCP tools available")
        
        generate_tool = None
        for tool in tools:
            if tool.name == "generate_content" or "generate_content" in tool.name.lower():
                generate_tool = tool
                logger.info(f"Found Gemini MCP tool: {tool.name}")
                break
        
        if not generate_tool:
            logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
            invalidate_mcp_tools_cache()
            return ""
        
        arguments = {
            "user_prompt": user_prompt
        }
        if system_prompt:
            arguments["system_prompt"] = system_prompt
        if files:
            arguments["files"] = files
        if model:
            arguments["model"] = model
        if temperature is not None:
            arguments["temperature"] = temperature
        
        logger.info(f"πŸ”΅ Calling MCP tool '{generate_tool.name}' with model={model or 'default'}, temperature={temperature}")
        logger.debug(f"MCP tool arguments keys: {list(arguments.keys())}")
        logger.debug(f"User prompt length: {len(user_prompt)} chars")
        
        # Add timeout to prevent hanging
        # Client timeout should be longer than server timeout to account for communication overhead
        # Server timeout is ~18s, so client should wait ~25s to allow for processing + communication
        client_timeout = 25.0
        try:
            logger.debug(f"Starting MCP tool call with {client_timeout}s timeout...")
            result = await asyncio.wait_for(
                session.call_tool(generate_tool.name, arguments=arguments),
                timeout=client_timeout
            )
            logger.info(f"βœ… MCP tool call completed successfully")
        except asyncio.TimeoutError:
            logger.error(f"❌ MCP tool call timed out after {client_timeout}s")
            logger.error(f"   Tool: {generate_tool.name}, Model: {model or 'default'}")
            logger.error(f"   This suggests the MCP server (agent.py) is not responding or the Gemini API call is taking too long")
            logger.error(f"   Check if agent.py process is still running and responsive")
            logger.error(f"   Consider increasing GEMINI_TIMEOUT or checking network connectivity")
            # Invalidate session on timeout to force retry
            config.global_mcp_session = None
            # Properly cleanup stdio context
            if config.global_mcp_stdio_ctx is not None:
                try:
                    await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
                except Exception as cleanup_error:
                    logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
                config.global_mcp_stdio_ctx = None
            return ""
        except Exception as call_error:
            logger.error(f"❌ MCP tool call failed with exception: {type(call_error).__name__}: {call_error}")
            import traceback
            logger.error(f"   Traceback: {traceback.format_exc()}")
            # Invalidate session on error to force retry
            config.global_mcp_session = None
            # Properly cleanup stdio context
            if config.global_mcp_stdio_ctx is not None:
                try:
                    await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
                except Exception as cleanup_error:
                    logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
                config.global_mcp_stdio_ctx = None
            raise  # Re-raise to be caught by outer exception handler
        
        if hasattr(result, 'content') and result.content:
            for item in result.content:
                if hasattr(item, 'text'):
                    response_text = item.text.strip()
                    if response_text:
                        logger.info(f"βœ… Gemini MCP returned {len(response_text)} chars")
                        return response_text
        logger.warning("⚠️ Gemini MCP returned empty or invalid result")
        return ""
    except Exception as e:
        error_type = type(e).__name__
        error_msg = str(e)
        logger.error(f"❌ Gemini MCP call error: {error_type}: {error_msg}")
        import traceback
        logger.error(f"Full traceback: {traceback.format_exc()}")
        # Invalidate session on error to force retry
        config.global_mcp_session = None
        # Properly cleanup stdio context
        if config.global_mcp_stdio_ctx is not None:
            try:
                await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
            except Exception as cleanup_error:
                logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
            config.global_mcp_stdio_ctx = None
        return ""