Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,136 Bytes
52b4ed7 c67b4e7 52b4ed7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""MCP session management and tool caching"""
import os
import time
import asyncio
from logger import logger
import config
# MCP imports
MCP_CLIENT_INFO = None
try:
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
from mcp.client.stdio import stdio_client
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
MCP_AVAILABLE = True
MCP_CLIENT_INFO = mcp_types.Implementation(
name="MedLLM-Agent",
version=os.environ.get("SPACE_VERSION", "local"),
)
except ImportError as e:
logger.warning(f"MCP SDK not available: {e}")
logger.info("The app will continue to work with fallback functionality (direct API calls)")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
except Exception as e:
logger.error(f"Unexpected error initializing MCP: {e}")
logger.info("The app will continue to work with fallback functionality")
MCP_AVAILABLE = False
MCP_CLIENT_INFO = None
async def get_mcp_session():
"""Get or create MCP client session with proper context management"""
if not MCP_AVAILABLE:
logger.warning("MCP not available - SDK not installed")
return None
if config.global_mcp_session is not None:
return config.global_mcp_session
try:
mcp_env = os.environ.copy()
if config.GEMINI_API_KEY:
mcp_env["GEMINI_API_KEY"] = config.GEMINI_API_KEY
else:
logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
if os.environ.get("GEMINI_MODEL"):
mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
if os.environ.get("GEMINI_TIMEOUT"):
mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
if os.environ.get("GEMINI_TEMPERATURE"):
mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
logger.info("Creating MCP client session...")
server_params = StdioServerParameters(
command=config.MCP_SERVER_COMMAND,
args=config.MCP_SERVER_ARGS,
env=mcp_env
)
stdio_ctx = stdio_client(server_params)
read, write = await stdio_ctx.__aenter__()
session = ClientSession(
read,
write,
client_info=MCP_CLIENT_INFO,
)
try:
await session.__aenter__()
init_result = await session.initialize()
server_info = getattr(init_result, "serverInfo", None)
server_name = getattr(server_info, "name", "unknown")
server_version = getattr(server_info, "version", "unknown")
logger.info(f"✅ MCP session initialized (server={server_name} v{server_version})")
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ MCP session initialization failed: {error_type}: {error_msg}")
try:
await session.__aexit__(None, None, None)
except Exception:
pass
try:
await stdio_ctx.__aexit__(None, None, None)
except Exception:
pass
return None
config.global_mcp_session = session
config.global_mcp_stdio_ctx = stdio_ctx
logger.info("✅ MCP client session created successfully")
return session
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"❌ Failed to create MCP client session: {error_type}: {error_msg}")
config.global_mcp_session = None
config.global_mcp_stdio_ctx = None
return None
def invalidate_mcp_tools_cache():
"""Invalidate cached MCP tool metadata"""
config.global_mcp_tools_cache = {"timestamp": 0.0, "tools": None}
async def get_cached_mcp_tools(force_refresh: bool = False):
"""Return cached MCP tools list to avoid repeated list_tools calls"""
if not MCP_AVAILABLE:
return []
now = time.time()
if (
not force_refresh
and config.global_mcp_tools_cache["tools"]
and now - config.global_mcp_tools_cache["timestamp"] < config.MCP_TOOLS_CACHE_TTL
):
return config.global_mcp_tools_cache["tools"]
session = await get_mcp_session()
if session is None:
return []
try:
tools_resp = await session.list_tools()
tools_list = list(getattr(tools_resp, "tools", []) or [])
config.global_mcp_tools_cache = {"timestamp": now, "tools": tools_list}
return tools_list
except Exception as e:
logger.error(f"Failed to refresh MCP tools: {e}")
invalidate_mcp_tools_cache()
return []
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
"""Call Gemini MCP generate_content tool"""
if not MCP_AVAILABLE:
logger.warning("MCP not available for Gemini call")
return ""
try:
session = await get_mcp_session()
if session is None:
logger.warning("Failed to get MCP session for Gemini call")
return ""
tools = await get_cached_mcp_tools()
if not tools:
tools = await get_cached_mcp_tools(force_refresh=True)
if not tools:
logger.error("Unable to obtain MCP tool catalog for Gemini calls")
return ""
generate_tool = None
for tool in tools:
if tool.name == "generate_content" or "generate_content" in tool.name.lower():
generate_tool = tool
logger.info(f"Found Gemini MCP tool: {tool.name}")
break
if not generate_tool:
logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools]}")
invalidate_mcp_tools_cache()
return ""
arguments = {
"user_prompt": user_prompt
}
if system_prompt:
arguments["system_prompt"] = system_prompt
if files:
arguments["files"] = files
if model:
arguments["model"] = model
if temperature is not None:
arguments["temperature"] = temperature
result = await session.call_tool(generate_tool.name, arguments=arguments)
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
response_text = item.text.strip()
return response_text
logger.warning("⚠️ Gemini MCP returned empty or invalid result")
return ""
except Exception as e:
logger.error(f"Gemini MCP call error: {e}")
return ""
|