Spaces:
Running
on
Zero
Running
on
Zero
| """MCP session management and tool caching""" | |
| import os | |
| import time | |
| import asyncio | |
| import base64 | |
| from logger import logger | |
| import config | |
| # Direct Gemini API imports | |
| GEMINI_DIRECT_AVAILABLE = False | |
| try: | |
| from google import genai | |
| GEMINI_DIRECT_AVAILABLE = True | |
| except ImportError: | |
| GEMINI_DIRECT_AVAILABLE = False | |
| # MCP imports | |
| MCP_CLIENT_INFO = None | |
| MCP_AVAILABLE = False | |
| try: | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp import types as mcp_types | |
| from mcp.client.stdio import stdio_client | |
| MCP_AVAILABLE = True | |
| try: | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| except ImportError: | |
| pass | |
| MCP_CLIENT_INFO = mcp_types.Implementation( | |
| name="MedLLM-Agent", | |
| version=os.environ.get("SPACE_VERSION", "local"), | |
| ) | |
| logger.info("✅ MCP SDK imported successfully") | |
| except ImportError as e: | |
| logger.warning(f"❌ MCP SDK import failed: {e}") | |
| logger.info(" Install with: pip install mcp>=0.1.0") | |
| logger.info(" The app will continue to work with fallback functionality (direct API calls)") | |
| MCP_AVAILABLE = False | |
| MCP_CLIENT_INFO = None | |
| except Exception as e: | |
| logger.error(f"❌ Unexpected error initializing MCP: {type(e).__name__}: {e}") | |
| logger.info(" The app will continue to work with fallback functionality") | |
| MCP_AVAILABLE = False | |
| MCP_CLIENT_INFO = None | |
| async def get_mcp_session(): | |
| """Get or create MCP client session with proper context management""" | |
| if not MCP_AVAILABLE: | |
| logger.warning("MCP not available - SDK not installed") | |
| return None | |
| # Reuse existing session if available | |
| if config.global_mcp_session is not None: | |
| logger.debug("Reusing existing MCP session") | |
| return config.global_mcp_session | |
| try: | |
| mcp_env = os.environ.copy() | |
| if config.GEMINI_API_KEY: | |
| mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY | |
| logger.info(f"✅ GEMINI_API_KEY found: {config.GEMINI_API_KEY[:10]}...{config.GEMINI_API_KEY[-4:]}") | |
| else: | |
| logger.warning("❌ GEMINI_API_KEY not set in environment. Gemini MCP features will not work.") | |
| logger.warning(" Set it with: export GEMINI_API_KEY='your-api-key'") | |
| return None | |
| if os.environ.get("GEMINI_MODEL"): | |
| mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL") | |
| if os.environ.get("GEMINI_TIMEOUT"): | |
| mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT") | |
| if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"): | |
| mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS") | |
| if os.environ.get("GEMINI_TEMPERATURE"): | |
| mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE") | |
| logger.info(f"Creating MCP client session... (command: {config.MCP_SERVER_COMMAND} {config.MCP_SERVER_ARGS})") | |
| server_params = StdioServerParameters( | |
| command=config.MCP_SERVER_COMMAND, | |
| args=config.MCP_SERVER_ARGS, | |
| env=mcp_env | |
| ) | |
| stdio_ctx = stdio_client(server_params) | |
| read, write = await stdio_ctx.__aenter__() | |
| session = ClientSession( | |
| read, | |
| write, | |
| client_info=MCP_CLIENT_INFO, | |
| ) | |
| try: | |
| await session.__aenter__() | |
| init_result = await session.initialize() | |
| server_info = getattr(init_result, "serverInfo", None) | |
| server_name = getattr(server_info, "name", "unknown") | |
| server_version = getattr(server_info, "version", "unknown") | |
| logger.info(f"✅ MCP session initialized successfully (server={server_name} v{server_version})") | |
| except Exception as e: | |
| error_msg = str(e) | |
| error_type = type(e).__name__ | |
| logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}") | |
| logger.error(f" This might be due to:") | |
| logger.error(f" - Invalid GEMINI_API_KEY") | |
| logger.error(f" - agent.py server not starting correctly") | |
| logger.error(f" - Network/firewall issues") | |
| logger.error(f" - MCP server process crashed or timed out") | |
| import traceback | |
| logger.error(f" Full traceback: {traceback.format_exc()}") | |
| try: | |
| await session.__aexit__(None, None, None) | |
| except Exception as cleanup_error: | |
| logger.debug(f"Session cleanup error (ignored): {cleanup_error}") | |
| try: | |
| await stdio_ctx.__aexit__(None, None, None) | |
| except Exception as cleanup_error: | |
| logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") | |
| return None | |
| config.global_mcp_session = session | |
| config.global_mcp_stdio_ctx = stdio_ctx | |
| logger.info("✅ MCP client session created successfully") | |
| return session | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| error_msg = str(e) | |
| logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}") | |
| config.global_mcp_session = None | |
| config.global_mcp_stdio_ctx = None | |
| return None | |
| def invalidate_mcp_tools_cache(): | |
| """Invalidate cached MCP tool metadata""" | |
| config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None} | |
| async def get_cached_mcp_tools(force_refresh: bool = False): | |
| """Return cached MCP tools list to avoid repeated list_tools calls""" | |
| if not MCP_AVAILABLE: | |
| return [] | |
| now = time.time() | |
| if ( | |
| not force_refresh | |
| and config.global_mcp_tools_cache["tools"] | |
| and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL | |
| ): | |
| return config.global_mcp_tools_cache["tools"] | |
| session = await get_mcp_session() | |
| if session is None: | |
| return [] | |
| try: | |
| tools_resp = await session.list_tools() | |
| tools_list = list(getattr(tools_resp, "tools", []) or []) | |
| config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list} | |
| return tools_list | |
| except Exception as e: | |
| logger.error(f"Failed to refresh MCP tools: {e}") | |
| invalidate_mcp_tools_cache() | |
| return [] | |
| async def test_mcp_connection() -> bool: | |
| """Test MCP connection and return True if successful""" | |
| if not MCP_AVAILABLE: | |
| logger.warning("Cannot test MCP: SDK not available") | |
| return False | |
| if not config.GEMINI_API_KEY: | |
| logger.warning("Cannot test MCP: GEMINI_API_KEY not set") | |
| return False | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.warning("MCP connection test failed: Could not create session") | |
| return False | |
| # Try to list tools as a connectivity test | |
| tools = await get_cached_mcp_tools() | |
| if tools: | |
| logger.info(f"✅ MCP connection test successful! Found {len(tools)} tools") | |
| return True | |
| else: | |
| logger.warning("MCP connection test: Session created but no tools found") | |
| return False | |
| except Exception as e: | |
| logger.error(f"MCP connection test failed: {type(e).__name__}: {e}") | |
| return False | |
| async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: | |
| """Call Gemini API directly without MCP | |
| Includes retry logic with exponential backoff to handle GPU task aborted errors | |
| """ | |
| if not GEMINI_DIRECT_AVAILABLE: | |
| logger.error("❌ google-genai not installed - cannot use direct API") | |
| return "" | |
| if not config.GEMINI_API_KEY: | |
| logger.warning("GEMINI_API_KEY not set - cannot use Gemini API") | |
| return "" | |
| max_retries = 3 | |
| base_delay = 1.0 # Base delay in seconds | |
| for attempt in range(max_retries): | |
| try: | |
| gemini_client = genai.Client(api_key=config.GEMINI_API_KEY) | |
| model_name = model or config.GEMINI_MODEL | |
| temp = temperature if temperature is not None else 0.2 | |
| # Prepare content | |
| contents = user_prompt | |
| if system_prompt: | |
| contents = f"{system_prompt}\n\n{user_prompt}" | |
| gemini_contents = [contents] | |
| # Handle files if provided | |
| if files: | |
| for file_obj in files: | |
| try: | |
| if "path" in file_obj: | |
| file_path = file_obj["path"] | |
| mime_type = file_obj.get("type") | |
| if not os.path.exists(file_path): | |
| logger.warning(f"File not found: {file_path}") | |
| continue | |
| with open(file_path, 'rb') as f: | |
| file_data = f.read() | |
| if not mime_type: | |
| from mimetypes import guess_type | |
| mime_type, _ = guess_type(file_path) | |
| if not mime_type: | |
| mime_type = "application/octet-stream" | |
| gemini_contents.append({ | |
| "inline_data": { | |
| "mime_type": mime_type, | |
| "data": base64.b64encode(file_data).decode('utf-8') | |
| } | |
| }) | |
| elif "content" in file_obj: | |
| file_data = base64.b64decode(file_obj["content"]) | |
| mime_type = file_obj.get("type", "application/octet-stream") | |
| gemini_contents.append({ | |
| "inline_data": { | |
| "mime_type": mime_type, | |
| "data": file_obj["content"] | |
| } | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Error processing file: {e}") | |
| continue | |
| generation_config = { | |
| "temperature": temp, | |
| "max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192")) | |
| } | |
| logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}") | |
| def generate_sync(): | |
| return gemini_client.models.generate_content( | |
| model=model_name, | |
| contents=gemini_contents, | |
| config=generation_config, | |
| ) | |
| timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0) | |
| response = await asyncio.wait_for( | |
| asyncio.to_thread(generate_sync), | |
| timeout=timeout_seconds | |
| ) | |
| logger.info(f"✅ Gemini API call completed successfully") | |
| # Extract text from response | |
| if response and hasattr(response, 'text') and response.text: | |
| return response.text.strip() | |
| elif response and hasattr(response, 'candidates') and response.candidates: | |
| text_parts = [] | |
| for candidate in response.candidates: | |
| if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'): | |
| for part in candidate.content.parts: | |
| if hasattr(part, 'text'): | |
| text_parts.append(part.text) | |
| if text_parts: | |
| return ''.join(text_parts).strip() | |
| logger.warning("⚠️ Gemini API returned empty response") | |
| return "" | |
| except asyncio.TimeoutError: | |
| if attempt < max_retries - 1: | |
| delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s | |
| logger.warning(f"⏳ Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| logger.error(f"❌ Gemini API call timed out after {max_retries} attempts") | |
| return "" | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| error_msg = str(e).lower() | |
| is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg) | |
| if is_gpu_error and attempt < max_retries - 1: | |
| delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s | |
| logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}") | |
| if attempt == max_retries - 1: | |
| import traceback | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| return "" | |
| async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: | |
| """Call Gemini - either via MCP or direct API based on USE_API config""" | |
| # Check if we should use direct API | |
| if config.USE_API: | |
| logger.info("🔵 Using direct Gemini API (USE_API=true)") | |
| return await call_agent_direct_api(user_prompt, system_prompt, files, model, temperature) | |
| # Otherwise use MCP | |
| if not MCP_AVAILABLE: | |
| logger.debug("MCP not available for Gemini call") | |
| return "" | |
| if not config.GEMINI_API_KEY: | |
| logger.warning("GEMINI_API_KEY not set - cannot use Gemini MCP") | |
| return "" | |
| try: | |
| session = await get_mcp_session() | |
| if session is None: | |
| logger.error("❌ Failed to get MCP session for Gemini call - check GEMINI_API_KEY and agent.py") | |
| # Invalidate session to force retry on next call | |
| config.global_mcp_session = None | |
| config.global_mcp_stdio_ctx = None | |
| return "" | |
| logger.debug(f"MCP session obtained: {type(session).__name__}") | |
| tools = await get_cached_mcp_tools() | |
| if not tools: | |
| logger.info("MCP tools cache empty, refreshing...") | |
| tools = await get_cached_mcp_tools(force_refresh=True) | |
| if not tools: | |
| logger.error("❌ Unable to obtain MCP tool catalog for Gemini calls") | |
| # Invalidate session to force retry on next call | |
| config.global_mcp_session = None | |
| config.global_mcp_stdio_ctx = None | |
| return "" | |
| logger.debug(f"Found {len(tools)} MCP tools available") | |
| generate_tool = None | |
| for tool in tools: | |
| if tool.name == "generate_content" or "generate_content" in tool.name.lower(): | |
| generate_tool = tool | |
| logger.info(f"Found Gemini MCP tool: {tool.name}") | |
| break | |
| if not generate_tool: | |
| logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}") | |
| invalidate_mcp_tools_cache() | |
| return "" | |
| arguments = { | |
| "user_prompt": user_prompt | |
| } | |
| if system_prompt: | |
| arguments["system_prompt"] = system_prompt | |
| if files: | |
| arguments["files"] = files | |
| if model: | |
| arguments["model"] = model | |
| if temperature is not None: | |
| arguments["temperature"] = temperature | |
| logger.info(f"🔵 Calling MCP tool '{generate_tool.name}' with model={model or 'default'}, temperature={temperature}") | |
| logger.debug(f"MCP tool arguments keys: {list(arguments.keys())}") | |
| logger.debug(f"User prompt length: {len(user_prompt)} chars") | |
| # Add timeout to prevent hanging | |
| # Client timeout should be longer than server timeout to account for communication overhead | |
| # Server timeout is ~18s, so client should wait ~25s to allow for processing + communication | |
| client_timeout = 25.0 | |
| try: | |
| logger.debug(f"Starting MCP tool call with {client_timeout}s timeout...") | |
| result = await asyncio.wait_for( | |
| session.call_tool(generate_tool.name, arguments=arguments), | |
| timeout=client_timeout | |
| ) | |
| logger.info(f"✅ MCP tool call completed successfully") | |
| except asyncio.TimeoutError: | |
| logger.error(f"❌ MCP tool call timed out after {client_timeout}s") | |
| logger.error(f" Tool: {generate_tool.name}, Model: {model or 'default'}") | |
| logger.error(f" This suggests the MCP server (agent.py) is not responding or the Gemini API call is taking too long") | |
| logger.error(f" Check if agent.py process is still running and responsive") | |
| logger.error(f" Consider increasing GEMINI_TIMEOUT or checking network connectivity") | |
| # Invalidate session on timeout to force retry | |
| config.global_mcp_session = None | |
| # Properly cleanup stdio context | |
| if config.global_mcp_stdio_ctx is not None: | |
| try: | |
| await config.global_mcp_stdio_ctx.__aexit__(None, None, None) | |
| except Exception as cleanup_error: | |
| logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") | |
| config.global_mcp_stdio_ctx = None | |
| return "" | |
| except Exception as call_error: | |
| logger.error(f"❌ MCP tool call failed with exception: {type(call_error).__name__}: {call_error}") | |
| import traceback | |
| logger.error(f" Traceback: {traceback.format_exc()}") | |
| # Invalidate session on error to force retry | |
| config.global_mcp_session = None | |
| # Properly cleanup stdio context | |
| if config.global_mcp_stdio_ctx is not None: | |
| try: | |
| await config.global_mcp_stdio_ctx.__aexit__(None, None, None) | |
| except Exception as cleanup_error: | |
| logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") | |
| config.global_mcp_stdio_ctx = None | |
| raise # Re-raise to be caught by outer exception handler | |
| if hasattr(result, 'content') and result.content: | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| response_text = item.text.strip() | |
| if response_text: | |
| logger.info(f"✅ Gemini MCP returned {len(response_text)} chars") | |
| return response_text | |
| logger.warning("⚠️ Gemini MCP returned empty or invalid result") | |
| return "" | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| error_msg = str(e) | |
| logger.error(f"❌ Gemini MCP call error: {error_type}: {error_msg}") | |
| import traceback | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| # Invalidate session on error to force retry | |
| config.global_mcp_session = None | |
| # Properly cleanup stdio context | |
| if config.global_mcp_stdio_ctx is not None: | |
| try: | |
| await config.global_mcp_stdio_ctx.__aexit__(None, None, None) | |
| except Exception as cleanup_error: | |
| logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}") | |
| config.global_mcp_stdio_ctx = None | |
| return "" | |