"""MCP session management and tool caching""" import os import time import asyncio import base64 from logger import logger import config # Direct Gemini API imports GEMINI_DIRECT_AVAILABLE = False try: from google import genai GEMINI_DIRECT_AVAILABLE = True except ImportError: GEMINI_DIRECT_AVAILABLE = False # MCP imports MCP_CLIENT_INFO = None MCP_AVAILABLE = False try: from mcp import ClientSession, StdioServerParameters from mcp import types as mcp_types from mcp.client.stdio import stdio_client MCP_AVAILABLE = True try: import nest_asyncio nest_asyncio.apply() except ImportError: pass MCP_CLIENT_INFO = mcp_types.Implementation( name="MedLLM-Agent", version=os.environ.get("SPACE_VERSION", "local"), ) logger.info("✅ MCP SDK imported successfully") except ImportError as e: logger.warning(f"❌ MCP SDK import failed: {e}") logger.info(" Install with: pip install mcp>=0.1.0") logger.info(" The app will continue to work with fallback functionality (direct API calls)") MCP_AVAILABLE = False MCP_CLIENT_INFO = None except Exception as e: logger.error(f"❌ Unexpected error initializing MCP: {type(e).__name__}: {e}") logger.info(" The app will continue to work with fallback functionality") MCP_AVAILABLE = False MCP_CLIENT_INFO = None async def get_mcp_session(): """Get or create MCP client session with proper context management""" if not MCP_AVAILABLE: logger.warning("MCP not available - SDK not installed") return None # Reuse existing session if available if config.global_mcp_session is not None: logger.debug("Reusing existing MCP session") return config.global_mcp_session try: mcp_env = os.environ.copy() if config.GEMINI_API_KEY: mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY logger.info(f"✅ GEMINI_API_KEY found: {config.GEMINI_API_KEY[:10]}...{config.GEMINI_API_KEY[-4:]}") else: logger.warning("❌ GEMINI_API_KEY not set in environment. Gemini MCP features will not work.") logger.warning(" Set it with: export GEMINI_API_KEY='your-api-key'") return None if os.environ.get("GEMINI_MODEL"): mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL") if os.environ.get("GEMINI_TIMEOUT"): mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT") if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"): mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS") if os.environ.get("GEMINI_TEMPERATURE"): mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE") logger.info(f"Creating MCP client session... (command: {config.MCP_SERVER_COMMAND} {config.MCP_SERVER_ARGS})") server_params = StdioServerParameters( command=config.MCP_SERVER_COMMAND, args=config.MCP_SERVER_ARGS, env=mcp_env ) stdio_ctx = stdio_client(server_params) read, write = await stdio_ctx.__aenter__() session = ClientSession( read, write, client_info=MCP_CLIENT_INFO, ) try: await session.__aenter__() init_result = await session.initialize() server_info = getattr(init_result, "serverInfo", None) server_name = getattr(server_info, "name", "unknown") server_version = getattr(server_info, "version", "unknown") logger.info(f"✅ MCP session initialized successfully (server={server_name} v{server_version})") except Exception as e: error_msg = str(e) error_type = type(e).__name__ logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}") logger.error(f" This might be due to:") logger.error(f" - Invalid GEMINI_API_KEY") logger.error(f" - agent.py server not starting correctly") logger.error(f" - Network/firewall issues") logger.error(f" - MCP server process crashed or timed out") import traceback logger.error(f" Full traceback: {traceback.format_exc()}") try: await session.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Session cleanup error (ignored): {cleanup_error}") try: await stdio_ctx.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") return None config.global_mcp_session = session config.global_mcp_stdio_ctx = stdio_ctx logger.info("✅ MCP client session created successfully") return session except Exception as e: error_type = type(e).__name__ error_msg = str(e) logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}") config.global_mcp_session = None config.global_mcp_stdio_ctx = None return None def invalidate_mcp_tools_cache(): """Invalidate cached MCP tool metadata""" config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None} async def get_cached_mcp_tools(force_refresh: bool = False): """Return cached MCP tools list to avoid repeated list_tools calls""" if not MCP_AVAILABLE: return [] now = time.time() if ( not force_refresh and config.global_mcp_tools_cache["tools"] and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL ): return config.global_mcp_tools_cache["tools"] session = await get_mcp_session() if session is None: return [] try: tools_resp = await session.list_tools() tools_list = list(getattr(tools_resp, "tools", []) or []) config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list} return tools_list except Exception as e: logger.error(f"Failed to refresh MCP tools: {e}") invalidate_mcp_tools_cache() return [] async def test_mcp_connection() -> bool: """Test MCP connection and return True if successful""" if not MCP_AVAILABLE: logger.warning("Cannot test MCP: SDK not available") return False if not config.GEMINI_API_KEY: logger.warning("Cannot test MCP: GEMINI_API_KEY not set") return False try: session = await get_mcp_session() if session is None: logger.warning("MCP connection test failed: Could not create session") return False # Try to list tools as a connectivity test tools = await get_cached_mcp_tools() if tools: logger.info(f"✅ MCP connection test successful! Found {len(tools)} tools") return True else: logger.warning("MCP connection test: Session created but no tools found") return False except Exception as e: logger.error(f"MCP connection test failed: {type(e).__name__}: {e}") return False async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: """Call Gemini API directly without MCP Includes retry logic with exponential backoff to handle GPU task aborted errors """ if not GEMINI_DIRECT_AVAILABLE: logger.error("❌ google-genai not installed - cannot use direct API") return "" if not config.GEMINI_API_KEY: logger.warning("GEMINI_API_KEY not set - cannot use Gemini API") return "" max_retries = 3 base_delay = 1.0 # Base delay in seconds for attempt in range(max_retries): try: gemini_client = genai.Client(api_key=config.GEMINI_API_KEY) model_name = model or config.GEMINI_MODEL temp = temperature if temperature is not None else 0.2 # Prepare content contents = user_prompt if system_prompt: contents = f"{system_prompt}\n\n{user_prompt}" gemini_contents = [contents] # Handle files if provided if files: for file_obj in files: try: if "path" in file_obj: file_path = file_obj["path"] mime_type = file_obj.get("type") if not os.path.exists(file_path): logger.warning(f"File not found: {file_path}") continue with open(file_path, 'rb') as f: file_data = f.read() if not mime_type: from mimetypes import guess_type mime_type, _ = guess_type(file_path) if not mime_type: mime_type = "application/octet-stream" gemini_contents.append({ "inline_data": { "mime_type": mime_type, "data": base64.b64encode(file_data).decode('utf-8') } }) elif "content" in file_obj: file_data = base64.b64decode(file_obj["content"]) mime_type = file_obj.get("type", "application/octet-stream") gemini_contents.append({ "inline_data": { "mime_type": mime_type, "data": file_obj["content"] } }) except Exception as e: logger.warning(f"Error processing file: {e}") continue generation_config = { "temperature": temp, "max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192")) } logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}") def generate_sync(): return gemini_client.models.generate_content( model=model_name, contents=gemini_contents, config=generation_config, ) timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0) response = await asyncio.wait_for( asyncio.to_thread(generate_sync), timeout=timeout_seconds ) logger.info(f"✅ Gemini API call completed successfully") # Extract text from response if response and hasattr(response, 'text') and response.text: return response.text.strip() elif response and hasattr(response, 'candidates') and response.candidates: text_parts = [] for candidate in response.candidates: if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'): for part in candidate.content.parts: if hasattr(part, 'text'): text_parts.append(part.text) if text_parts: return ''.join(text_parts).strip() logger.warning("⚠️ Gemini API returned empty response") return "" except asyncio.TimeoutError: if attempt < max_retries - 1: delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s logger.warning(f"⏳ Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...") await asyncio.sleep(delay) continue else: logger.error(f"❌ Gemini API call timed out after {max_retries} attempts") return "" except Exception as e: error_type = type(e).__name__ error_msg = str(e).lower() is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg) if is_gpu_error and attempt < max_retries - 1: delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...") await asyncio.sleep(delay) continue else: logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}") if attempt == max_retries - 1: import traceback logger.error(f"Full traceback: {traceback.format_exc()}") return "" async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: """Call Gemini - either via MCP or direct API based on USE_API config""" # Check if we should use direct API if config.USE_API: logger.info("🔵 Using direct Gemini API (USE_API=true)") return await call_agent_direct_api(user_prompt, system_prompt, files, model, temperature) # Otherwise use MCP if not MCP_AVAILABLE: logger.debug("MCP not available for Gemini call") return "" if not config.GEMINI_API_KEY: logger.warning("GEMINI_API_KEY not set - cannot use Gemini MCP") return "" try: session = await get_mcp_session() if session is None: logger.error("❌ Failed to get MCP session for Gemini call - check GEMINI_API_KEY and agent.py") # Invalidate session to force retry on next call config.global_mcp_session = None config.global_mcp_stdio_ctx = None return "" logger.debug(f"MCP session obtained: {type(session).__name__}") tools = await get_cached_mcp_tools() if not tools: logger.info("MCP tools cache empty, refreshing...") tools = await get_cached_mcp_tools(force_refresh=True) if not tools: logger.error("❌ Unable to obtain MCP tool catalog for Gemini calls") # Invalidate session to force retry on next call config.global_mcp_session = None config.global_mcp_stdio_ctx = None return "" logger.debug(f"Found {len(tools)} MCP tools available") generate_tool = None for tool in tools: if tool.name == "generate_content" or "generate_content" in tool.name.lower(): generate_tool = tool logger.info(f"Found Gemini MCP tool: {tool.name}") break if not generate_tool: logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}") invalidate_mcp_tools_cache() return "" arguments = { "user_prompt": user_prompt } if system_prompt: arguments["system_prompt"] = system_prompt if files: arguments["files"] = files if model: arguments["model"] = model if temperature is not None: arguments["temperature"] = temperature logger.info(f"🔵 Calling MCP tool '{generate_tool.name}' with model={model or 'default'}, temperature={temperature}") logger.debug(f"MCP tool arguments keys: {list(arguments.keys())}") logger.debug(f"User prompt length: {len(user_prompt)} chars") # Add timeout to prevent hanging # Client timeout should be longer than server timeout to account for communication overhead # Server timeout is ~18s, so client should wait ~25s to allow for processing + communication client_timeout = 25.0 try: logger.debug(f"Starting MCP tool call with {client_timeout}s timeout...") result = await asyncio.wait_for( session.call_tool(generate_tool.name, arguments=arguments), timeout=client_timeout ) logger.info(f"✅ MCP tool call completed successfully") except asyncio.TimeoutError: logger.error(f"❌ MCP tool call timed out after {client_timeout}s") logger.error(f" Tool: {generate_tool.name}, Model: {model or 'default'}") logger.error(f" This suggests the MCP server (agent.py) is not responding or the Gemini API call is taking too long") logger.error(f" Check if agent.py process is still running and responsive") logger.error(f" Consider increasing GEMINI_TIMEOUT or checking network connectivity") # Invalidate session on timeout to force retry config.global_mcp_session = None # Properly cleanup stdio context if config.global_mcp_stdio_ctx is not None: try: await config.global_mcp_stdio_ctx.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") config.global_mcp_stdio_ctx = None return "" except Exception as call_error: logger.error(f"❌ MCP tool call failed with exception: {type(call_error).__name__}: {call_error}") import traceback logger.error(f" Traceback: {traceback.format_exc()}") # Invalidate session on error to force retry config.global_mcp_session = None # Properly cleanup stdio context if config.global_mcp_stdio_ctx is not None: try: await config.global_mcp_stdio_ctx.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") config.global_mcp_stdio_ctx = None raise # Re-raise to be caught by outer exception handler if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): response_text = item.text.strip() if response_text: logger.info(f"✅ Gemini MCP returned {len(response_text)} chars") return response_text logger.warning("⚠️ Gemini MCP returned empty or invalid result") return "" except Exception as e: error_type = type(e).__name__ error_msg = str(e) logger.error(f"❌ Gemini MCP call error: {error_type}: {error_msg}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") # Invalidate session on error to force retry config.global_mcp_session = None # Properly cleanup stdio context if config.global_mcp_stdio_ctx is not None: try: await config.global_mcp_stdio_ctx.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") config.global_mcp_stdio_ctx = None return ""