Y Phung Nguyen commited on
Commit
47e5fb1
·
1 Parent(s): 83a4de1

Enhance Q&A breakdown agent

Browse files
Files changed (2) hide show
  1. pipeline.py +45 -8
  2. supervisor.py +33 -0
pipeline.py CHANGED
@@ -75,6 +75,21 @@ def _format_intake_question(question: dict, round_idx: int, max_rounds: int, tar
75
  return prompt_text
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _format_insights_block(insights: dict) -> str:
79
  if not insights:
80
  return ""
@@ -160,13 +175,15 @@ def _handle_clinical_answer(session_id: str, answer_text: str):
160
  insights = gemini_summarize_clinical_insights(state["base_query"], state["answers"])
161
  insights_block = _format_insights_block(insights)
162
  refined_query = _build_refined_query(state["base_query"], insights, insights_block)
 
163
  _clear_clinical_intake_state(session_id)
164
  return {
165
  "type": "insights",
166
  "insights": insights,
167
  "insights_block": insights_block,
168
  "refined_query": refined_query,
169
- "qa_pairs": state["answers"]
 
170
  }
171
  state["pending_question_index"] = next_index
172
  state["current_round"] = len(state["answers"]) + 1
@@ -240,7 +257,11 @@ def stream_chat(
240
  "activated": False,
241
  "rounds": 0,
242
  "reason": "",
243
- "insights": []
 
 
 
 
244
  }
245
  }
246
  def record_stage(stage_name: str, start_time: float):
@@ -262,6 +283,8 @@ def stream_chat(
262
  {"role": "assistant", "content": ""}
263
  ]
264
 
 
 
265
  if not enable_clinical_intake:
266
  _clear_clinical_intake_state(user_id)
267
  else:
@@ -281,11 +304,24 @@ def stream_chat(
281
  pipeline_diagnostics["clinical_intake"]["activated"] = True
282
  pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
283
  pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
284
- message = intake_result.get("refined_query", message)
 
 
 
 
 
 
 
 
 
 
 
 
285
  else:
286
  history_context = _history_to_text(history)
287
  triage_plan = gemini_clinical_intake_triage(message, history_context, MAX_CLINICAL_QA_ROUNDS)
288
  pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
 
289
  needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
290
  if needs_intake:
291
  first_prompt = _start_clinical_intake_session(
@@ -460,13 +496,14 @@ def stream_chat(
460
 
461
  base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
462
 
463
- combined_context = ""
 
 
464
  if rag_contexts:
465
- combined_context += "Document Context:\n" + "\n\n".join(rag_contexts[:4])
466
  if search_contexts:
467
- if combined_context:
468
- combined_context += "\n\n"
469
- combined_context += "Web Search Context:\n" + "\n\n".join(search_contexts)
470
 
471
  logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
472
  medswin_answers = []
 
75
  return prompt_text
76
 
77
 
78
+ def _format_qa_transcript(qa_pairs: list) -> str:
79
+ if not qa_pairs:
80
+ return ""
81
+ lines = []
82
+ for idx, qa in enumerate(qa_pairs, 1):
83
+ question = qa.get("question", "").strip()
84
+ answer = qa.get("answer", "").strip()
85
+ if question:
86
+ lines.append(f"Q{idx}: {question}")
87
+ if answer:
88
+ lines.append(f"A{idx}: {answer}")
89
+ lines.append("")
90
+ return "\n".join(lines).strip()
91
+
92
+
93
  def _format_insights_block(insights: dict) -> str:
94
  if not insights:
95
  return ""
 
175
  insights = gemini_summarize_clinical_insights(state["base_query"], state["answers"])
176
  insights_block = _format_insights_block(insights)
177
  refined_query = _build_refined_query(state["base_query"], insights, insights_block)
178
+ transcript = _format_qa_transcript(state["answers"])
179
  _clear_clinical_intake_state(session_id)
180
  return {
181
  "type": "insights",
182
  "insights": insights,
183
  "insights_block": insights_block,
184
  "refined_query": refined_query,
185
+ "qa_pairs": state["answers"],
186
+ "qa_transcript": transcript
187
  }
188
  state["pending_question_index"] = next_index
189
  state["current_round"] = len(state["answers"]) + 1
 
257
  "activated": False,
258
  "rounds": 0,
259
  "reason": "",
260
+ "insights": [],
261
+ "plan": [],
262
+ "qa_pairs": [],
263
+ "transcript": "",
264
+ "insights_block": ""
265
  }
266
  }
267
  def record_stage(stage_name: str, start_time: float):
 
283
  {"role": "assistant", "content": ""}
284
  ]
285
 
286
+ clinical_intake_context_block = ""
287
+
288
  if not enable_clinical_intake:
289
  _clear_clinical_intake_state(user_id)
290
  else:
 
304
  pipeline_diagnostics["clinical_intake"]["activated"] = True
