Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
47e5fb1
1
Parent(s):
83a4de1
Enhance Q&A breakdown agent
Browse files- pipeline.py +45 -8
- supervisor.py +33 -0
pipeline.py
CHANGED
|
@@ -75,6 +75,21 @@ def _format_intake_question(question: dict, round_idx: int, max_rounds: int, tar
|
|
| 75 |
return prompt_text
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _format_insights_block(insights: dict) -> str:
|
| 79 |
if not insights:
|
| 80 |
return ""
|
|
@@ -160,13 +175,15 @@ def _handle_clinical_answer(session_id: str, answer_text: str):
|
|
| 160 |
insights = gemini_summarize_clinical_insights(state["base_query"], state["answers"])
|
| 161 |
insights_block = _format_insights_block(insights)
|
| 162 |
refined_query = _build_refined_query(state["base_query"], insights, insights_block)
|
|
|
|
| 163 |
_clear_clinical_intake_state(session_id)
|
| 164 |
return {
|
| 165 |
"type": "insights",
|
| 166 |
"insights": insights,
|
| 167 |
"insights_block": insights_block,
|
| 168 |
"refined_query": refined_query,
|
| 169 |
-
"qa_pairs": state["answers"]
|
|
|
|
| 170 |
}
|
| 171 |
state["pending_question_index"] = next_index
|
| 172 |
state["current_round"] = len(state["answers"]) + 1
|
|
@@ -240,7 +257,11 @@ def stream_chat(
|
|
| 240 |
"activated": False,
|
| 241 |
"rounds": 0,
|
| 242 |
"reason": "",
|
| 243 |
-
"insights": []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
}
|
| 245 |
}
|
| 246 |
def record_stage(stage_name: str, start_time: float):
|
|
@@ -262,6 +283,8 @@ def stream_chat(
|
|
| 262 |
{"role": "assistant", "content": ""}
|
| 263 |
]
|
| 264 |
|
|
|
|
|
|
|
| 265 |
if not enable_clinical_intake:
|
| 266 |
_clear_clinical_intake_state(user_id)
|
| 267 |
else:
|
|
@@ -281,11 +304,24 @@ def stream_chat(
|
|
| 281 |
pipeline_diagnostics["clinical_intake"]["activated"] = True
|
| 282 |
pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
|
| 283 |
pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
history_context = _history_to_text(history)
|
| 287 |
triage_plan = gemini_clinical_intake_triage(message, history_context, MAX_CLINICAL_QA_ROUNDS)
|
| 288 |
pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
|
|
|
|
| 289 |
needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
|
| 290 |
if needs_intake:
|
| 291 |
first_prompt = _start_clinical_intake_session(
|
|
@@ -460,13 +496,14 @@ def stream_chat(
|
|
| 460 |
|
| 461 |
base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
|
| 462 |
|
| 463 |
-
|
|
|
|
|
|
|
| 464 |
if rag_contexts:
|
| 465 |
-
|
| 466 |
if search_contexts:
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
|
| 470 |
|
| 471 |
logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
|
| 472 |
medswin_answers = []
|
|
|
|
| 75 |
return prompt_text
|
| 76 |
|
| 77 |
|
| 78 |
+
def _format_qa_transcript(qa_pairs: list) -> str:
|
| 79 |
+
if not qa_pairs:
|
| 80 |
+
return ""
|
| 81 |
+
lines = []
|
| 82 |
+
for idx, qa in enumerate(qa_pairs, 1):
|
| 83 |
+
question = qa.get("question", "").strip()
|
| 84 |
+
answer = qa.get("answer", "").strip()
|
| 85 |
+
if question:
|
| 86 |
+
lines.append(f"Q{idx}: {question}")
|
| 87 |
+
if answer:
|
| 88 |
+
lines.append(f"A{idx}: {answer}")
|
| 89 |
+
lines.append("")
|
| 90 |
+
return "\n".join(lines).strip()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
def _format_insights_block(insights: dict) -> str:
|
| 94 |
if not insights:
|
| 95 |
return ""
|
|
|
|
| 175 |
insights = gemini_summarize_clinical_insights(state["base_query"], state["answers"])
|
| 176 |
insights_block = _format_insights_block(insights)
|
| 177 |
refined_query = _build_refined_query(state["base_query"], insights, insights_block)
|
| 178 |
+
transcript = _format_qa_transcript(state["answers"])
|
| 179 |
_clear_clinical_intake_state(session_id)
|
| 180 |
return {
|
| 181 |
"type": "insights",
|
| 182 |
"insights": insights,
|
| 183 |
"insights_block": insights_block,
|
| 184 |
"refined_query": refined_query,
|
| 185 |
+
"qa_pairs": state["answers"],
|
| 186 |
+
"qa_transcript": transcript
|
| 187 |
}
|
| 188 |
state["pending_question_index"] = next_index
|
| 189 |
state["current_round"] = len(state["answers"]) + 1
|
|
|
|
| 257 |
"activated": False,
|
| 258 |
"rounds": 0,
|
| 259 |
"reason": "",
|
| 260 |
+
"insights": [],
|
| 261 |
+
"plan": [],
|
| 262 |
+
"qa_pairs": [],
|
| 263 |
+
"transcript": "",
|
| 264 |
+
"insights_block": ""
|
| 265 |
}
|
| 266 |
}
|
| 267 |
def record_stage(stage_name: str, start_time: float):
|
|
|
|
| 283 |
{"role": "assistant", "content": ""}
|
| 284 |
]
|
| 285 |
|
| 286 |
+
clinical_intake_context_block = ""
|
| 287 |
+
|
| 288 |
if not enable_clinical_intake:
|
| 289 |
_clear_clinical_intake_state(user_id)
|
| 290 |
else:
|
|
|
|
| 304 |
pipeline_diagnostics["clinical_intake"]["activated"] = True
|
| 305 |
pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
|
| 306 |
pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
|
| 307 |
+
pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", [])
|
| 308 |
+
pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "")
|
| 309 |
+
pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "")
|
| 310 |
+
base_refined = intake_result.get("refined_query", message)
|
| 311 |
+
summary_section = ""
|
| 312 |
+
transcript_section = ""
|
| 313 |
+
if intake_result.get("insights_block"):
|
| 314 |
+
summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}"
|
| 315 |
+
if intake_result.get("qa_transcript"):
|
| 316 |
+
transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}"
|
| 317 |
+
sections = [base_refined, summary_section, transcript_section]
|
| 318 |
+
message = "\n\n---\n\n".join([section for section in sections if section])
|
| 319 |
+
clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg])
|
| 320 |
else:
|
| 321 |
history_context = _history_to_text(history)
|
| 322 |
triage_plan = gemini_clinical_intake_triage(message, history_context, MAX_CLINICAL_QA_ROUNDS)
|
| 323 |
pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
|
| 324 |
+
pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", [])
|
| 325 |
needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
|
| 326 |
if needs_intake:
|
| 327 |
first_prompt = _start_clinical_intake_session(
|
|
|
|
| 496 |
|
| 497 |
base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
|
| 498 |
|
| 499 |
+
context_sections = []
|
| 500 |
+
if clinical_intake_context_block:
|
| 501 |
+
context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block)
|
| 502 |
if rag_contexts:
|
| 503 |
+
context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4]))
|
| 504 |
if search_contexts:
|
| 505 |
+
context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts))
|
| 506 |
+
combined_context = "\n\n".join(context_sections)
|
|
|
|
| 507 |
|
| 508 |
logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
|
| 509 |
medswin_answers = []
|
supervisor.py
CHANGED
|
@@ -160,6 +160,38 @@ Keep strategies focused and avoid overlap."""
|
|
| 160 |
}
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
|
| 164 |
"""Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
|
| 165 |
max_doc_length = 3000
|
|
@@ -274,6 +306,7 @@ Guidelines:
|
|
| 274 |
json_end = response.rfind('}') + 1
|
| 275 |
if json_start >= 0 and json_end > json_start:
|
| 276 |
plan = json.loads(response[json_start:json_end])
|
|
|
|
| 277 |
return plan
|
| 278 |
raise ValueError("Clinical intake JSON not found")
|
| 279 |
except Exception as exc:
|
|
|
|
| 160 |
}
|
| 161 |
|
| 162 |
|
| 163 |
+
def _prepare_clinical_question_plan(plan: dict, safe_rounds: int) -> dict:
|
| 164 |
+
"""Normalize Gemini question plan to 1-5 sequential prompts."""
|
| 165 |
+
if not isinstance(plan, dict):
|
| 166 |
+
return {"questions": []}
|
| 167 |
+
questions = plan.get("questions", [])
|
| 168 |
+
if not isinstance(questions, list):
|
| 169 |
+
questions = []
|
| 170 |
+
cleaned = []
|
| 171 |
+
for idx, raw in enumerate(questions):
|
| 172 |
+
if not isinstance(raw, dict):
|
| 173 |
+
continue
|
| 174 |
+
question_text = (raw.get("question") or "").strip()
|
| 175 |
+
if not question_text:
|
| 176 |
+
continue
|
| 177 |
+
entry = dict(raw)
|
| 178 |
+
entry["question"] = question_text
|
| 179 |
+
entry["order"] = entry.get("order") or raw.get("id") or (idx + 1)
|
| 180 |
+
cleaned.append(entry)
|
| 181 |
+
cleaned.sort(key=lambda item: item.get("order", 0))
|
| 182 |
+
cleaned = cleaned[:max(1, min(5, safe_rounds))]
|
| 183 |
+
for idx, item in enumerate(cleaned, 1):
|
| 184 |
+
item["order"] = idx
|
| 185 |
+
plan["questions"] = cleaned
|
| 186 |
+
if cleaned:
|
| 187 |
+
plan["max_rounds"] = min(len(cleaned), safe_rounds)
|
| 188 |
+
plan["needs_additional_info"] = bool(plan.get("needs_additional_info", True))
|
| 189 |
+
else:
|
| 190 |
+
plan["needs_additional_info"] = False
|
| 191 |
+
plan["max_rounds"] = 0
|
| 192 |
+
return plan
|
| 193 |
+
|
| 194 |
+
|
| 195 |
async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
|
| 196 |
"""Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
|
| 197 |
max_doc_length = 3000
|
|
|
|
| 306 |
json_end = response.rfind('}') + 1
|
| 307 |
if json_start >= 0 and json_end > json_start:
|
| 308 |
plan = json.loads(response[json_start:json_end])
|
| 309 |
+
plan = _prepare_clinical_question_plan(plan, safe_rounds)
|
| 310 |
return plan
|
| 311 |
raise ValueError("Clinical intake JSON not found")
|
| 312 |
except Exception as exc:
|