Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,136 Bytes
52b4ed7 c67b4e7 52b4ed7 |
|
"""MCP session management and tool caching"""
import os
import time
import asyncio
from logger import logger
import config
# MCP imports
MCP_CLIENT_INFO = None
try:
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
from mcp.client.stdio import stdio_client
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
MCP_AVAILABLE = True
MCP_CLIENT_INFO = mcp_types.Implementation(
name="MedLLM-Agent",
version=os.environ.get("SPACE_VERSION", "local"),
)
except ImportError as e:
logger.warning(f"MCP SDK not available: {e}")
logger.info("The app will continue to work with fallback functionality (direct API calls)")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
except Exception as e:
logger.error(f"Unexpected error initializing MCP: {e}")
logger.info("The app will continue to work with fallback functionality")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
async def get_mcp_session():
"""Get or create MCP client session with proper context management"""
if not MCP_AVAILABLE:
logger.warning("MCP not available - SDK not installed")
return None
if config.global_mcp_session is not None:
return config.global_mcp_session
try:
mcp_env = os.environ.copy()
if config.GEMINI_API_KEY:
mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
else:
logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
if os.environ.get("GEMINI_MODEL"):
mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
if os.environ.get("GEMINI_TIMEOUT"):
mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
if os.environ.get("GEMINI_TEMPERATURE"):
mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
logger.info("Creating MCP client session...")
server_params = StdioServerParameters(
command=config.MCP_SERVER_COMMAND,
args=config.MCP_SERVER_ARGS,
env=mcp_env
)
stdio_ctx = stdio_client(server_params)
read, write = await stdio_ctx.__aenter__()
session = ClientSession(
read,
write,
client_info=MCP_CLIENT_INFO,
)
try:
await session.__aenter__()
init_result = await session.initialize()
server_info = getattr(init_result, "serverInfo", None)
server_name = getattr(server_info, "name", "unknown")
server_version = getattr(server_info, "version", "unknown")
logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})")
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
try:
await session.__aexit__(None, None, None)
except Exception:
pass
try:
await stdio_ctx.__aexit__(None, None, None)
except Exception:
pass
return None
config.global_mcp_session = session
config.global_mcp_stdio_ctx = stdio_ctx
logger.info("✅ MCP client session created successfully")
return session
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
config.global_mcp_session = None
config.global_mcp_stdio_ctx = None
return None
def invalidate_mcp_tools_cache():
"""Invalidate cached MCP tool metadata"""
config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
async def get_cached_mcp_tools(force_refresh: bool = False):
"""Return cached MCP tools list to avoid repeated list_tools calls"""
if not MCP_AVAILABLE:
return []
now = time.time()
if (
not force_refresh
and config.global_mcp_tools_cache["tools"]
and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
):
return config.global_mcp_tools_cache["tools"]
session = await get_mcp_session()
if session is None:
return []
try:
tools_resp = await session.list_tools()
tools_list = list(getattr(tools_resp, "tools", []) or [])
config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
return tools_list
except Exception as e:
logger.error(f"Failed to refresh MCP tools: {e}")
invalidate_mcp_tools_cache()
return []
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
"""Call Gemini MCP generate_content tool"""
if not MCP_AVAILABLE:
logger.warning("MCP not available for Gemini call")
return ""
try:
session = await get_mcp_session()
if session is None:
logger.warning("Failed to get MCP session for Gemini call")
return ""
tools = await get_cached_mcp_tools()
if not tools:
tools = await get_cached_mcp_tools(force_refresh=True)
if not tools:
logger.error("Unable to obtain MCP tool catalog for Gemini calls")
return ""
generate_tool = None
for tool in tools:
if tool.name == "generate_content" or "generate_content" in tool.name.lower():
generate_tool = tool
logger.info(f"Found Gemini MCP tool: {tool.name}")
break
if not generate_tool:
logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
invalidate_mcp_tools_cache()
return ""
arguments = {
"user_prompt": user_prompt
}
if system_prompt:
arguments["system_prompt"] = system_prompt
if files:
arguments["files"] = files
if model:
arguments["model"] = model
if temperature is not None:
arguments["temperature"] = temperature
result = await session.call_tool(generate_tool.name, arguments=arguments)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
response_text = item.text.strip()
return response_text
logger.warning("⚠️ Gemini MCP returned empty or invalid result")
return ""
except Exception as e:
logger.error(f"Gemini MCP call error: {e}")
return ""
|