LiamKhoaLe commited on
Commit
84f64fc
·
1 Parent(s): 9b1b152

Revert MedSwin GPU runner to commit #ec4d4b3

Browse files
Files changed (3) hide show
  1. agent.py +38 -38
  2. app.py +64 -204
  3. model.py +22 -199
agent.py CHANGED
@@ -17,13 +17,11 @@ from pathlib import Path
17
  # MCP imports
18
  try:
19
  from mcp.server import Server
20
- from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource, InitializationOptions, ServerCapabilities
21
- MCP_AVAILABLE = True
22
- except ImportError as e:
23
- print(f"Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
24
- print(f"Import error details: {e}", file=sys.stderr)
25
- MCP_AVAILABLE = False
26
- # Don't exit immediately - let the main function handle it gracefully
27
 
28
  # Gemini imports
29
  try:
@@ -273,23 +271,13 @@ async def call_tool(name: str, arguments: dict) -> Sequence[TextContent | ImageC
273
 
274
  async def main():
275
  """Main entry point"""
276
- if not MCP_AVAILABLE:
277
- logger.error("MCP SDK not available. Cannot start MCP server.")
278
- print("Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
279
- sys.exit(1)
280
-
281
  logger.info("Starting Gemini MCP Server...")
282
  logger.info(f"Gemini API Key: {'Set' if GEMINI_API_KEY else 'Not Set'}")
283
  logger.info(f"Default Model: {GEMINI_MODEL}")
284
  logger.info(f"Default Lite Model: {GEMINI_MODEL_LITE}")
285
 
286
  # Use stdio_server from mcp.server.stdio
287
- try:
288
- from mcp.server.stdio import stdio_server
289
- except ImportError as e:
290
- logger.error(f"Failed to import stdio_server: {e}")
291
- print(f"Error: Failed to import MCP stdio_server: {e}", file=sys.stderr)
292
- sys.exit(1)
293
 
294
  # Suppress root logger warnings during initialization
295
  # These are expected during the MCP initialization handshake
@@ -302,41 +290,53 @@ async def main():
302
  logging.getLogger("root").setLevel(original_root_level)
303
  logger.info("✅ MCP server stdio streams ready, starting server...")
304
 
305
- # Create initialization options for the server
306
- # The server capabilities are automatically determined from registered handlers
307
- # Try to use Server's built-in method to get capabilities, or create minimal options
308
  try:
309
- # Try to get capabilities from the server if it has a method
310
  if hasattr(app, 'get_capabilities'):
311
- capabilities = app.get_capabilities()
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  else:
313
- # Create minimal capabilities - tools are registered via @app.list_tools() decorator
314
- capabilities = ServerCapabilities(tools={})
315
- except:
316
- # Fallback: create minimal capabilities
317
- capabilities = ServerCapabilities(tools={})
 
318
 
319
- initialization_options = InitializationOptions(
 
 
 
320
  server_name="gemini-mcp-server",
321
  server_version="1.0.0",
322
- capabilities=capabilities
323
  )
324
 
325
- # Run the server - the Server class automatically handles initialization
326
- # The server will provide its capabilities based on registered handlers
327
- # (@app.list_tools() and @app.call_tool())
328
  await app.run(
329
  read_stream=streams[0],
330
  write_stream=streams[1],
331
- initialization_options=initialization_options
332
  )
333
  except Exception as e:
334
  logging.getLogger("root").setLevel(original_root_level)
335
  logger.error(f"❌ MCP server error: {e}")
336
- import traceback
337
- logger.debug(traceback.format_exc())
338
  raise
339
 
340
  if __name__ == "__main__":
341
- asyncio.run(main())
342
-
 
17
  # MCP imports
18
  try:
19
  from mcp.server import Server
20
+ from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource
21
+ from mcp.server.models import InitializationOptions
22
+ except ImportError:
23
+ print("Error: MCP SDK not installed. Install with: pip install mcp", file=sys.stderr)
24
+ sys.exit(1)
 
 
25
 
26
  # Gemini imports
27
  try:
 
271
 
272
  async def main():
273
  """Main entry point"""
 
 
 
 
 
274
  logger.info("Starting Gemini MCP Server...")
275
  logger.info(f"Gemini API Key: {'Set' if GEMINI_API_KEY else 'Not Set'}")
276
  logger.info(f"Default Model: {GEMINI_MODEL}")
277
  logger.info(f"Default Lite Model: {GEMINI_MODEL_LITE}")
278
 
279
  # Use stdio_server from mcp.server.stdio
280
+ from mcp.server.stdio import stdio_server
 
 
 
 
 
281
 
282
  # Suppress root logger warnings during initialization
283
  # These are expected during the MCP initialization handshake
 
290
  logging.getLogger("root").setLevel(original_root_level)
291
  logger.info("✅ MCP server stdio streams ready, starting server...")
292
 
293
+ # Create initialization options
294
+ # The Server class will automatically provide its capabilities based on
295
+ # the registered @app.list_tools() and @app.call_tool() handlers
296
  try:
297
+ # Try to get capabilities from the server if the method exists
298
  if hasattr(app, 'get_capabilities'):
299
+ try:
300
+ # Try with NotificationOptions if available
301
+ from mcp.server.lowlevel.server import NotificationOptions
302
+ server_capabilities = app.get_capabilities(
303
+ notification_options=NotificationOptions(),
304
+ experimental_capabilities={}
305
+ )
306
+ except (ImportError, AttributeError, TypeError):
307
+ # Fallback: try without NotificationOptions
308
+ try:
309
+ server_capabilities = app.get_capabilities()
310
+ except:
311
+ # If get_capabilities doesn't work, create minimal capabilities
312
+ server_capabilities = {}
313
  else:
314
+ # Server will provide capabilities automatically, use empty dict
315
+ server_capabilities = {}
316
+ except Exception as e:
317
+ logger.debug(f"Could not get server capabilities: {e}, server will provide defaults")
318
+ # Server will handle capabilities automatically
319
+ server_capabilities = {}
320
 
321
+ # Create initialization options
322
+ # The server_name and server_version are required
323
+ # Capabilities will be automatically determined by the Server from registered handlers
324
+ init_options = InitializationOptions(
325
  server_name="gemini-mcp-server",
326
  server_version="1.0.0",
327
+ capabilities=server_capabilities
328
  )
329
 
330
+ # Run the server with initialization options
 
 
331
  await app.run(
332
  read_stream=streams[0],
333
  write_stream=streams[1],
334
+ initialization_options=init_options
335
  )
336
  except Exception as e:
337
  logging.getLogger("root").setLevel(original_root_level)
338
  logger.error(f"❌ MCP server error: {e}")
 
 
339
  raise
340
 
341
  if __name__ == "__main__":
342
+ asyncio.run(main())
 
app.py CHANGED
@@ -281,65 +281,31 @@ async def get_mcp_session():
281
  stdio_ctx = stdio_client(server_params)
282
  read, write = await stdio_ctx.__aenter__()
283
 
284
- # Wait for the server process to fully start
 
 
 
 
285
  # The server needs time to: start Python, import modules, initialize Gemini client, start MCP server
286
  logger.info("⏳ Waiting for MCP server process to start...")
287
- # Increase wait time and add progressive checks
288
- for wait_attempt in range(5):
289
- await asyncio.sleep(1.0) # Check every second
290
- # Try to peek at the read stream to see if server is responding
291
- # (This is a simple check - the actual initialization will happen below)
292
- try:
293
- # Check if the process is still alive by attempting a small read with timeout
294
- # Note: This is a best-effort check
295
- pass
296
- except:
297
- pass
298
- logger.info("⏳ MCP server startup wait complete, proceeding with initialization...")
299
-
300
- # Create ClientSession from the streams
301
- # ClientSession handles initialization automatically when used as context manager
302
- # Use the session as a context manager to ensure proper initialization
303
- logger.info("🔄 Creating MCP client session...")
304
- try:
305
- from mcp.types import ClientInfo
306
- try:
307
- client_info = ClientInfo(
308
- name="medllm-agent",
309
- version="1.0.0"
310
- )
311
- session = ClientSession(read, write, client_info=client_info)
312
- except (TypeError, ValueError):
313
- # Fallback if ClientInfo parameters are incorrect
314
- session = ClientSession(read, write)
315
- except (ImportError, AttributeError):
316
- # Fallback if ClientInfo is not available
317
- session = ClientSession(read, write)
318
 
319
- # Initialize the session using context manager pattern
320
- # This properly handles the initialization handshake
321
- logger.info("🔄 Initializing MCP session...")
322
  try:
323
- # Enter the session context - this triggers initialization
 
324
  await session.__aenter__()
325
  logger.info("✅ MCP session initialized, verifying tools...")
326
  except Exception as e:
327
- logger.error(f"MCP session initialization failed: {e}")
328
- import traceback
329
- logger.debug(traceback.format_exc())
330
- # Clean up and return None
331
- try:
332
- await stdio_ctx.__aexit__(None, None, None)
333
- except:
334
- pass
335
- return None
336
 
337
- # Wait for the server to be fully ready after initialization
338
- await asyncio.sleep(1.0) # Wait after initialization
 
339
 
340
  # Verify the session works by listing tools with retries
341
  # This confirms the server is ready to handle requests
342
- max_init_retries = 15 # Increased retries
343
  tools_listed = False
344
  tools = None
345
  last_error = None
@@ -350,18 +316,6 @@ async def get_mcp_session():
350
  logger.info(f"✅ MCP server ready with {len(tools.tools)} tools: {[t.name for t in tools.tools]}")
351
  tools_listed = True
352
  break
353
- elif tools and hasattr(tools, 'tools'):
354
- # Empty tools list - might be a server issue
355
- logger.warning(f"MCP server returned empty tools list (attempt {init_attempt + 1}/{max_init_retries})")
356
- if init_attempt < max_init_retries - 1:
357
- await asyncio.sleep(1.5) # Slightly longer wait
358
- continue
359
- else:
360
- # Invalid response format
361
- logger.warning(f"MCP server returned invalid tools response (attempt {init_attempt + 1}/{max_init_retries})")
362
- if init_attempt < max_init_retries - 1:
363
- await asyncio.sleep(1.5)
364
- continue
365
  except Exception as e:
366
  last_error = e
367
  error_str = str(e).lower()
@@ -370,10 +324,8 @@ async def get_mcp_session():
370
  # Log the actual error for debugging
371
  if init_attempt == 0:
372
  logger.debug(f"First list_tools attempt failed: {error_msg}")
373
- elif init_attempt < 3:
374
- logger.debug(f"list_tools attempt {init_attempt + 1} failed: {error_msg}")
375
 
376
- # Handle different error types
377
  if "initialization" in error_str or "before initialization" in error_str or "not initialized" in error_str:
378
  if init_attempt < max_init_retries - 1:
379
  wait_time = 0.5 * (init_attempt + 1) # Progressive wait: 0.5s, 1s, 1.5s...
@@ -381,25 +333,18 @@ async def get_mcp_session():
381
  await asyncio.sleep(wait_time)
382
  continue
383
  elif "invalid request" in error_str or "invalid request parameters" in error_str:
384
- # Invalid request might mean the server isn't ready yet or there's a protocol issue
385
  if init_attempt < max_init_retries - 1:
386
- wait_time = 1.0 * (init_attempt + 1) # Longer wait for invalid request errors
387
  logger.debug(f"Invalid request error (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...")
388
  await asyncio.sleep(wait_time)
389
  continue
390
- else:
391
- # Last attempt failed - log detailed error
392
- logger.error(f"❌ Invalid request parameters error persists. This may indicate a protocol mismatch.")
393
- import traceback
394
- logger.debug(traceback.format_exc())
395
  elif init_attempt < max_init_retries - 1:
396
  wait_time = 0.5 * (init_attempt + 1)
397
  logger.debug(f"Tool listing attempt {init_attempt + 1}/{max_init_retries} failed: {error_msg}, waiting {wait_time}s...")
398
  await asyncio.sleep(wait_time)
399
  else:
400
  logger.error(f"❌ Could not list tools after {max_init_retries} attempts. Last error: {error_msg}")
401
- import traceback
402
- logger.debug(traceback.format_exc())
403
  # Don't continue - if we can't list tools, the session is not usable
404
  try:
405
  await session.__aexit__(None, None, None)
@@ -438,15 +383,7 @@ async def get_mcp_session():
438
  return None
439
 
440
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
441
- """
442
- Call Gemini MCP generate_content tool via MCP protocol.
443
-
444
- This function uses the MCP (Model Context Protocol) to call Gemini AI,
445
- NOT direct API calls. It connects to the bundled agent.py MCP server
446
- which provides the generate_content tool.
447
-
448
- Used for: translation, summarization, document parsing, transcription, reasoning
449
- """
450
  if not MCP_AVAILABLE:
451
  logger.warning("MCP not available for Gemini call")
452
  return ""
@@ -719,14 +656,9 @@ def generate_speech(text: str):
719
  return None
720
 
721
  def format_prompt_manually(messages: list, tokenizer) -> str:
722
- """Manually format prompt for models without chat template
 
723
 
724
- Following the exact example pattern from MedAlpaca documentation:
725
- - Simple Question/Answer format
726
- - System prompt as instruction context
727
- - Clean formatting without extra special tokens
728
- - Ensure no double special tokens are added
729
- """
730
  # Combine system and user messages into a single instruction
731
  system_content = ""
732
  user_content = ""
@@ -745,17 +677,11 @@ def format_prompt_manually(messages: list, tokenizer) -> str:
745
 
746
  # Format for MedAlpaca/LLaMA-based medical models
747
  # Common format: Instruction + Input -> Response
748
- # Following the exact example pattern - keep it simple and clean
749
- # The tokenizer will add BOS token automatically, so we don't add it here
750
  if system_content:
751
- # Clean format: system instruction, then question, then answer prompt
752
  prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
753
  else:
754
  prompt = f"Question: {user_content}\n\nAnswer:"
755
 
756
- # Ensure prompt is clean (no extra whitespace or special characters)
757
- prompt = prompt.strip()
758
-
759
  return prompt
760
 
761
  def detect_language(text: str) -> str:
@@ -835,12 +761,7 @@ def translate_text(text: str, target_lang: str = "en", source_lang: str = None)
835
  return text
836
 
837
  async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
838
- """
839
- Search web using MCP web search tool (e.g., DuckDuckGo MCP server).
840
-
841
- This function uses MCP tools for web search, NOT direct API calls.
842
- It looks for MCP tools with names containing "search", "duckduckgo", "ddg", or "web".
843
- """
844
  if not MCP_AVAILABLE:
845
  return []
846
 
@@ -1052,12 +973,9 @@ async def summarize_web_content_gemini(content_list: list, query: str) -> str:
1052
  combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
1053
 
1054
  user_prompt = f"""Summarize the following web search results related to the query: "{query}"
1055
-
1056
  Extract key medical information, facts, and insights. Be concise and focus on reliable information.
1057
-
1058
  Search Results:
1059
  {combined_content}
1060
-
1061
  Summary:"""
1062
 
1063
  # Use concise system prompt
@@ -1114,9 +1032,7 @@ async def autonomous_reasoning_gemini(query: str) -> dict:
1114
  """Autonomous reasoning using Gemini MCP"""
1115
  logger.info(f"🧠 [MCP] Analyzing query with Gemini MCP: {query[:100]}...")
1116
  reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
1117
-
1118
  Query: "{query}"
1119
-
1120
  Analyze:
1121
  1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
1122
  2. Complexity: (simple, moderate, complex, multi_faceted)
@@ -1124,7 +1040,6 @@ Analyze:
1124
  4. Requires RAG: (yes/no) - Does this need document context?
1125
  5. Requires Web Search: (yes/no) - Does this need current/updated information?
1126
  6. Sub-questions: Break down into key sub-questions if complex
1127
-
1128
  Respond in JSON format:
1129
  {{
1130
  "query_type": "...",
@@ -1326,17 +1241,14 @@ def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, us
1326
  async def self_reflection_gemini(answer: str, query: str) -> dict:
1327
  """Self-reflection using Gemini MCP"""
1328
  reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
1329
-
1330
  Query: "{query}"
1331
  Answer: "{answer[:1000]}"
1332
-
1333
  Evaluate:
1334
  1. Completeness: Does it address all aspects of the query?
1335
  2. Accuracy: Is the medical information accurate?
1336
  3. Clarity: Is it clear and well-structured?
1337
  4. Sources: Are sources cited appropriately?
1338
  5. Missing Information: What important information might be missing?
1339
-
1340
  Respond in JSON:
1341
  {{
1342
  "completeness_score": 0-10,
@@ -1624,7 +1536,6 @@ def stream_chat(
1624
  use_rag: bool,
1625
  medical_model: str,
1626
  use_web_search: bool,
1627
- disable_agentic_reasoning: bool,
1628
  request: gr.Request
1629
  ):
1630
  if not request:
@@ -1635,57 +1546,36 @@ def stream_chat(
1635
  index_dir = f"./{user_id}_index"
1636
  has_rag_index = os.path.exists(index_dir)
1637
 
1638
- # If agentic reasoning is disabled, use base MedSwin model only
1639
- if disable_agentic_reasoning:
1640
- logger.info("🚫 Agentic reasoning disabled - using base MedSwin model only")
1641
- # Skip all MCP functionality: reasoning, translation, web search, summarization
1642
- original_message = message
1643
- original_lang = "en" # Assume English, no translation
1644
- needs_translation = False
1645
- final_use_rag = use_rag and has_rag_index # Still allow RAG if user wants it
1646
- final_use_web_search = False # Disable web search when agentic reasoning is off
1647
- reasoning_note = ""
1648
-
1649
- # Simple reasoning structure for base model
1650
- reasoning = {
1651
- "query_type": "general_info",
1652
- "complexity": "simple",
1653
- "information_needs": ["direct_answer"],
1654
- "requires_rag": final_use_rag,
1655
- "requires_web_search": False,
1656
- "sub_questions": [message]
1657
- }
1658
- else:
1659
- # ===== AUTONOMOUS REASONING =====
1660
- logger.info("🤔 Starting autonomous reasoning...")
1661
- reasoning = autonomous_reasoning(message, history)
1662
-
1663
- # ===== PLANNING =====
1664
- logger.info("📋 Creating execution plan...")
1665
- plan = create_execution_plan(reasoning, message, has_rag_index)
1666
-
1667
- # ===== AUTONOMOUS EXECUTION STRATEGY =====
1668
- logger.info("🎯 Determining execution strategy...")
1669
- execution_strategy = autonomous_execution_strategy(reasoning, plan, use_rag, use_web_search, has_rag_index)
1670
-
1671
- # Use autonomous strategy decisions (respect user's RAG setting)
1672
- final_use_rag = execution_strategy["use_rag"] and has_rag_index # Only use RAG if enabled AND documents exist
1673
- final_use_web_search = execution_strategy["use_web_search"]
1674
-
1675
- # Show reasoning override message if applicable
1676
- reasoning_note = ""
1677
- if execution_strategy["reasoning_override"]:
1678
- reasoning_note = f"\n\n💡 *Autonomous Reasoning: {execution_strategy['rationale']}*"
1679
-
1680
- # Detect language and translate if needed (Step 1 of plan)
1681
- original_lang = detect_language(message)
1682
- original_message = message
1683
- needs_translation = original_lang != "en"
1684
-
1685
- if needs_translation:
1686
- logger.info(f"Detected non-English language: {original_lang}, translating to English...")
1687
- message = translate_text(message, target_lang="en", source_lang=original_lang)
1688
- logger.info(f"Translated query: {message}")
1689
 
1690
  # Initialize medical model
1691
  medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
@@ -1696,8 +1586,8 @@ def stream_chat(
1696
  else:
1697
  base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only."
1698
 
1699
- # Add reasoning context to system prompt for complex queries (only when agentic reasoning is enabled)
1700
- if not disable_agentic_reasoning and reasoning["complexity"] in ["complex", "multi_faceted"]:
1701
  base_system_prompt += f"\n\nQuery Analysis: This is a {reasoning['complexity']} {reasoning['query_type']} query. Address all sub-questions: {', '.join(reasoning.get('sub_questions', [])[:3])}"
1702
 
1703
  # ===== EXECUTION: RAG Retrieval (Step 2) =====
@@ -1737,12 +1627,10 @@ def stream_chat(
1737
  web_sources = []
1738
  web_urls = [] # Store URLs for citations
1739
  if final_use_web_search:
1740
- logger.info("🌐 Performing web search (using MCP tools, with Gemini MCP for summarization)...")
1741
- # search_web() tries MCP web search tool first, then falls back to direct API
1742
  web_results = search_web(message, max_results=5)
1743
  if web_results:
1744
  logger.info(f"📊 Found {len(web_results)} web search results, now summarizing with Gemini MCP...")
1745
- # summarize_web_content() uses Gemini MCP via call_agent()
1746
  web_summary = summarize_web_content(web_results, message)
1747
  if web_summary and len(web_summary) > 50: # Check if we got a real summary
1748
  logger.info(f"✅ [MCP] Gemini MCP summarization successful ({len(web_summary)} chars)")
@@ -1788,9 +1676,7 @@ def stream_chat(
1788
  max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
1789
  max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
1790
 
1791
- # Format prompt - MedAlpaca/MedSwin models typically don't have chat templates
1792
- # Use manual formatting for consistent behavior
1793
- # Following the example: check if tokenizer has chat template, otherwise format manually
1794
  if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
1795
  try:
1796
  prompt = medical_tokenizer.apply_chat_template(
@@ -1806,18 +1692,8 @@ def stream_chat(
1806
  # Manual formatting for models without chat template
1807
  prompt = format_prompt_manually(messages, medical_tokenizer)
1808
 
1809
- # Calculate prompt length for stopping criteria
1810
- # Tokenize to get length - use EXACT same tokenization as model.py
1811
- # This ensures consistency and prevents tokenization mismatches
1812
- inputs = medical_tokenizer(
1813
- prompt,
1814
- return_tensors="pt",
1815
- add_special_tokens=True, # Match model.py tokenization
1816
- padding=False,
1817
- truncation=False
1818
- )
1819
  prompt_length = inputs['input_ids'].shape[1]
1820
- logger.debug(f"Prompt length: {prompt_length} tokens")
1821
 
1822
  stop_event = threading.Event()
1823
 
@@ -1830,23 +1706,19 @@ def stream_chat(
1830
  return self.stop_event.is_set()
1831
 
1832
  # Custom stopping criteria that doesn't stop on EOS too early
1833
- # This prevents premature stopping which can cause corrupted outputs
1834
- # Following the example: use min_new_tokens=100 to ensure proper generation
1835
  class MedicalStoppingCriteria(StoppingCriteria):
1836
  def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
1837
  super().__init__()
1838
  self.eos_token_id = eos_token_id
1839
  self.prompt_length = prompt_length
1840
  self.min_new_tokens = min_new_tokens
1841
-
1842
  def __call__(self, input_ids, scores, **kwargs):
1843
  current_length = input_ids.shape[1]
1844
  new_tokens = current_length - self.prompt_length
1845
  last_token = input_ids[0, -1].item()
1846
 
1847
  # Don't stop on EOS if we haven't generated enough new tokens
1848
- # This prevents early stopping that can cause corrupted outputs
1849
- # Following example: require at least min_new_tokens before allowing EOS
1850
  if new_tokens < self.min_new_tokens:
1851
  return False
1852
  # Allow EOS after minimum new tokens have been generated
@@ -1857,13 +1729,10 @@ def stream_chat(
1857
  MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
1858
  ])
1859
 
1860
- # Create streamer with correct settings for LLaMA-based models
1861
- # skip_special_tokens=True ensures clean text output without special token artifacts
1862
  streamer = TextIteratorStreamer(
1863
  medical_tokenizer,
1864
  skip_prompt=True,
1865
- skip_special_tokens=True, # Skip special tokens in output for clean text
1866
- timeout=None # Don't timeout on long generations
1867
  )
1868
 
1869
  temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
@@ -1906,8 +1775,7 @@ def stream_chat(
1906
  yield updated_history
1907
 
1908
  # ===== SELF-REFLECTION (Step 6) =====
1909
- # Skip self-reflection when agentic reasoning is disabled
1910
- if not disable_agentic_reasoning and reasoning["complexity"] in ["complex", "multi_faceted"]:
1911
  logger.info("🔍 Performing self-reflection on answer quality...")
1912
  reflection = self_reflection(partial_response, message, reasoning)
1913
 
@@ -1924,8 +1792,8 @@ def stream_chat(
1924
  partial_response = reasoning_note + "\n\n" + partial_response
1925
  updated_history[-1]["content"] = partial_response
1926
 
1927
- # Translate back if needed (only when agentic reasoning is enabled)
1928
- if not disable_agentic_reasoning and needs_translation and partial_response:
1929
  logger.info(f"Translating response back to {original_lang}...")
1930
  translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
1931
  partial_response = translated_response
@@ -2098,12 +1966,6 @@ def create_demo():
2098
  )
2099
 
2100
  with gr.Accordion("⚙️ Advanced Settings", open=False):
2101
- with gr.Row():
2102
- disable_agentic_reasoning = gr.Checkbox(
2103
- value=False,
2104
- label="Disable Agentic Reasoning",
2105
- info="Use base MedSwin model only, no MCP tools (Gemini, web search, reasoning)"
2106
- )
2107
  with gr.Row():
2108
  use_rag = gr.Checkbox(
2109
  value=False,
@@ -2198,8 +2060,7 @@ def create_demo():
2198
  merge_threshold,
2199
  use_rag,
2200
  medical_model,
2201
- use_web_search,
2202
- disable_agentic_reasoning
2203
  ],
2204
  outputs=chatbot
2205
  )
@@ -2219,8 +2080,7 @@ def create_demo():
2219
  merge_threshold,
2220
  use_rag,
2221
  medical_model,
2222
- use_web_search,
2223
- disable_agentic_reasoning
2224
  ],
2225
  outputs=chatbot
2226
  )
@@ -2251,4 +2111,4 @@ if __name__ == "__main__":
2251
 
2252
  logger.info("Model preloading complete!")
2253
  demo = create_demo()
2254
- demo.launch()
 
281
  stdio_ctx = stdio_client(server_params)
282
  read, write = await stdio_ctx.__aenter__()
283
 
284
+ # Create ClientSession from the streams
285
+ # The __aenter__() method automatically handles the initialization handshake
286
+ session = ClientSession(read, write)
287
+
288
+ # Wait longer for the server process to fully start
289
  # The server needs time to: start Python, import modules, initialize Gemini client, start MCP server
290
  logger.info("⏳ Waiting for MCP server process to start...")
291
+ await asyncio.sleep(3.0) # Increased wait for server process startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
 
 
 
293
  try:
294
+ # Initialize the session (this sends initialize request and waits for response)
295
+ logger.info("🔄 Initializing MCP session...")
296
  await session.__aenter__()
297
  logger.info("✅ MCP session initialized, verifying tools...")
298
  except Exception as e:
299
+ logger.warning(f"MCP session initialization had an issue (may be expected): {e}")
300
+ # Continue anyway - the session might still work
 
 
 
 
 
 
 
301
 
302
+ # Wait longer for the server to be fully ready after initialization
303
+ # The server needs time to process the initialization and be ready for requests
304
+ await asyncio.sleep(2.0) # Wait after initialization
305
 
306
  # Verify the session works by listing tools with retries
307
  # This confirms the server is ready to handle requests
308
+ max_init_retries = 15
309
  tools_listed = False
310
  tools = None
311
  last_error = None
 
316
  logger.info(f"✅ MCP server ready with {len(tools.tools)} tools: {[t.name for t in tools.tools]}")
317
  tools_listed = True
318
  break
 
 
 
 
 
 
 
 
 
 
 
 
319
  except Exception as e:
320
  last_error = e
321
  error_str = str(e).lower()
 
324
  # Log the actual error for debugging
325
  if init_attempt == 0:
326
  logger.debug(f"First list_tools attempt failed: {error_msg}")
 
 
327
 
328
+ # Ignore initialization-related errors during the handshake phase
329
  if "initialization" in error_str or "before initialization" in error_str or "not initialized" in error_str:
330
  if init_attempt < max_init_retries - 1:
331
  wait_time = 0.5 * (init_attempt + 1) # Progressive wait: 0.5s, 1s, 1.5s...
 
333
  await asyncio.sleep(wait_time)
334
  continue
335
  elif "invalid request" in error_str or "invalid request parameters" in error_str:
336
+ # This might be a timing issue - wait and retry
337
  if init_attempt < max_init_retries - 1:
338
+ wait_time = 0.8 * (init_attempt + 1) # Longer wait for invalid request errors
339
  logger.debug(f"Invalid request error (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...")
340
  await asyncio.sleep(wait_time)
341
  continue
 
 
 
 
 
342
  elif init_attempt < max_init_retries - 1:
343
  wait_time = 0.5 * (init_attempt + 1)
344
  logger.debug(f"Tool listing attempt {init_attempt + 1}/{max_init_retries} failed: {error_msg}, waiting {wait_time}s...")
345
  await asyncio.sleep(wait_time)
346
  else:
347
  logger.error(f"❌ Could not list tools after {max_init_retries} attempts. Last error: {error_msg}")
 
 
348
  # Don't continue - if we can't list tools, the session is not usable
349
  try:
350
  await session.__aexit__(None, None, None)
 
383
  return None
384
 
385
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
386
+ """Call Gemini MCP generate_content tool"""
 
 
 
 
 
 
 
 
387
  if not MCP_AVAILABLE:
388
  logger.warning("MCP not available for Gemini call")
389
  return ""
 
656
  return None
657
 
658
  def format_prompt_manually(messages: list, tokenizer) -> str:
659
+ """Manually format prompt for models without chat template"""
660
+ prompt_parts = []
661
 
 
 
 
 
 
 
662
  # Combine system and user messages into a single instruction
663
  system_content = ""
664
  user_content = ""
 
677
 
678
  # Format for MedAlpaca/LLaMA-based medical models
679
  # Common format: Instruction + Input -> Response
 
 
680
  if system_content:
 
681
  prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
682
  else:
683
  prompt = f"Question: {user_content}\n\nAnswer:"
684
 
 
 
 
685
  return prompt
686
 
687
  def detect_language(text: str) -> str:
 
761
  return text
762
 
763
  async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
764
+ """Search web using MCP web search tool (e.g., DuckDuckGo MCP server)"""
 
 
 
 
 
765
  if not MCP_AVAILABLE:
766
  return []
767
 
 
973
  combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
974
 
975
  user_prompt = f"""Summarize the following web search results related to the query: "{query}"
 
976
  Extract key medical information, facts, and insights. Be concise and focus on reliable information.
 
977
  Search Results:
978
  {combined_content}
 
979
  Summary:"""
980
 
981
  # Use concise system prompt
 
1032
  """Autonomous reasoning using Gemini MCP"""
1033
  logger.info(f"🧠 [MCP] Analyzing query with Gemini MCP: {query[:100]}...")
1034
  reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
 
1035
  Query: "{query}"
 
1036
  Analyze:
1037
  1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
1038
  2. Complexity: (simple, moderate, complex, multi_faceted)
 
1040
  4. Requires RAG: (yes/no) - Does this need document context?
1041
  5. Requires Web Search: (yes/no) - Does this need current/updated information?
1042
  6. Sub-questions: Break down into key sub-questions if complex
 
1043
  Respond in JSON format:
1044
  {{
1045
  "query_type": "...",
 
1241
  async def self_reflection_gemini(answer: str, query: str) -> dict:
1242
  """Self-reflection using Gemini MCP"""
1243
  reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
 
1244
  Query: "{query}"
1245
  Answer: "{answer[:1000]}"
 
1246
  Evaluate:
1247
  1. Completeness: Does it address all aspects of the query?
1248
  2. Accuracy: Is the medical information accurate?
1249
  3. Clarity: Is it clear and well-structured?
1250
  4. Sources: Are sources cited appropriately?
1251
  5. Missing Information: What important information might be missing?
 
1252
  Respond in JSON:
1253
  {{
1254
  "completeness_score": 0-10,
 
1536
  use_rag: bool,
1537
  medical_model: str,
1538
  use_web_search: bool,
 
1539
  request: gr.Request
1540
  ):
1541
  if not request:
 
1546
  index_dir = f"./{user_id}_index"
1547
  has_rag_index = os.path.exists(index_dir)
1548
 
1549
+ # ===== AUTONOMOUS REASONING =====
1550
+ logger.info("🤔 Starting autonomous reasoning...")
1551
+ reasoning = autonomous_reasoning(message, history)
1552
+
1553
+ # ===== PLANNING =====
1554
+ logger.info("📋 Creating execution plan...")
1555
+ plan = create_execution_plan(reasoning, message, has_rag_index)
1556
+
1557
+ # ===== AUTONOMOUS EXECUTION STRATEGY =====
1558
+ logger.info("🎯 Determining execution strategy...")
1559
+ execution_strategy = autonomous_execution_strategy(reasoning, plan, use_rag, use_web_search, has_rag_index)
1560
+
1561
+ # Use autonomous strategy decisions (respect user's RAG setting)
1562
+ final_use_rag = execution_strategy["use_rag"] and has_rag_index # Only use RAG if enabled AND documents exist
1563
+ final_use_web_search = execution_strategy["use_web_search"]
1564
+
1565
+ # Show reasoning override message if applicable
1566
+ reasoning_note = ""
1567
+ if execution_strategy["reasoning_override"]:
1568
+ reasoning_note = f"\n\n💡 *Autonomous Reasoning: {execution_strategy['rationale']}*"
1569
+
1570
+ # Detect language and translate if needed (Step 1 of plan)
1571
+ original_lang = detect_language(message)
1572
+ original_message = message
1573
+ needs_translation = original_lang != "en"
1574
+
1575
+ if needs_translation:
1576
+ logger.info(f"Detected non-English language: {original_lang}, translating to English...")
1577
+ message = translate_text(message, target_lang="en", source_lang=original_lang)
1578
+ logger.info(f"Translated query: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1579
 
1580
  # Initialize medical model
1581
  medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
 
1586
  else:
1587
  base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only."
1588
 
1589
+ # Add reasoning context to system prompt for complex queries
1590
+ if reasoning["complexity"] in ["complex", "multi_faceted"]:
1591
  base_system_prompt += f"\n\nQuery Analysis: This is a {reasoning['complexity']} {reasoning['query_type']} query. Address all sub-questions: {', '.join(reasoning.get('sub_questions', [])[:3])}"
1592
 
1593
  # ===== EXECUTION: RAG Retrieval (Step 2) =====
 
1627
  web_sources = []
1628
  web_urls = [] # Store URLs for citations
1629
  if final_use_web_search:
1630
+ logger.info("🌐 Performing web search (will use Gemini MCP for summarization)...")
 
1631
  web_results = search_web(message, max_results=5)
1632
  if web_results:
1633
  logger.info(f"📊 Found {len(web_results)} web search results, now summarizing with Gemini MCP...")
 
1634
  web_summary = summarize_web_content(web_results, message)
1635
  if web_summary and len(web_summary) > 50: # Check if we got a real summary
1636
  logger.info(f"✅ [MCP] Gemini MCP summarization successful ({len(web_summary)} chars)")
 
1676
  max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
1677
  max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
1678
 
1679
+ # Check if tokenizer has chat template, otherwise format manually
 
 
1680
  if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
1681
  try:
1682
  prompt = medical_tokenizer.apply_chat_template(
 
1692
  # Manual formatting for models without chat template
1693
  prompt = format_prompt_manually(messages, medical_tokenizer)
1694
 
1695
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
 
 
 
 
 
 
 
 
 
1696
  prompt_length = inputs['input_ids'].shape[1]
 
1697
 
1698
  stop_event = threading.Event()
1699
 
 
1706
  return self.stop_event.is_set()
1707
 
1708
  # Custom stopping criteria that doesn't stop on EOS too early
 
 
1709
  class MedicalStoppingCriteria(StoppingCriteria):
1710
  def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
1711
  super().__init__()
1712
  self.eos_token_id = eos_token_id
1713
  self.prompt_length = prompt_length
1714
  self.min_new_tokens = min_new_tokens
1715
+
1716
  def __call__(self, input_ids, scores, **kwargs):
1717
  current_length = input_ids.shape[1]
1718
  new_tokens = current_length - self.prompt_length
1719
  last_token = input_ids[0, -1].item()
1720
 
1721
  # Don't stop on EOS if we haven't generated enough new tokens
 
 
1722
  if new_tokens < self.min_new_tokens:
1723
  return False
1724
  # Allow EOS after minimum new tokens have been generated
 
1729
  MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
1730
  ])
1731
 
 
 
1732
  streamer = TextIteratorStreamer(
1733
  medical_tokenizer,
1734
  skip_prompt=True,
1735
+ skip_special_tokens=True
 
1736
  )
1737
 
1738
  temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
 
1775
  yield updated_history
1776
 
1777
  # ===== SELF-REFLECTION (Step 6) =====
1778
+ if reasoning["complexity"] in ["complex", "multi_faceted"]:
 
1779
  logger.info("🔍 Performing self-reflection on answer quality...")
1780
  reflection = self_reflection(partial_response, message, reasoning)
1781
 
 
1792
  partial_response = reasoning_note + "\n\n" + partial_response
1793
  updated_history[-1]["content"] = partial_response
1794
 
1795
+ # Translate back if needed
1796
+ if needs_translation and partial_response:
1797
  logger.info(f"Translating response back to {original_lang}...")
1798
  translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
1799
  partial_response = translated_response
 
1966
  )
1967
 
1968
  with gr.Accordion("⚙️ Advanced Settings", open=False):
 
 
 
 
 
 
1969
  with gr.Row():
1970
  use_rag = gr.Checkbox(
1971
  value=False,
 
2060
  merge_threshold,
2061
  use_rag,
2062
  medical_model,
2063
+ use_web_search
 
2064
  ],
2065
  outputs=chatbot
2066
  )
 
2080
  merge_threshold,
2081
  use_rag,
2082
  medical_model,
2083
+ use_web_search
 
2084
  ],
2085
  outputs=chatbot
2086
  )
 
2111
 
2112
  logger.info("Model preloading complete!")
2113
  demo = create_demo()
2114
+ demo.launch()
model.py CHANGED
@@ -38,51 +38,12 @@ global_medical_tokenizers = {}
38
 
39
 
40
  def initialize_medical_model(model_name: str):
41
- """Initialize medical model (MedSwin) - download on demand
42
-
43
- Following standard MedAlpaca/LLaMA initialization pattern:
44
- - Simple tokenizer loading without over-complication
45
- - Model loading with device_map="auto" for ZeroGPU Spaces
46
- - Proper pad_token setup for LLaMA-based models
47
- - Float16 for memory efficiency
48
- - Ensure tokenizer padding side is set correctly
49
- """
50
  global global_medical_models, global_medical_tokenizers
51
-
52
  if model_name not in global_medical_models or global_medical_models[model_name] is None:
53
  logger.info(f"Initializing medical model: {model_name}...")
54
  model_path = MEDSWIN_MODELS[model_name]
55
-
56
- # Load tokenizer - simple and clean, following example pattern
57
- # Use fast tokenizer if available (default), fallback to slow if needed
58
- try:
59
- tokenizer = AutoTokenizer.from_pretrained(
60
- model_path,
61
- token=HF_TOKEN,
62
- trust_remote_code=True
63
- )
64
- except Exception as e:
65
- logger.warning(f"Failed to load fast tokenizer, trying slow tokenizer: {e}")
66
- tokenizer = AutoTokenizer.from_pretrained(
67
- model_path,
68
- token=HF_TOKEN,
69
- use_fast=False,
70
- trust_remote_code=True
71
- )
72
-
73
- # LLaMA models don't have pad_token by default, set it to eos_token
74
- if tokenizer.pad_token is None:
75
- tokenizer.pad_token = tokenizer.eos_token
76
- tokenizer.pad_token_id = tokenizer.eos_token_id
77
-
78
- # Set padding side to left for generation (LLaMA models expect this)
79
- tokenizer.padding_side = "left"
80
-
81
- # Ensure tokenizer is properly configured
82
- if not hasattr(tokenizer, 'model_max_length') or tokenizer.model_max_length is None:
83
- tokenizer.model_max_length = 4096
84
-
85
- # Load model - use device_map="auto" for ZeroGPU Spaces
86
  model = AutoModelForCausalLM.from_pretrained(
87
  model_path,
88
  device_map="auto",
@@ -90,22 +51,13 @@ def initialize_medical_model(model_name: str):
90
  token=HF_TOKEN,
91
  torch_dtype=torch.float16
92
  )
93
-
94
- # Ensure model is in eval mode
95
- model.eval()
96
-
97
  global_medical_models[model_name] = model
98
  global_medical_tokenizers[model_name] = tokenizer
99
  logger.info(f"Medical model {model_name} initialized successfully")
100
- logger.info(f"Model device: {next(model.parameters()).device}")
101
- logger.info(f"Tokenizer vocab size: {len(tokenizer)}")
102
- logger.info(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
103
- logger.info(f"PAD token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")
104
- logger.info(f"Tokenizer padding side: {tokenizer.padding_side}")
105
-
106
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
107
 
108
 
 
109
  def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
110
  """Get LLM for RAG indexing (uses medical model) - GPU only"""
111
  # Use medical model for RAG indexing instead of translation model
@@ -125,125 +77,13 @@ def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
125
  )
126
 
127
 
 
128
  def get_embedding_model():
129
  """Get embedding model for RAG - GPU only"""
130
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
131
 
132
- def _generate_with_medswin_internal(
133
- medical_model_obj,
134
- medical_tokenizer,
135
- prompt: str,
136
- max_new_tokens: int,
137
- temperature: float,
138
- top_p: float,
139
- top_k: int,
140
- penalty: float,
141
- eos_token_id: int,
142
- pad_token_id: int,
143
- prompt_length: int,
144
- min_new_tokens: int = 100,
145
- streamer: TextIteratorStreamer = None,
146
- stopping_criteria: StoppingCriteriaList = None
147
- ):
148
- """
149
- Internal generation function that runs directly on GPU.
150
- Model is already on GPU via device_map="auto", so no @spaces.GPU decorator needed.
151
- This avoids pickling issues with streamer and stopping_criteria.
152
- """
153
- # Ensure model is in evaluation mode
154
- medical_model_obj.eval()
155
-
156
- # Get device - handle device_map="auto" case
157
- device = next(medical_model_obj.parameters()).device
158
-
159
- # Tokenize prompt - CRITICAL: use consistent tokenization settings
160
- # For LLaMA-based models, the tokenizer automatically adds BOS token
161
- inputs = medical_tokenizer(
162
- prompt,
163
- return_tensors="pt",
164
- add_special_tokens=True, # Let tokenizer add BOS/EOS as needed
165
- padding=False, # No padding for single sequence generation
166
- truncation=False # Don't truncate - let model handle length
167
- ).to(device)
168
-
169
- # Log tokenization info for debugging
170
- actual_prompt_length = inputs['input_ids'].shape[1]
171
- logger.info(f"Tokenized prompt: {actual_prompt_length} tokens on device {device}")
172
-
173
- # Use provided streamer and stopping_criteria (created in caller to avoid pickling)
174
- if streamer is None:
175
- streamer = TextIteratorStreamer(
176
- medical_tokenizer,
177
- skip_prompt=True,
178
- skip_special_tokens=True,
179
- timeout=None
180
- )
181
-
182
- if stopping_criteria is None:
183
- # Create simple stopping criteria if not provided
184
- class SimpleStoppingCriteria(StoppingCriteria):
185
- def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
186
- super().__init__()
187
- self.eos_token_id = eos_token_id
188
- self.prompt_length = prompt_length
189
- self.min_new_tokens = min_new_tokens
190
-
191
- def __call__(self, input_ids, scores, **kwargs):
192
- current_length = input_ids.shape[1]
193
- new_tokens = current_length - self.prompt_length
194
- last_token = input_ids[0, -1].item()
195
-
196
- # Don't stop on EOS if we haven't generated enough new tokens
197
- if new_tokens < self.min_new_tokens:
198
- return False
199
- # Allow EOS after minimum new tokens have been generated
200
- return last_token == self.eos_token_id
201
-
202
- stopping_criteria = StoppingCriteriaList([
203
- SimpleStoppingCriteria(eos_token_id, actual_prompt_length, min_new_tokens)
204
- ])
205
-
206
- # Prepare generation kwargs - following standard MedAlpaca/LLaMA pattern
207
- # Ensure all parameters are valid and within expected ranges
208
- generation_kwargs = {
209
- **inputs, # Unpack input_ids and attention_mask
210
- "streamer": streamer,
211
- "max_new_tokens": max_new_tokens,
212
- "temperature": max(0.01, min(temperature, 2.0)), # Clamp temperature to valid range
213
- "top_p": max(0.0, min(top_p, 1.0)), # Clamp top_p to valid range
214
- "top_k": max(1, int(top_k)), # Ensure top_k is at least 1
215
- "repetition_penalty": max(1.0, min(penalty, 2.0)), # Clamp repetition_penalty
216
- "do_sample": True,
217
- "stopping_criteria": stopping_criteria,
218
- "eos_token_id": eos_token_id,
219
- "pad_token_id": pad_token_id
220
- }
221
-
222
- # Validate token IDs are valid
223
- if eos_token_id is None or eos_token_id < 0:
224
- logger.warning(f"Invalid EOS token ID: {eos_token_id}, using tokenizer default")
225
- eos_token_id = medical_tokenizer.eos_token_id or medical_tokenizer.pad_token_id
226
- generation_kwargs["eos_token_id"] = eos_token_id
227
-
228
- if pad_token_id is None or pad_token_id < 0:
229
- logger.warning(f"Invalid PAD token ID: {pad_token_id}, using EOS token")
230
- pad_token_id = eos_token_id
231
- generation_kwargs["pad_token_id"] = pad_token_id
232
-
233
- # Run generation on GPU with torch.no_grad() for efficiency
234
- # Model is already on GPU, so this will run on GPU automatically
235
- with torch.no_grad():
236
- try:
237
- logger.debug(f"Starting generation with max_new_tokens={max_new_tokens}, temperature={generation_kwargs['temperature']}, top_p={generation_kwargs['top_p']}, top_k={generation_kwargs['top_k']}")
238
- logger.debug(f"EOS token ID: {eos_token_id}, PAD token ID: {pad_token_id}")
239
- medical_model_obj.generate(**generation_kwargs)
240
- except Exception as e:
241
- logger.error(f"Error during generation: {e}")
242
- import traceback
243
- logger.error(traceback.format_exc())
244
- raise
245
-
246
 
 
247
  def generate_with_medswin(
248
  medical_model_obj,
249
  medical_tokenizer,
@@ -260,46 +100,29 @@ def generate_with_medswin(
260
  stopping_criteria: StoppingCriteriaList
261
  ):
262
  """
263
- Public API function for model generation.
264
 
265
- This function is NOT decorated with @spaces.GPU because:
266
- 1. The model is already on GPU via device_map="auto" during initialization
267
- 2. Generation will automatically run on GPU where the model is located
268
- 3. This avoids pickling issues with streamer, stop_event, and stopping_criteria
269
-
270
- The @spaces.GPU decorator is only needed for model loading, which is handled
271
- separately in initialize_medical_model (though that also doesn't need it since
272
- device_map="auto" handles GPU placement).
273
  """
274
- # Calculate prompt length for stopping criteria (if not already calculated)
275
- inputs = medical_tokenizer(
276
- prompt,
277
- return_tensors="pt",
278
- add_special_tokens=True,
279
- padding=False,
280
- truncation=False
281
- )
282
- prompt_length = inputs['input_ids'].shape[1]
283
-
284
- # Call internal generation function directly
285
- # Model is already on GPU, so generation will happen on GPU automatically
286
- _generate_with_medswin_internal(
287
- medical_model_obj=medical_model_obj,
288
- medical_tokenizer=medical_tokenizer,
289
- prompt=prompt,
290
  max_new_tokens=max_new_tokens,
291
  temperature=temperature,
292
  top_p=top_p,
293
  top_k=top_k,
294
- penalty=penalty,
 
 
295
  eos_token_id=eos_token_id,
296
- pad_token_id=pad_token_id,
297
- prompt_length=prompt_length,
298
- min_new_tokens=100,
299
- streamer=streamer, # Use the provided streamer (created in caller)
300
- stopping_criteria=stopping_criteria # Use the provided stopping criteria
301
  )
302
 
303
- # Function returns immediately - generation happens in background via streamer
304
- return
305
-
 
38
 
39
 
40
  def initialize_medical_model(model_name: str):
41
+ """Initialize medical model (MedSwin) - download on demand"""
 
 
 
 
 
 
 
 
42
  global global_medical_models, global_medical_tokenizers
 
43
  if model_name not in global_medical_models or global_medical_models[model_name] is None:
44
  logger.info(f"Initializing medical model: {model_name}...")
45
  model_path = MEDSWIN_MODELS[model_name]
46
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_path,
49
  device_map="auto",
 
51
  token=HF_TOKEN,
52
  torch_dtype=torch.float16
53
  )
 
 
 
 
54
  global_medical_models[model_name] = model
55
  global_medical_tokenizers[model_name] = tokenizer
56
  logger.info(f"Medical model {model_name} initialized successfully")
 
 
 
 
 
 
57
  return global_medical_models[model_name], global_medical_tokenizers[model_name]
58
 
59
 
60
+ @spaces.GPU(max_duration=120)
61
  def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
62
  """Get LLM for RAG indexing (uses medical model) - GPU only"""
63
  # Use medical model for RAG indexing instead of translation model
 
77
  )
78
 
79
 
80
+ @spaces.GPU(max_duration=120)
81
  def get_embedding_model():
82
  """Get embedding model for RAG - GPU only"""
83
  return HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ @spaces.GPU(max_duration=120)
87
  def generate_with_medswin(
88
  medical_model_obj,
89
  medical_tokenizer,
 
100
  stopping_criteria: StoppingCriteriaList
101
  ):
102
  """
103
+ Generate text with MedSwin model - GPU only
104
 
105
+ This function only performs the actual model inference on GPU.
106
+ All other operations (prompt preparation, post-processing) should be done outside.
 
 
 
 
 
 
107
  """
108
+ # Tokenize prompt (this is a CPU operation but happens here for simplicity)
109
+ # The actual GPU work is in model.generate()
110
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
111
+
112
+ # Prepare generation kwargs
113
+ generation_kwargs = dict(
114
+ **inputs,
115
+ streamer=streamer,
 
 
 
 
 
 
 
 
116
  max_new_tokens=max_new_tokens,
117
  temperature=temperature,
118
  top_p=top_p,
119
  top_k=top_k,
120
+ repetition_penalty=penalty,
121
+ do_sample=True,
122
+ stopping_criteria=stopping_criteria,
123
  eos_token_id=eos_token_id,
124
+ pad_token_id=pad_token_id
 
 
 
 
125
  )
126
 
127
+ # Run generation on GPU - this is the only GPU operation
128
+ medical_model_obj.generate(**generation_kwargs)