305
  pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
306
  pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
307
+ pipeline_diagnostics["clinical_intake"]["qa_pairs"] = intake_result.get("qa_pairs", [])
308
+ pipeline_diagnostics["clinical_intake"]["transcript"] = intake_result.get("qa_transcript", "")
309
+ pipeline_diagnostics["clinical_intake"]["insights_block"] = intake_result.get("insights_block", "")
310
+ base_refined = intake_result.get("refined_query", message)
311
+ summary_section = ""
312
+ transcript_section = ""
313
+ if intake_result.get("insights_block"):
314
+ summary_section = f"Clinical intake summary:\n{intake_result['insights_block']}"
315
+ if intake_result.get("qa_transcript"):
316
+ transcript_section = f"Clinical intake Q&A transcript:\n{intake_result['qa_transcript']}"
317
+ sections = [base_refined, summary_section, transcript_section]
318
+ message = "\n\n---\n\n".join([section for section in sections if section])
319
+ clinical_intake_context_block = "\n\n".join([seg for seg in [summary_section, transcript_section] if seg])
320
  else:
321
  history_context = _history_to_text(history)
322
  triage_plan = gemini_clinical_intake_triage(message, history_context, MAX_CLINICAL_QA_ROUNDS)
323
  pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
324
+ pipeline_diagnostics["clinical_intake"]["plan"] = triage_plan.get("questions", [])
325
  needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
326
  if needs_intake:
327
  first_prompt = _start_clinical_intake_session(
 
496
 
497
  base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers. Use Markdown format with bullet points. Do not use tables."
498
 
499
+ context_sections = []
500
+ if clinical_intake_context_block:
501
+ context_sections.append("Clinical Intake Context:\n" + clinical_intake_context_block)
502
  if rag_contexts:
503
+ context_sections.append("Document Context:\n" + "\n\n".join(rag_contexts[:4]))
504
  if search_contexts:
505
+ context_sections.append("Web Search Context:\n" + "\n\n".join(search_contexts))
506
+ combined_context = "\n\n".join(context_sections)
 
507
 
508
  logger.info(f"[MEDSWIN] Executing {len(breakdown.get('sub_topics', []))} tasks sequentially...")
509
  medswin_answers = []
supervisor.py CHANGED
@@ -160,6 +160,38 @@ Keep strategies focused and avoid overlap."""
160
  }
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
164
  """Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
165
  max_doc_length = 3000
@@ -274,6 +306,7 @@ Guidelines:
274
  json_end = response.rfind('}') + 1
275
  if json_start >= 0 and json_end > json_start:
276
  plan = json.loads(response[json_start:json_end])
 
277
  return plan
278
  raise ValueError("Clinical intake JSON not found")
279
  except Exception as exc:
 
160
  }
161
 
162
 
163
+ def _prepare_clinical_question_plan(plan: dict, safe_rounds: int) -> dict:
164
+ """Normalize Gemini question plan to 1-5 sequential prompts."""
165
+ if not isinstance(plan, dict):
166
+ return {"questions": []}
167
+ questions = plan.get("questions", [])
168
+ if not isinstance(questions, list):
169
+ questions = []
170
+ cleaned = []
171
+ for idx, raw in enumerate(questions):
172
+ if not isinstance(raw, dict):
173
+ continue
174
+ question_text = (raw.get("question") or "").strip()
175
+ if not question_text:
176
+ continue
177
+ entry = dict(raw)
178
+ entry["question"] = question_text
179
+ entry["order"] = entry.get("order") or raw.get("id") or (idx + 1)
180
+ cleaned.append(entry)
181
+ cleaned.sort(key=lambda item: item.get("order", 0))
182
+ cleaned = cleaned[:max(1, min(5, safe_rounds))]
183
+ for idx, item in enumerate(cleaned, 1):
184
+ item["order"] = idx
185
+ plan["questions"] = cleaned
186
+ if cleaned:
187
+ plan["max_rounds"] = min(len(cleaned), safe_rounds)
188
+ plan["needs_additional_info"] = bool(plan.get("needs_additional_info", True))
189
+ else:
190
+ plan["needs_additional_info"] = False
191
+ plan["max_rounds"] = 0
192
+ return plan
193
+
194
+
195
  async def gemini_supervisor_rag_brainstorm_async(query: str, retrieved_docs: str, time_elapsed: float) -> dict:
196
  """Gemini Supervisor: In RAG mode, brainstorm retrieved documents into 1-4 short contexts"""
197
  max_doc_length = 3000
 
306
  json_end = response.rfind('}') + 1
307
  if json_start >= 0 and json_end > json_start:
308
  plan = json.loads(response[json_start:json_end])
309
+ plan = _prepare_clinical_question_plan(plan, safe_rounds)
310
  return plan
311
  raise ValueError("Clinical intake JSON not found")
312
  except Exception as exc: