MedLLM-Agent / client.py
Y Phung Nguyen
Reduce supervisor processing latency
dd13e35
"""MCP session management and tool caching"""
import os
import time
import asyncio
import base64
from logger import logger
import config
# Direct Gemini API imports
GEMINI_DIRECT_AVAILABLE = False
try:
from google import genai
GEMINI_DIRECT_AVAILABLE = True
except ImportError:
GEMINI_DIRECT_AVAILABLE = False
# MCP imports
MCP_CLIENT_INFO = None
MCP_AVAILABLE = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
from mcp.client.stdio import stdio_client
MCP_AVAILABLE = True
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
MCP_CLIENT_INFO = mcp_types.Implementation(
name="MedLLM-Agent",
version=os.environ.get("SPACE_VERSION", "local"),
)
logger.info("✅ MCP SDK imported successfully")
except ImportError as e:
logger.warning(f"❌ MCP SDK import failed: {e}")
logger.info(" Install with: pip install mcp>=0.1.0")
logger.info(" The app will continue to work with fallback functionality (direct API calls)")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
except Exception as e:
logger.error(f"❌ Unexpected error initializing MCP: {type(e).__name__}: {e}")
logger.info(" The app will continue to work with fallback functionality")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
async def get_mcp_session():
"""Get or create MCP client session with proper context management"""
if not MCP_AVAILABLE:
logger.warning("MCP not available - SDK not installed")
return None
# Reuse existing session if available
if config.global_mcp_session is not None:
logger.debug("Reusing existing MCP session")
return config.global_mcp_session
try:
mcp_env = os.environ.copy()
if config.GEMINI_API_KEY:
mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
logger.info(f"✅ GEMINI_API_KEY found: {config.GEMINI_API_KEY[:10]}...{config.GEMINI_API_KEY[-4:]}")
else:
logger.warning("❌ GEMINI_API_KEY not set in environment. Gemini MCP features will not work.")
logger.warning(" Set it with: export GEMINI_API_KEY='your-api-key'")
return None
if os.environ.get("GEMINI_MODEL"):
mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
if os.environ.get("GEMINI_TIMEOUT"):
mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
if os.environ.get("GEMINI_TEMPERATURE"):
mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
logger.info(f"Creating MCP client session... (command: {config.MCP_SERVER_COMMAND} {config.MCP_SERVER_ARGS})")
server_params = StdioServerParameters(
command=config.MCP_SERVER_COMMAND,
args=config.MCP_SERVER_ARGS,
env=mcp_env
)
stdio_ctx = stdio_client(server_params)
read, write = await stdio_ctx.__aenter__()
session = ClientSession(
read,
write,
client_info=MCP_CLIENT_INFO,
)
try:
await session.__aenter__()
init_result = await session.initialize()
server_info = getattr(init_result, "serverInfo", None)
server_name = getattr(server_info, "name", "unknown")
server_version = getattr(server_info, "version", "unknown")
logger.info(f"✅ MCP session initialized successfully (server={server_name} v{server_version})")
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
logger.error(f" This might be due to:")
logger.error(f" - Invalid GEMINI_API_KEY")
logger.error(f" - agent.py server not starting correctly")
logger.error(f" - Network/firewall issues")
logger.error(f" - MCP server process crashed or timed out")
import traceback
logger.error(f" Full traceback: {traceback.format_exc()}")
try:
await session.__aexit__(None, None, None)
except Exception as cleanup_error:
logger.debug(f"Session cleanup error (ignored): {cleanup_error}")
try:
await stdio_ctx.__aexit__(None, None, None)
except Exception as cleanup_error:
logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
return None
config.global_mcp_session = session
config.global_mcp_stdio_ctx = stdio_ctx
logger.info("✅ MCP client session created successfully")
return session
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
config.global_mcp_session = None
config.global_mcp_stdio_ctx = None
return None
def invalidate_mcp_tools_cache():
"""Invalidate cached MCP tool metadata"""
config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
async def get_cached_mcp_tools(force_refresh: bool = False):
"""Return cached MCP tools list to avoid repeated list_tools calls"""
if not MCP_AVAILABLE:
return []
now = time.time()
if (
not force_refresh
and config.global_mcp_tools_cache["tools"]
and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
):
return config.global_mcp_tools_cache["tools"]
session = await get_mcp_session()
if session is None:
return []
try:
tools_resp = await session.list_tools()
tools_list = list(getattr(tools_resp, "tools", []) or [])
config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
return tools_list
except Exception as e:
logger.error(f"Failed to refresh MCP tools: {e}")
invalidate_mcp_tools_cache()
return []
async def test_mcp_connection() -> bool:
"""Test MCP connection and return True if successful"""
if not MCP_AVAILABLE:
logger.warning("Cannot test MCP: SDK not available")
return False
if not config.GEMINI_API_KEY:
logger.warning("Cannot test MCP: GEMINI_API_KEY not set")
return False
try:
session = await get_mcp_session()
if session is None:
logger.warning("MCP connection test failed: Could not create session")
return False
# Try to list tools as a connectivity test
tools = await get_cached_mcp_tools()
if tools:
logger.info(f"✅ MCP connection test successful! Found {len(tools)} tools")
return True
else:
logger.warning("MCP connection test: Session created but no tools found")
return False
except Exception as e:
logger.error(f"MCP connection test failed: {type(e).__name__}: {e}")
return False
async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
"""Call Gemini API directly without MCP
Includes retry logic with exponential backoff to handle GPU task aborted errors
"""
if not GEMINI_DIRECT_AVAILABLE:
logger.error("❌ google-genai not installed - cannot use direct API")
return ""
if not config.GEMINI_API_KEY:
logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
return ""
max_retries = 3
base_delay = 1.0 # Base delay in seconds
for attempt in range(max_retries):
try:
gemini_client = genai.Client(api_key=config.GEMINI_API_KEY)
model_name = model or config.GEMINI_MODEL
temp = temperature if temperature is not None else 0.2
# Prepare content
contents = user_prompt
if system_prompt:
contents = f"{system_prompt}\n\n{user_prompt}"
gemini_contents = [contents]
# Handle files if provided
if files:
for file_obj in files:
try:
if "path" in file_obj:
file_path = file_obj["path"]
mime_type = file_obj.get("type")
if not os.path.exists(file_path):
logger.warning(f"File not found: {file_path}")
continue
with open(file_path, 'rb') as f:
file_data = f.read()
if not mime_type:
from mimetypes import guess_type
mime_type, _ = guess_type(file_path)
if not mime_type:
mime_type = "application/octet-stream"
gemini_contents.append({
"inline_data": {
"mime_type": mime_type,
"data": base64.b64encode(file_data).decode('utf-8')
}
})
elif "content" in file_obj:
file_data = base64.b64decode(file_obj["content"])
mime_type = file_obj.get("type", "application/octet-stream")
gemini_contents.append({
"inline_data": {
"mime_type": mime_type,
"data": file_obj["content"]
}
})
except Exception as e:
logger.warning(f"Error processing file: {e}")
continue
generation_config = {
"temperature": temp,
"max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192"))
}
logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}")
def generate_sync():
return gemini_client.models.generate_content(
model=model_name,
contents=gemini_contents,
config=generation_config,
)
timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0)
response = await asyncio.wait_for(
asyncio.to_thread(generate_sync),
timeout=timeout_seconds
)
logger.info(f"✅ Gemini API call completed successfully")
# Extract text from response
if response and hasattr(response, 'text') and response.text:
return response.text.strip()
elif response and hasattr(response, 'candidates') and response.candidates:
text_parts = []
for candidate in response.candidates:
if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
for part in candidate.content.parts:
if hasattr(part, 'text'):
text_parts.append(part.text)
if text_parts:
return ''.join(text_parts).strip()
logger.warning("⚠️ Gemini API returned empty response")
return ""
except asyncio.TimeoutError:
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
logger.warning(f"⏳ Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"❌ Gemini API call timed out after {max_retries} attempts")
return ""
except Exception as e:
error_type = type(e).__name__
error_msg = str(e).lower()
is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg)
if is_gpu_error and attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}")
if attempt == max_retries - 1:
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return ""
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
"""Call Gemini - either via MCP or direct API based on USE_API config"""
# Check if we should use direct API
if config.USE_API:
logger.info("🔵 Using direct Gemini API (USE_API=true)")
return await call_agent_direct_api(user_prompt, system_prompt, files, model, temperature)
# Otherwise use MCP
if not MCP_AVAILABLE:
logger.debug("MCP not available for Gemini call")
return ""
if not config.GEMINI_API_KEY:
logger.warning("GEMINI_API_KEY not set - cannot use Gemini MCP")
return ""
try:
session = await get_mcp_session()
if session is None:
logger.error("❌ Failed to get MCP session for Gemini call - check GEMINI_API_KEY and agent.py")
# Invalidate session to force retry on next call
config.global_mcp_session = None
config.global_mcp_stdio_ctx = None
return ""
logger.debug(f"MCP session obtained: {type(session).__name__}")
tools = await get_cached_mcp_tools()
if not tools:
logger.info("MCP tools cache empty, refreshing...")
tools = await get_cached_mcp_tools(force_refresh=True)
if not tools:
logger.error("❌ Unable to obtain MCP tool catalog for Gemini calls")
# Invalidate session to force retry on next call
config.global_mcp_session = None
config.global_mcp_stdio_ctx = None
return ""
logger.debug(f"Found {len(tools)} MCP tools available")
generate_tool = None
for tool in tools:
if tool.name == "generate_content" or "generate_content" in tool.name.lower():
generate_tool = tool
logger.info(f"Found Gemini MCP tool: {tool.name}")
break
if not generate_tool:
logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
invalidate_mcp_tools_cache()
return ""
arguments = {
"user_prompt": user_prompt
}
if system_prompt:
arguments["system_prompt"] = system_prompt
if files:
arguments["files"] = files
if model:
arguments["model"] = model
if temperature is not None:
arguments["temperature"] = temperature
logger.info(f"🔵 Calling MCP tool '{generate_tool.name}' with model={model or 'default'}, temperature={temperature}")
logger.debug(f"MCP tool arguments keys: {list(arguments.keys())}")
logger.debug(f"User prompt length: {len(user_prompt)} chars")
# Add timeout to prevent hanging
# Client timeout should be longer than server timeout to account for communication overhead
# Server timeout is ~18s, so client should wait ~25s to allow for processing + communication
client_timeout = 25.0
try:
logger.debug(f"Starting MCP tool call with {client_timeout}s timeout...")
result = await asyncio.wait_for(
session.call_tool(generate_tool.name, arguments=arguments),
timeout=client_timeout
)
logger.info(f"✅ MCP tool call completed successfully")
except asyncio.TimeoutError:
logger.error(f"❌ MCP tool call timed out after {client_timeout}s")
logger.error(f" Tool: {generate_tool.name}, Model: {model or 'default'}")
logger.error(f" This suggests the MCP server (agent.py) is not responding or the Gemini API call is taking too long")
logger.error(f" Check if agent.py process is still running and responsive")
logger.error(f" Consider increasing GEMINI_TIMEOUT or checking network connectivity")
# Invalidate session on timeout to force retry
config.global_mcp_session = None
# Properly cleanup stdio context
if config.global_mcp_stdio_ctx is not None:
try:
await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
except Exception as cleanup_error:
logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
config.global_mcp_stdio_ctx = None
return ""
except Exception as call_error:
logger.error(f"❌ MCP tool call failed with exception: {type(call_error).__name__}: {call_error}")
import traceback
logger.error(f" Traceback: {traceback.format_exc()}")
# Invalidate session on error to force retry
config.global_mcp_session = None
# Properly cleanup stdio context
if config.global_mcp_stdio_ctx is not None:
try:
await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
except Exception as cleanup_error:
logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
config.global_mcp_stdio_ctx = None
raise # Re-raise to be caught by outer exception handler
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
response_text = item.text.strip()
if response_text:
logger.info(f"✅ Gemini MCP returned {len(response_text)} chars")
return response_text
logger.warning("⚠️ Gemini MCP returned empty or invalid result")
return ""
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"❌ Gemini MCP call error: {error_type}: {error_msg}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Invalidate session on error to force retry
config.global_mcp_session = None
# Properly cleanup stdio context
if config.global_mcp_stdio_ctx is not None:
try:
await config.global_mcp_stdio_ctx.__aexit__(None, None, None)
except Exception as cleanup_error:
logger.debug(f"Stdio context cleanup error (ignored): {cleanup_error}")
config.global_mcp_stdio_ctx = None
return ""