Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
dd13e35
1
Parent(s):
f7415cc
Reduce supervisor processing latency
Browse files- client.py +120 -99
- pipeline.py +58 -48
client.py
CHANGED
|
@@ -196,7 +196,9 @@ async def test_mcp_connection() -> bool:
|
|
| 196 |
|
| 197 |
|
| 198 |
async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
|
| 199 |
-
"""Call Gemini API directly without MCP
|
|
|
|
|
|
|
| 200 |
if not GEMINI_DIRECT_AVAILABLE:
|
| 201 |
logger.error("❌ google-genai not installed - cannot use direct API")
|
| 202 |
return ""
|
|
@@ -205,106 +207,125 @@ async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, fil
|
|
| 205 |
logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
|
| 206 |
return ""
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
if
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
if not mime_type:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
"
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
async def call_agent_direct_api(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
|
| 199 |
+
"""Call Gemini API directly without MCP
|
| 200 |
+
Includes retry logic with exponential backoff to handle GPU task aborted errors
|
| 201 |
+
"""
|
| 202 |
if not GEMINI_DIRECT_AVAILABLE:
|
| 203 |
logger.error("❌ google-genai not installed - cannot use direct API")
|
| 204 |
return ""
|
|
|
|
| 207 |
logger.warning("GEMINI_API_KEY not set - cannot use Gemini API")
|
| 208 |
return ""
|
| 209 |
|
| 210 |
+
max_retries = 3
|
| 211 |
+
base_delay = 1.0 # Base delay in seconds
|
| 212 |
+
|
| 213 |
+
for attempt in range(max_retries):
|
| 214 |
+
try:
|
| 215 |
+
gemini_client = genai.Client(api_key=config.GEMINI_API_KEY)
|
| 216 |
+
model_name = model or config.GEMINI_MODEL
|
| 217 |
+
temp = temperature if temperature is not None else 0.2
|
| 218 |
+
|
| 219 |
+
# Prepare content
|
| 220 |
+
contents = user_prompt
|
| 221 |
+
if system_prompt:
|
| 222 |
+
contents = f"{system_prompt}\n\n{user_prompt}"
|
| 223 |
+
|
| 224 |
+
gemini_contents = [contents]
|
| 225 |
+
|
| 226 |
+
# Handle files if provided
|
| 227 |
+
if files:
|
| 228 |
+
for file_obj in files:
|
| 229 |
+
try:
|
| 230 |
+
if "path" in file_obj:
|
| 231 |
+
file_path = file_obj["path"]
|
| 232 |
+
mime_type = file_obj.get("type")
|
| 233 |
+
|
| 234 |
+
if not os.path.exists(file_path):
|
| 235 |
+
logger.warning(f"File not found: {file_path}")
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
with open(file_path, 'rb') as f:
|
| 239 |
+
file_data = f.read()
|
| 240 |
+
|
| 241 |
if not mime_type:
|
| 242 |
+
from mimetypes import guess_type
|
| 243 |
+
mime_type, _ = guess_type(file_path)
|
| 244 |
+
if not mime_type:
|
| 245 |
+
mime_type = "application/octet-stream"
|
| 246 |
+
|
| 247 |
+
gemini_contents.append({
|
| 248 |
+
"inline_data": {
|
| 249 |
+
"mime_type": mime_type,
|
| 250 |
+
"data": base64.b64encode(file_data).decode('utf-8')
|
| 251 |
+
}
|
| 252 |
+
})
|
| 253 |
+
elif "content" in file_obj:
|
| 254 |
+
file_data = base64.b64decode(file_obj["content"])
|
| 255 |
+
mime_type = file_obj.get("type", "application/octet-stream")
|
| 256 |
+
gemini_contents.append({
|
| 257 |
+
"inline_data": {
|
| 258 |
+
"mime_type": mime_type,
|
| 259 |
+
"data": file_obj["content"]
|
| 260 |
+
}
|
| 261 |
+
})
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.warning(f"Error processing file: {e}")
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
generation_config = {
|
| 267 |
+
"temperature": temp,
|
| 268 |
+
"max_output_tokens": int(os.environ.get("GEMINI_MAX_OUTPUT_TOKENS", "8192"))
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
logger.info(f"🔵 Calling Gemini API directly with model={model_name}, temperature={temp}")
|
| 272 |
+
|
| 273 |
+
def generate_sync():
|
| 274 |
+
return gemini_client.models.generate_content(
|
| 275 |
+
model=model_name,
|
| 276 |
+
contents=gemini_contents,
|
| 277 |
+
config=generation_config,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
timeout_seconds = min(int(os.environ.get("GEMINI_TIMEOUT", "300000")) / 1000.0, 20.0)
|
| 281 |
+
response = await asyncio.wait_for(
|
| 282 |
+
asyncio.to_thread(generate_sync),
|
| 283 |
+
timeout=timeout_seconds
|
| 284 |
)
|
| 285 |
+
|
| 286 |
+
logger.info(f"✅ Gemini API call completed successfully")
|
| 287 |
+
|
| 288 |
+
# Extract text from response
|
| 289 |
+
if response and hasattr(response, 'text') and response.text:
|
| 290 |
+
return response.text.strip()
|
| 291 |
+
elif response and hasattr(response, 'candidates') and response.candidates:
|
| 292 |
+
text_parts = []
|
| 293 |
+
for candidate in response.candidates:
|
| 294 |
+
if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
| 295 |
+
for part in candidate.content.parts:
|
| 296 |
+
if hasattr(part, 'text'):
|
| 297 |
+
text_parts.append(part.text)
|
| 298 |
+
if text_parts:
|
| 299 |
+
return ''.join(text_parts).strip()
|
| 300 |
+
|
| 301 |
+
logger.warning("⚠️ Gemini API returned empty response")
|
| 302 |
+
return ""
|
| 303 |
+
|
| 304 |
+
except asyncio.TimeoutError:
|
| 305 |
+
if attempt < max_retries - 1:
|
| 306 |
+
delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
|
| 307 |
+
logger.warning(f"⏳ Gemini API call timed out (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
|
| 308 |
+
await asyncio.sleep(delay)
|
| 309 |
+
continue
|
| 310 |
+
else:
|
| 311 |
+
logger.error(f"❌ Gemini API call timed out after {max_retries} attempts")
|
| 312 |
+
return ""
|
| 313 |
+
except Exception as e:
|
| 314 |
+
error_type = type(e).__name__
|
| 315 |
+
error_msg = str(e).lower()
|
| 316 |
+
is_gpu_error = 'gpu task aborted' in error_msg or ('gpu' in error_msg and 'abort' in error_msg)
|
| 317 |
+
|
| 318 |
+
if is_gpu_error and attempt < max_retries - 1:
|
| 319 |
+
delay = base_delay * (2 ** attempt) # Exponential backoff: 1s, 2s, 4s
|
| 320 |
+
logger.warning(f"⏳ Gemini API GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
|
| 321 |
+
await asyncio.sleep(delay)
|
| 322 |
+
continue
|
| 323 |
+
else:
|
| 324 |
+
logger.error(f"❌ Gemini API call error after {attempt + 1} attempts: {error_type}: {str(e)}")
|
| 325 |
+
if attempt == max_retries - 1:
|
| 326 |
+
import traceback
|
| 327 |
+
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 328 |
+
return ""
|
| 329 |
|
| 330 |
|
| 331 |
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
|
pipeline.py
CHANGED
|
@@ -47,54 +47,64 @@ def run_gemini_in_thread(fn, *args, **kwargs):
|
|
| 47 |
except concurrent.futures.TimeoutError:
|
| 48 |
logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
|
| 49 |
# Return fallback based on function
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
],
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
except Exception as e:
|
| 99 |
logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
|
| 100 |
# Return appropriate fallback
|
|
|
|
| 47 |
except concurrent.futures.TimeoutError:
|
| 48 |
logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} timed out after 30s")
|
| 49 |
# Return fallback based on function
|
| 50 |
+
return _supervisor_logics(fn.__name__, args)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"[GEMINI SUPERVISOR] Function {fn.__name__} failed with error: {type(e).__name__}: {str(e)}")
|
| 53 |
+
# Return fallback based on function
|
| 54 |
+
return _supervisor_logics(fn.__name__, args)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _supervisor_logics(fn_name: str, args: tuple):
|
| 58 |
+
"""Get appropriate fallback value based on function name"""
|
| 59 |
+
if "breakdown" in fn_name:
|
| 60 |
+
return {
|
| 61 |
+
"sub_topics": [
|
| 62 |
+
{"id": 1, "topic": "Answer", "instruction": args[0] if args else "Address the question", "expected_tokens": 400, "priority": "high", "approach": "direct answer"}
|
| 63 |
+
],
|
| 64 |
+
"strategy": "Direct answer (fallback)",
|
| 65 |
+
"exploration_note": "Gemini supervisor error"
|
| 66 |
+
}
|
| 67 |
+
elif "search_strategies" in fn_name:
|
| 68 |
+
return {
|
| 69 |
+
"search_strategies": [
|
| 70 |
+
{"id": 1, "strategy": args[0] if args else "", "target_sources": 2, "focus": "main query"}
|
| 71 |
+
],
|
| 72 |
+
"max_strategies": 1
|
| 73 |
+
}
|
| 74 |
+
elif "rag_brainstorm" in fn_name:
|
| 75 |
+
return {
|
| 76 |
+
"contexts": [
|
| 77 |
+
{"id": 1, "context": args[1][:500] if len(args) > 1 else "", "focus": "retrieved information", "relevance": "high"}
|
| 78 |
+
],
|
| 79 |
+
"max_contexts": 1
|
| 80 |
+
}
|
| 81 |
+
elif "synthesize" in fn_name:
|
| 82 |
+
# Return concatenated MedSwin answers as fallback
|
| 83 |
+
return "\n\n".join(args[1] if len(args) > 1 and args[1] else [])
|
| 84 |
+
elif "challenge" in fn_name:
|
| 85 |
+
return {"is_optimal": True, "completeness_score": 7, "accuracy_score": 7, "clarity_score": 7, "missing_aspects": [], "inaccuracies": [], "improvement_suggestions": [], "needs_more_context": False, "enhancement_instructions": ""}
|
| 86 |
+
elif "enhance_answer" in fn_name:
|
| 87 |
+
return args[1] if len(args) > 1 else ""
|
| 88 |
+
elif "check_clarity" in fn_name:
|
| 89 |
+
return {"is_unclear": False, "needs_search": False, "search_queries": []}
|
| 90 |
+
elif "clinical_intake_triage" in fn_name:
|
| 91 |
+
return {
|
| 92 |
+
"needs_additional_info": False,
|
| 93 |
+
"decision_reason": "Error fallback",
|
| 94 |
+
"max_rounds": args[2] if len(args) > 2 else 5,
|
| 95 |
+
"questions": [],
|
| 96 |
+
"initial_hypotheses": []
|
| 97 |
+
}
|
| 98 |
+
elif "summarize_clinical_insights" in fn_name:
|
| 99 |
+
return {
|
| 100 |
+
"patient_profile": "",
|
| 101 |
+
"refined_problem_statement": args[0] if args else "",
|
| 102 |
+
"key_findings": [],
|
| 103 |
+
"handoff_note": "Proceed with regular workflow."
|
| 104 |
+
}
|
| 105 |
+
else:
|
| 106 |
+
logger.warning(f"[GEMINI SUPERVISOR] Unknown function {fn_name}, returning None")
|
| 107 |
+
return None
|
| 108 |
except Exception as e:
|
| 109 |
logger.error(f"[GEMINI SUPERVISOR] Error running {fn.__name__} in thread: {e}")
|
| 110 |
# Return appropriate fallback
|