Y Phung Nguyen commited on
Commit
dd13e35
·
1 Parent(s): f7415cc

Reduce supervisor processing latency

Browse files
Files changed (2) hide show
  1. client.py +120 -99
  2. pipeline.py +58 -48
client.py CHANGED
@@ -196,7 +196,9 @@ async def test_mcp_connection() -> bool:
196
 
197
 
198
  async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
199
- """Call Gemini API directly without MCP"""
 
 
200
  if not GEMINI_DIRECT_AVAILABLE:
201
  logger.error("❌ google-genai not installed - cannot use direct API")
202
  return ""
@@ -205,106 +207,125 @@ async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, fil
205
  logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
206
  return ""
207
 
208
- try:
209
- gemini_client = genai.Client(api_key=config.GEMINI_API_KEY)
210
- model_name = model or config.GEMINI_MODEL
211
- temp = temperature if temperature is not None else 0.2
212
-
213
- # Prepare content
214
- contents = user_prompt
215
- if system_prompt:
216
- contents = f"{system_prompt}\n\n{user_prompt}"
217
-
218
- gemini_contents = [contents]
219
-
220
- # Handle files if provided
221
- if files:
222
- for file_obj in files:
223
- try:
224
- if "path" in file_obj:
225
- file_path = file_obj["path"]
226
- mime_type = file_obj.get("type")
227
-
228
- if not os.path.exists(file_path):
229
- logger.warning(f"File not found: {file_path}")
230
- continue
231
-
232
- with open(file_path, 'rb') as f:
233
- file_data = f.read()
234
-
235
- if not mime_type:
236
- from mimetypes import guess_type
237
- mime_type, _ = guess_type(file_path)
 
238
  if not mime_type:
239
- mime_type = "application/octet-stream"
240
-
241
- gemini_contents.append({
242
- "inline_data": {
243
- "mime_type": mime_type,
244
- "data": base64.b64encode(file_data).decode('utf-8')
245
- }
246
- })
247
- elif "content" in file_obj:
248
- file_data = base64.b64decode(file_obj["content"])
249
- mime_type = file_obj.get("type", "application/octet-stream")
250
- gemini_contents.append({
251
- "inline_data": {
252
- "mime_type": mime_type,
253
- "data": file_obj["content"]
254
- }
255
- })
256
- except Exception as e:
257
- logger.warning(f"Error processing file: {e}")
258
- continue
259
-
260
- generation_config = {
261
- "temperature": temp,
262
- "max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192"))
263
- }
264
-
265
- logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}")
266
-
267
- def generate_sync():
268
- return gemini_client.models.generate_content(
269
- model=model_name,
270
- contents=gemini_contents,
271
- config=generation_config,
 
 
 
 
 
 
 
 
 
272
  )
273
-
274
- timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0)
275
- response = await asyncio.wait_for(
276
- asyncio.to_thread(generate_sync),
277
- timeout=timeout_seconds
278
- )
279
-
280
- logger.info(f"✅ Gemini API call completed successfully")
281
-
282
- # Extract text from response
283
- if response and hasattr(response, 'text') and response.text:
284
- return response.text.strip()
285
- elif response and hasattr(response, 'candidates') and response.candidates:
286
- text_parts = []
287
- for candidate in response.candidates:
288
- if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
289
- for part in candidate.content.parts:
290
- if hasattr(part, 'text'):
291
- text_parts.append(part.text)
292
- if text_parts:
293
- return ''.join(text_parts).strip()
294
-
295
- logger.warning("⚠️ Gemini API returned empty response")
296
- return ""
297
-
298
- except asyncio.TimeoutError:
299
- logger.error(f"❌ Gemini API call timed out")
300
- return ""
301
- except Exception as e:
302
- error_type = type(e).__name__
303
- error_msg = str(e)
304
- logger.error(f"❌ Gemini API call error: {error_type}: {error_msg}")
305
- import traceback
306
- logger.error(f"Full traceback: {traceback.format_exc()}")
307
- return ""
 
 
 
 
 
 
 
 
 
308
 
309
 
310
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
 
196
 
197
 
198
  async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
199
+ """Call Gemini API directly without MCP
200
+ Includes retry logic with exponential backoff to handle GPU task aborted errors
201
+ """
202
  if not GEMINI_DIRECT_AVAILABLE:
203
  logger.error("❌ google-genai not installed - cannot use direct API")
204
  return ""
 
207
  logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
208
  return ""
209
 
210
+ max_retries = 3
211
+ base_delay = 1.0 # Base delay in seconds
212
+
213
+ for attempt in range(max_retries):
214
+ try:
215
+ gemini_client = genai.Client(api_key=config.GEMINI_API_KEY)
216
+ model_name = model or config.GEMINI_MODEL
217
+ temp = temperature if temperature is not None else 0.2
218
+
219
+ # Prepare content
220
+ contents = user_prompt
221
+ if system_prompt:
222
+ contents = f"{system_prompt}\n\n{user_prompt}"
223
+
224
+ gemini_contents = [contents]
225
+
226
+ # Handle files if provided
227
+ if files:
228
+ for file_obj in files:
229
+ try:
230
+ if "path" in file_obj:
231
+ file_path = file_obj["path"]
232
+ mime_type = file_obj.get("type")
233
+
234
+ if not os.path.exists(file_path):
235
+ logger.warning(f"File not found: {file_path}")
236
+ continue
237
+
238
+ with open(file_path, 'rb') as f:
239
+ file_data = f.read()
240
+
241
  if not mime_type:
242
+ from mimetypes import guess_type
243
+ mime_type, _ = guess_type(file_path)
244
+ if not mime_type:
245
+ mime_type = "application/octet-stream"
246
+
247
+ gemini_contents.append({
248
+ "inline_data": {
249
+ "mime_type": mime_type,
250
+ "data": base64.b64encode(file_data).decode('utf-8')
251
+ }
252
+ })
253
+ elif "content" in file_obj:
254
+ file_data = base64.b64decode(file_obj["content"])
255
+ mime_type = file_obj.get("type", "application/octet-stream")
256
+ gemini_contents.append({
257
+ "inline_data": {
258
+ "mime_type": mime_type,
259
+ "data": file_obj["content"]
260
+ }
261
+ })
262
+ except Exception as e:
263
+ logger.warning(f"Error processing file: {e}")
264
+ continue
265
+
266
+ generation_config = {
267
+ "temperature": temp,
268
+ "max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192"))
269
+ }
270
+
271
+ logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}")
272
+
273
+ def generate_sync():
274
+ return gemini_client.models.generate_content(
275
+ model=model_name,
276
+ contents=gemini_contents,
277
+ config=generation_config,
278
+ )
279
+
280
+ timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0)
281
+ response = await asyncio.wait_for(
282
+ asyncio.to_thread(generate_sync),
283
+ timeout=timeout_seconds
284
  )
285
+
286
+ logger.info(f" Gemini API call completed successfully")
287
+
288
+ # Extract text from response
289
+ if response and hasattr(response, 'text') and response.text:
290
+ return response.text.strip()
291
+ elif response and hasattr(response, 'candidates') and response.candidates:
292
+ text_parts = []
293
+ for candidate in response.candidates:
294
+ if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
295
+ for part in candidate.content.parts:
296
+ if hasattr(part, 'text'):
297
+ text_parts.append(part.text)
298
+ if text_parts:
299
+ return ''.join(text_parts).strip()
300
+
301
+ logger.warning("⚠️ Gemini API returned empty response")
302
+ return ""
303
+
304
+ except asyncio.TimeoutError:
305
+ if attempt < max_retries - 1:
306
+ delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
307
+ logger.warning(f" Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
308
+ await asyncio.sleep(delay)
309
+ continue
310
+ else:
311
+ logger.error(f"❌ Gemini API call timed out after {max_retries} attempts")
312
+ return ""
313
+ except Exception as e:
314
+ error_type = type(e).__name__
315
+ error_msg = str(e).lower()
316
+ is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg)
317
+
318
+ if is_gpu_error and attempt < max_retries - 1:
319
+ delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
320
+ logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
321
+ await asyncio.sleep(delay)
322
+ continue
323
+ else:
324
+ logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}")
325
+ if attempt == max_retries - 1:
326
+ import traceback
327
+ logger.error(f"Full traceback: {traceback.format_exc()}")
328
+ return ""
329
 
330
 
331
  async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
pipeline.py CHANGED
@@ -47,54 +47,64 @@ def run_gemini_in_thread(fn, *args, **kwargs):
47
  except concurrent.futures.TimeoutError:
48
  logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
49
  # Return fallback based on function
50
- if "breakdown" in fn.__name__:
51
- return {
52
- "sub_topics": [
53
- {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
54
- ],
55
- "strategy": "Direct answer (timeout fallback)",
56
- "exploration_note": "Gemini supervisor timeout"
57
- }
58
- elif "search_strategies" in fn.__name__:
59
- return {
60
- "search_strategies": [
61
- {"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"}
62
- ],
63
- "max_strategies": 1
64
- }
65
- elif "rag_brainstorm" in fn.__name__:
66
- return {
67
- "contexts": [
68
- {"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"}
69
- ],
70
- "max_contexts": 1
71
- }
72
- elif "synthesize" in fn.__name__:
73
- return "\n\n".join(args[1] if len(args) > 1 else [])
74
- elif "challenge" in fn.__name__:
75
- return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
76
- elif "enhance_answer" in fn.__name__:
77
- return args[1] if len(args) > 1 else ""
78
- elif "check_clarity" in fn.__name__:
79
- return {"is_unclear": False, "needs_search": False, "search_queries": []}
80
- elif "clinical_intake_triage" in fn.__name__:
81
- return {
82
- "needs_additional_info": False,
83
- "decision_reason": "Timeout fallback",
84
- "max_rounds": args[2] if len(args) > 2 else 5,
85
- "questions": [],
86
- "initial_hypotheses": []
87
- }
88
- elif "summarize_clinical_insights" in fn.__name__:
89
- return {
90
- "patient_profile": "",
91
- "refined_problem_statement": args[0] if args else "",
92
- "key_findings": [],
93
- "handoff_note": "Proceed with regular workflow."
94
- }
95
- else:
96
- logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn.__name__}, returning None")
97
- return None
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
  logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
100
  # Return appropriate fallback
 
47
  except concurrent.futures.TimeoutError:
48
  logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
49
  # Return fallback based on function
50
+ return _supervisor_logics(fn.__name__, args)
51
+ except Exception as e:
52
+ logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} failed with error: {type(e).__name__}: {str(e)}")
53
+ # Return fallback based on function
54
+ return _supervisor_logics(fn.__name__, args)
55
+
56
+
57
+ def _supervisor_logics(fn_name: str, args: tuple):
58
+ """Get appropriate fallback value based on function name"""
59
+ if "breakdown" in fn_name:
60
+ return {
61
+ "sub_topics": [
62
+ {"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
63
+ ],
64
+ "strategy": "Direct answer (fallback)",
65
+ "exploration_note": "Gemini supervisor error"
66
+ }
67
+ elif "search_strategies" in fn_name:
68
+ return {
69
+ "search_strategies": [
70
+ {"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"}
71
+ ],
72
+ "max_strategies": 1
73
+ }
74
+ elif "rag_brainstorm" in fn_name:
75
+ return {
76
+ "contexts": [
77
+ {"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"}
78
+ ],
79
+ "max_contexts": 1
80
+ }
81
+ elif "synthesize" in fn_name:
82
+ # Return concatenated MedSwin answers as fallback
83
+ return "\n\n".join(args[1] if len(args) > 1 and args[1] else [])
84
+ elif "challenge" in fn_name:
85
+ return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
86
+ elif "enhance_answer" in fn_name:
87
+ return args[1] if len(args) > 1 else ""
88
+ elif "check_clarity" in fn_name:
89
+ return {"is_unclear": False, "needs_search": False, "search_queries": []}
90
+ elif "clinical_intake_triage" in fn_name:
91
+ return {
92
+ "needs_additional_info": False,
93
+ "decision_reason": "Error fallback",
94
+ "max_rounds": args[2] if len(args) > 2 else 5,
95
+ "questions": [],
96
+ "initial_hypotheses": []
97
+ }
98
+ elif "summarize_clinical_insights" in fn_name:
99
+ return {
100
+ "patient_profile": "",
101
+ "refined_problem_statement": args[0] if args else "",
102
+ "key_findings": [],
103
+ "handoff_note": "Proceed with regular workflow."
104
+ }
105
+ else:
106
+ logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn_name}, returning None")
107
+ return None
108
  except Exception as e:
109
  logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
110
  # Return appropriate fallback