KeenWoo commited on
Commit
1e85018
·
verified ·
1 Parent(s): c5d9793

Update alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +161 -232
alz_companion/agent.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import random # for random select songs
9
 
10
  from typing import List, Dict, Any, Optional
 
11
 
12
  try:
13
  from openai import OpenAI
@@ -107,6 +108,25 @@ MULTI_HOP_KEYPHRASES = [
107
  _MH_PATTERNS = [re.compile(p, re.IGNORECASE) for p in MULTI_HOP_KEYPHRASES]
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Add this near the top of agent.py with the other keyphrase lists
111
  SUMMARIZATION_KEYPHRASES = [
112
  r"^\b(summarize|summarise|recap)\b", r"^\b(give me a summary|create a short summary)\b"
@@ -259,11 +279,13 @@ def detect_tags_from_query(
259
  print(f"ERROR parsing NLU Specialist JSON: {e}")
260
  return result_dict
261
 
 
262
  def _default_embeddings():
263
  # This function remains unchanged from agent_work.py
264
  model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
265
  return HuggingFaceEmbeddings(model_name=model_name)
266
 
 
267
  def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
268
  # This function remains unchanged from agent_work.py
269
  os.makedirs(os.path.dirname(index_path), exist_ok=True)
@@ -277,6 +299,23 @@ def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal
277
  vs.save_local(index_path)
278
  return vs
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  def texts_from_jsonl(path: str) -> List[Document]:
281
  # This function remains unchanged from agent_work.py
282
  out: List[Document] = []
@@ -293,6 +332,36 @@ def texts_from_jsonl(path: str) -> List[Document]:
293
  except Exception: return []
294
  return out
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Some vectorstores might return duplicates.
297
  # This is useful when top-k cutoff might otherwise include near-duplicates from query expansion
298
  def dedup_docs(scored_docs):
@@ -305,21 +374,6 @@ def dedup_docs(scored_docs):
305
  seen.add(uid)
306
  return unique
307
 
308
-
309
- def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
310
- # This function remains unchanged from agent_work.py
311
- docs: List[Document] = []
312
- for p in (sample_paths or []):
313
- try:
314
- if p.lower().endswith(".jsonl"):
315
- docs.extend(texts_from_jsonl(p))
316
- else:
317
- with open(p, "r", encoding="utf-8", errors="ignore") as fh:
318
- docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
319
- except Exception: continue
320
- if not docs:
321
- docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
322
- return build_or_load_vectorstore(docs, index_path=index_path)
323
 
324
  def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None, response_format: Optional[dict] = None) -> str:
325
  # This function remains unchanged from agent_work.py
@@ -398,6 +452,7 @@ def route_query_type(query: str, severity: str = "Normal / Unspecified"):
398
  print(f"Query classified as: {sum_hit} (summarization pre-router)")
399
  return sum_hit
400
 
 
401
  # NEW Add Music Support before care_hit = _pre_router(query)
402
  # the general "caregiving" keyword checker (_pre_router) is called before
403
  # the specific "play music" checker (_pre_router_music).
@@ -405,8 +460,8 @@ def route_query_type(query: str, severity: str = "Normal / Unspecified"):
405
  if music_hit:
406
  print(f"Query classified as: {music_hit} (music re-router)")
407
  return music_hit
408
-
409
- # Priority 3: Check for general caregiving keywords.
410
  care_hit = _pre_router(query)
411
  if care_hit:
412
  print(f"Query classified as: {care_hit} (caregiving pre-router)")
@@ -421,7 +476,6 @@ def route_query_type(query: str, severity: str = "Normal / Unspecified"):
421
 
422
  # helper: put near other small utils in agent.py
423
  # In agent.py, replace the _source_ids_for_eval function
424
-
425
  def _source_ids_for_eval(docs, cap=5):
426
  """
427
  Return the source identifiers for evaluation.
@@ -584,15 +638,13 @@ def make_rag_chain(vs_general: FAISS, vs_personal: FAISS, *, for_evaluation: boo
584
  # --- END OF REVISED MUSIC LOGIC ---
585
  # END --- MUSIC PLAYBACK LOGIC ---
586
 
587
-
588
  p_name = patient_name or "the patient"
589
  c_name = caregiver_name or "the caregiver"
590
-
591
  perspective_line = (f"You are speaking directly to {p_name}, who is the patient...") if role == "patient" else (f"You are communicating with {c_name}, the caregiver, about {p_name}.")
592
  system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, perspective_line=perspective_line, guardrails=SAFETY_GUARDRAILS)
593
  messages = [{"role": "system", "content": system_message}]
594
  messages.extend(chat_history)
595
-
596
  if "general_knowledge_question" in query_type or "general_conversation" in query_type:
597
  template = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE if "general_knowledge" in query_type else ANSWER_TEMPLATE_GENERAL
598
  user_prompt = template.format(question=query, language=language)
@@ -601,232 +653,109 @@ def make_rag_chain(vs_general: FAISS, vs_personal: FAISS, *, for_evaluation: boo
601
  answer = _clean_surface_text(raw_answer)
602
  sources = ["General Knowledge"] if "general_knowledge" in query_type else []
603
  return {"answer": answer, "sources": sources, "source_documents": []}
 
604
 
605
- expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
606
- expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
607
- try:
608
- search_queries = [query] + json.loads(expansion_response.strip().replace("```json", "").replace("```", ""))
609
- except json.JSONDecodeError:
610
- search_queries = [query]
611
 
612
- # NEW: Determine sourcing weight
613
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
614
- top_k_general = 5
615
- top_k_personal = 1
616
- else: # current default
617
- top_k_general = 2
618
- top_k_personal = 3
619
-
620
- # NEW: pass top_k_personal and top_k_general parameters
621
- personal_results_with_scores = [
622
- result for q in search_queries for result in vs_personal.similarity_search_with_score(q, k=top_k_personal)
623
- ]
624
- general_results_with_scores = [
625
- result for q in search_queries for result in vs_general.similarity_search_with_score(q, k=top_k_general)
626
- ]
627
-
628
- # NEW: Remove duplicates
629
- personal_results_with_scores = dedup_docs(personal_results_with_scores)
630
- general_results_with_scores = dedup_docs(general_results_with_scores)
631
-
632
- ## BEGIN DEBUGGING
633
- print(f"[DEBUG] Retrieved {len(personal_results_with_scores)} personal, {len(general_results_with_scores)} general results")
634
- if personal_results_with_scores:
635
- print(f"Top personal score: {max([s for _, s in personal_results_with_scores]):.3f}")
636
- if general_results_with_scores:
637
- print(f"Top general score: {max([s for _, s in general_results_with_scores]):.3f}")
638
-
639
- print("\n--- DEBUG: Personal Search Results with Scores (Before Filtering) ---")
640
- if personal_results_with_scores:
641
- for doc, score in personal_results_with_scores:
642
- print(f" - Score: {score:.4f} | Source: {doc.metadata.get('source', 'N/A')}")
643
- else:
644
- print(" - No results found.")
645
- print("-----------------------------------------------------------------")
646
-
647
- print("\n--- DEBUG: General Search Results with Scores (Before Filtering) ----")
648
- if general_results_with_scores:
649
- for doc, score in general_results_with_scores:
650
- print(f" - Score: {score:.4f} | Source: {doc.metadata.get('source', 'N/A')}")
651
  else:
652
- print(" - No results found.")
653
- print("-----------------------------------------------------------------")
654
- ## END DEBUGGING
655
-
656
- # Return the most relevant doc if not return the best score; and all strip OUT placehoder doc
657
- def get_best_docs_with_fallback(results_with_scores: list[tuple[Document, float]]) -> (list[Document], float):
658
- valid_results = [res for res in results_with_scores if res[0].metadata.get("source") != "placeholder"]
659
- if not valid_results:
660
- return [], float('inf')
661
-
662
- best_score = sorted(valid_results, key=lambda x: x[1])[0][1]
663
- filtered_docs = [doc for doc, score in valid_results if score < RELEVANCE_THRESHOLD]
 
 
 
 
 
664
 
665
- if not filtered_docs:
666
- return [sorted(valid_results, key=lambda x: x[1])[0][0]], best_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
 
668
- return filtered_docs, best_score
669
- # END def get_best_docs_with_fallback
670
-
671
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
672
- # Use top-k selection (e.g. top 5 for general, top 1 for personal)
673
- filtered_general_docs = [doc for doc, score in general_results_with_scores[:top_k_general]]
674
- best_general_score = general_results_with_scores[0][1] if general_results_with_scores else 0.0
675
-
676
- filtered_personal_docs = [doc for doc, score in personal_results_with_scores[:top_k_personal]]
677
- best_personal_score = personal_results_with_scores[0][1] if personal_results_with_scores else 0.0
678
- else:
679
- # Use standard fallback-based scoring
680
- filtered_personal_docs, best_personal_score = get_best_docs_with_fallback(personal_results_with_scores)
681
- filtered_general_docs, best_general_score = get_best_docs_with_fallback(general_results_with_scores)
682
-
683
- print("\n--- DEBUG: Filtered Personal Docs (After Threshold/Fallback) ---")
684
- if filtered_personal_docs:
685
- for doc in filtered_personal_docs:
686
- print(f" - Source: {doc.metadata.get('source', 'N/A')}")
687
- else:
688
- print(" - No documents met the criteria.")
689
  print("----------------------------------------------------------------")
690
 
691
- print("\n--- DEBUG: Filtered General Docs (After Threshold/Fallback) ----")
692
- if filtered_general_docs:
693
- for doc in filtered_general_docs:
694
- print(f" - Source: {doc.metadata.get('source', 'N/A')}")
695
- else:
696
- print(" - No documents met the criteria.")
697
- print("----------------------------------------------------------------")
698
-
699
- personal_memory_routes = ["factual", "multi_hop", "summarization"]
700
- is_personal_route = any(route_keyword in query_type for route_keyword in personal_memory_routes)
701
-
702
- all_retrieved_docs = []
703
  if is_personal_route:
704
- # --- MODIFIED AS PER YOUR SPECIFICATION ---
705
- # Implements the simple fallback logic for personal routes.
706
- # the logic of it always returns a personal doc unless it's not loaded with personal memory
707
- if filtered_personal_docs:
708
- print("[DEBUG] Factual/Sum/Multi: Prioritizing personal docs.")
709
- all_retrieved_docs = filtered_personal_docs
710
- else:
711
- print("[DEBUG] Factual/Sum/Multi: Prioritizing general docs.")
712
- all_retrieved_docs = filtered_general_docs
713
- # --- END OF MODIFICATION ---
714
  else: # caregiving_scenario
715
- if disease_stage in ["Moderate Stage", "Advanced Stage"]:
716
- # --- STAGE-AWARE LOGIC FOR CAREGIVING SCENARIOS ---
717
- if filtered_general_docs:
718
- print("[DEBUG] Moderate / Advanced: Prioritizing general documents.")
719
- all_retrieved_docs = filtered_general_docs
720
- elif filtered_personal_docs:
721
- print("[DEBUG] Moderate / Advanced: Falling back to personal documents.")
722
- all_retrieved_docs = filtered_personal_docs
723
- else:
724
- print("[DEBUG] Moderate / Advanced: No relevant documents found.")
725
- all_retrieved_docs = []
726
- # --- END STAGE-AWARE BLOCK ---
727
- else:
728
- # --- NORMAL ROUTING LOGIC ---
729
- # Conditional Blending logic for caregiving remains.
730
- if abs(best_personal_score - best_general_score) <= SCORE_MARGIN:
731
- print("[DEBUG] Caregiving Case: Blending personal and general docs (scores are close).")
732
- all_retrieved_docs = list({doc.page_content: doc for doc in filtered_personal_docs + filtered_general_docs}.values())[:4]
733
- elif best_personal_score < best_general_score:
734
- print("[DEBUG] Caregiving Case: Prioritizing personal docs (better score).")
735
- all_retrieved_docs = filtered_personal_docs
736
- else:
737
- print("[DEBUG] Caregiving Case: Prioritizing general docs (better score).")
738
- all_retrieved_docs = filtered_general_docs
739
 
740
- # --- Prompt Generation and LLM Call ---
741
- answer = ""
742
- if is_personal_route:
743
- personal_context = _format_docs(all_retrieved_docs, "(No relevant personal memories found.)")
744
- # New modify for test evaluation, general_context is empty but use general context in live chat
745
- general_context = _format_docs([], "") if for_evaluation else _format_docs(filtered_general_docs, "(No general information found.)")
746
- # End
747
-
748
- print(f"[DEBUG] Personal Context: {personal_context}")
749
- print(f"[DEBUG] General Context: {general_context}")
750
-
751
- template = ANSWER_TEMPLATE_SUMMARIZE if "summarization" in query_type else ANSWER_TEMPLATE_FACTUAL
752
- user_prompt = ""
753
- if "summarization" in query_type:
754
- if for_evaluation: # for evaluation, use only personal
755
- user_prompt = template.format(context=personal_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name, role=role)
756
- else: # for live chat, use more context
757
- combined_context = f"{personal_context}\n{general_context}".strip()
758
- user_prompt = template.format(context=combined_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name, role=role)
759
-
760
- else: # ANSWER_TEMPLATE_FACTUAL
761
- user_prompt = template.format(personal_context=personal_context, general_context=general_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name)
762
-
763
- messages.append({"role": "user", "content": user_prompt})
764
- if for_evaluation: # if evaluation test, set temperature (creativity) low from 0.6 input
765
- test_temperature = 0.0 # Modify the local variable
766
- raw_answer = call_llm(messages, temperature=test_temperature)
767
- answer = _clean_surface_text(raw_answer)
768
- print("[DEBUG] Factual / Sum / Multi LLM Answer: ", {answer})
769
 
770
- else: # caregiving_scenario
771
- # --- MODIFICATION START: Integrate the severity-based logic ---
772
- # The disease_stage variable is available here from the outer function's scope
773
-
774
- # 1. Select the appropriate template based on the disease stage setting.
775
- if disease_stage == "Advanced Stage":
776
- template = ANSWER_TEMPLATE_ADQ_ADVANCED
777
- elif disease_stage == "Moderate Stage":
778
- template = ANSWER_TEMPLATE_ADQ_MODERATE
779
- else: # Normal / Unspecified or Mild Stage
780
- template = ANSWER_TEMPLATE_ADQ
781
-
782
- # 2. The rest of the logic remains the same. It will use the 'template' variable
783
- # that was just selected above.
784
- personal_sources = {'1 Complaints of a Dutiful Daughter.txt', 'Saved Chat', 'Text Input'}
785
- personal_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') in personal_sources], "(No relevant personal memories found.)")
786
- general_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') not in personal_sources], "(No general guidance found.)")
787
-
788
- print(f"[DEBUG] Personal Context: {personal_context}")
789
- print(f"[DEBUG] General Context: {general_context}")
790
-
791
- first_emotion = next((d.metadata.get("emotion") for d in all_retrieved_docs if d.metadata.get("emotion")), None)
792
- emotions_context = render_emotion_guidelines(first_emotion or kwargs.get("emotion_tag"))
793
-
794
- # NEW: Add Emotion Tag
795
- user_prompt = template.format(general_context=general_context, personal_context=personal_context,
796
- question=query, scenario_tag=kwargs.get("scenario_tag"),
797
- emotions_context=emotions_context, role=role, language=language,
798
- patient_name=p_name, caregiver_name=c_name,
799
- emotion_tag=kwargs.get("emotion_tag"))
800
- messages.append({"role": "user", "content": user_prompt})
801
- # --- MODIFICATION END ---
802
-
803
- # OLD
804
- # template = ANSWER_TEMPLATE_ADQ
805
- # user_prompt = template.format(general_context=general_context, personal_context=personal_context,
806
- # question=query, scenario_tag=kwargs.get("scenario_tag"),
807
- # emotions_context=emotions_context, role=role, language=language,
808
- # patient_name=p_name, caregiver_name=c_name)
809
- # messages.append({"role": "user", "content": user_prompt})
810
-
811
- if for_evaluation: # if evaluation test, set temperature (creativity) low from 0.6 input
812
- test_temperature = 0.0 # Modify the local variable
813
- raw_answer = call_llm(messages, temperature=test_temperature)
814
- answer = _clean_surface_text(raw_answer)
815
- print("[DEBUG] Caregiving Case LLM Answer: ", {answer})
816
- high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
817
- if kwargs.get("scenario_tag") and kwargs["scenario_tag"].lower() in high_risk_scenarios:
818
- answer += f"\n\n---\n{RISK_FOOTER}"
819
-
820
- if for_evaluation:
821
- sources = _source_ids_for_eval(all_retrieved_docs)
822
- else:
823
- sources = sorted(list(set(d.metadata.get("source", "unknown") for d in all_retrieved_docs if d.metadata.get("source") != "placeholder")))
824
 
 
825
  print("DEBUG Sources (After Filtering):", sources)
826
  return {"answer": answer, "sources": sources, "source_documents": all_retrieved_docs}
827
 
828
- return _answer_fn
829
-
830
  # END of make_rag_chain
831
 
832
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
 
8
  import random # for random select songs
9
 
10
  from typing import List, Dict, Any, Optional
11
+ from sentence_transformers import CrossEncoder
12
 
13
  try:
14
  from openai import OpenAI
 
108
  _MH_PATTERNS = [re.compile(p, re.IGNORECASE) for p in MULTI_HOP_KEYPHRASES]
109
 
110
 
111
+
112
+ FACTUAL_KEYPHRASES = [
113
+ r"\b(what is|what was) my\b",
114
+ r"\b(who is|who was) my\b",
115
+ r"\b(where is|where was) my\b",
116
+ r"\b(how old am i)\b",
117
+ # r"\b(when did|what did) the journal say\b"
118
+ ]
119
+ _FQ_PATTERNS = [re.compile(p, re.IGNORECASE) for p in FACTUAL_KEYPHRASES]
120
+
121
+ def _pre_router_factual(query: str) -> str | None:
122
+ """Checks for patterns common in direct factual questions about personal memory."""
123
+ q = (query or "")
124
+ for pat in _FQ_PATTERNS:
125
+ if re.search(pat, q):
126
+ return "factual_question"
127
+ return None
128
+
129
+
130
  # Add this near the top of agent.py with the other keyphrase lists
131
  SUMMARIZATION_KEYPHRASES = [
132
  r"^\b(summarize|summarise|recap)\b", r"^\b(give me a summary|create a short summary)\b"
 
279
  print(f"ERROR parsing NLU Specialist JSON: {e}")
280
  return result_dict
281
 
282
+
283
  def _default_embeddings():
284
  # This function remains unchanged from agent_work.py
285
  model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
286
  return HuggingFaceEmbeddings(model_name=model_name)
287
 
288
+
289
  def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
290
  # This function remains unchanged from agent_work.py
291
  os.makedirs(os.path.dirname(index_path), exist_ok=True)
 
299
  vs.save_local(index_path)
300
  return vs
301
 
302
+
303
+ def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
304
+ # This function remains unchanged from agent_work.py
305
+ docs: List[Document] = []
306
+ for p in (sample_paths or []):
307
+ try:
308
+ if p.lower().endswith(".jsonl"):
309
+ docs.extend(texts_from_jsonl(p))
310
+ else:
311
+ with open(p, "r", encoding="utf-8", errors="ignore") as fh:
312
+ docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
313
+ except Exception: continue
314
+ if not docs:
315
+ docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
316
+ return build_or_load_vectorstore(docs, index_path=index_path)
317
+
318
+
319
  def texts_from_jsonl(path: str) -> List[Document]:
320
  # This function remains unchanged from agent_work.py
321
  out: List[Document] = []
 
332
  except Exception: return []
333
  return out
334
 
335
+
336
+ def rerank_documents(query: str, documents: list[tuple[Document, float]]) -> list[tuple[tuple[Document, float], float]]:
337
+ """
338
+ Re-ranks a list of retrieved documents against a query using a CrossEncoder model.
339
+ Returns the original document tuples along with their new re-ranker score.
340
+ """
341
+ if not documents or not query:
342
+ return []
343
+
344
+ model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
345
+
346
+ doc_contents = [doc.page_content for doc, score in documents]
347
+ query_doc_pairs = [[query, doc_content] for doc_content in doc_contents]
348
+
349
+ scores = model.predict(query_doc_pairs)
350
+
351
+ reranked_results = list(zip(documents, scores))
352
+ reranked_results.sort(key=lambda x: x[1], reverse=True)
353
+
354
+ print(f"\n[DEBUG] Re-ranked Top 3 Sources:")
355
+ for doc_tuple, score in reranked_results[:3]:
356
+ doc, _ = doc_tuple
357
+ # --- MODIFICATION: Add score to debug log ---
358
+ print(f" - New Rank | Source: {doc.metadata.get('source')} | Score: {score:.4f}")
359
+
360
+ # --- MODIFICATION: Return the results with scores ---
361
+ return reranked_results
362
+
363
+
364
+
365
  # Some vectorstores might return duplicates.
366
  # This is useful when top-k cutoff might otherwise include near-duplicates from query expansion
367
  def dedup_docs(scored_docs):
 
374
  seen.add(uid)
375
  return unique
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None, response_format: Optional[dict] = None) -> str:
379
  # This function remains unchanged from agent_work.py
 
452
  print(f"Query classified as: {sum_hit} (summarization pre-router)")
453
  return sum_hit
454
 
455
+ # Priority 4: Check for music requests.
456
  # NEW Add Music Support before care_hit = _pre_router(query)
457
  # the general "caregiving" keyword checker (_pre_router) is called before
458
  # the specific "play music" checker (_pre_router_music).
 
460
  if music_hit:
461
  print(f"Query classified as: {music_hit} (music re-router)")
462
  return music_hit
463
+
464
+ # Priority 5: Check for general caregiving keywords.
465
  care_hit = _pre_router(query)
466
  if care_hit:
467
  print(f"Query classified as: {care_hit} (caregiving pre-router)")
 
476
 
477
  # helper: put near other small utils in agent.py
478
  # In agent.py, replace the _source_ids_for_eval function
 
479
  def _source_ids_for_eval(docs, cap=5):
480
  """
481
  Return the source identifiers for evaluation.
 
638
  # --- END OF REVISED MUSIC LOGIC ---
639
  # END --- MUSIC PLAYBACK LOGIC ---
640
 
 
641
  p_name = patient_name or "the patient"
642
  c_name = caregiver_name or "the caregiver"
 
643
  perspective_line = (f"You are speaking directly to {p_name}, who is the patient...") if role == "patient" else (f"You are communicating with {c_name}, the caregiver, about {p_name}.")
644
  system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, perspective_line=perspective_line, guardrails=SAFETY_GUARDRAILS)
645
  messages = [{"role": "system", "content": system_message}]
646
  messages.extend(chat_history)
647
+
648
  if "general_knowledge_question" in query_type or "general_conversation" in query_type:
649
  template = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE if "general_knowledge" in query_type else ANSWER_TEMPLATE_GENERAL
650
  user_prompt = template.format(question=query, language=language)
 
653
  answer = _clean_surface_text(raw_answer)
654
  sources = ["General Knowledge"] if "general_knowledge" in query_type else []
655
  return {"answer": answer, "sources": sources, "source_documents": []}
656
+ # --- END: Non-RAG Route Handling ---
657
 
658
+ all_retrieved_docs = []
659
+ is_personal_route = "factual" in query_type or "summarization" in query_type or "multi_hop" in query_type
 
 
 
 
660
 
661
+
662
+ # --- NEW: DEDICATED LOGIC PATHS FOR RETRIEVAL ---
663
+ if is_personal_route:
664
+ # For personal queries, semantic search is unreliable. We retrieve ALL personal documents.
665
+ print("[DEBUG] Personal Memory Route Activated. Retrieving all personal documents.")
666
+ if vs_personal and vs_personal._collection.count() > 0:
667
+ personal_docs_data = vs_personal.get(include=["metadatas", "documents"])
668
+ all_retrieved_docs.extend([
669
+ Document(page_content=doc_content, metadata=metadata)
670
+ for doc_content, metadata in zip(personal_docs_data['documents'], personal_docs_data['metadatas'])
671
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  else:
673
+ # For caregiving scenarios, use our powerful Multi-Stage Retrieval algorithm.
674
+ print("[DEBUG] Using Multi-Stage Retrieval for caregiving scenario...")
675
+ print("[DEBUG] Expanding query...")
676
+ search_queries = [query]
677
+ try:
678
+ expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
679
+ expansion_messages = [{"role": "user", "content": expansion_prompt}]
680
+ raw_expansion = call_llm(expansion_messages, temperature=0.0)
681
+ expanded = json.loads(raw_expansion)
682
+ if isinstance(expanded, list):
683
+ search_queries.extend(expanded)
684
+ except Exception as e:
685
+ print(f"[DEBUG] Query expansion failed: {e}")
686
+
687
+ scenario_tags = kwargs.get("scenario_tag")
688
+ if isinstance(scenario_tags, str): scenario_tags = [scenario_tags]
689
+ primary_behavior = (scenario_tags or [None])[0]
690
 
691
+ candidate_docs = []
692
+ if primary_behavior and primary_behavior != "None":
693
+ print(f" - Stage 1a: High-precision search for behavior: '{primary_behavior}'")
694
+ for q in search_queries:
695
+ candidate_docs.extend(vs_general.similarity_search_with_score(q, k=10, filter={"behaviors": primary_behavior}))
696
+
697
+ print(" - Stage 1b: High-recall semantic search (k=20)")
698
+ for q in search_queries:
699
+ candidate_docs.extend(vs_general.similarity_search_with_score(q, k=20))
700
+
701
+ all_candidate_docs = dedup_docs(candidate_docs)
702
+ print(f"[DEBUG] Total unique candidates for re-ranking: {len(all_candidate_docs)}")
703
+ reranked_docs_with_scores = rerank_documents(query, all_candidate_docs) if all_candidate_docs else []
704
+
705
+
706
+ # --- BEST method code: Recall 90% and Precision 73%
707
+ final_docs_with_scores = []
708
+ if reranked_docs_with_scores:
709
+ RELATIVE_SCORE_MARGIN = 3.0
710
+ top_doc_tuple, top_score = reranked_docs_with_scores[0]
711
+ final_docs_with_scores.append(top_doc_tuple)
712
+ for doc_tuple, score in reranked_docs_with_scores[1:]:
713
+ if score > (top_score - RELATIVE_SCORE_MARGIN):
714
+ final_docs_with_scores.append(doc_tuple)
715
+ else: break
716
 
717
+ limit = 5 if disease_stage in ["Moderate Stage", "Advanced Stage"] else 3
718
+ final_docs_with_scores = final_docs_with_scores[:limit]
719
+ all_retrieved_docs = [doc for doc, score in final_docs_with_scores]
720
+ # BEFORE FINAL PROCESSING (Applies to all RAG routes)
721
+
722
+ # --- FINAL PROCESSING (Applies to all RAG routes) ---
723
+ print("\n--- DEBUG: Final Selected Docs ---")
724
+ for doc in all_retrieved_docs:
725
+ print(f" - Source: {doc.metadata.get('source', 'N/A')}")
 
 
 
 
 
 
 
 
 
 
 
 
726
  print("----------------------------------------------------------------")
727
 
728
+ personal_sources_set = {'1 Complaints of a Dutiful Daughter.txt', 'Saved Chat', 'Text Input'}
729
+ personal_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') in personal_sources_set], "(No relevant personal memories found.)")
730
+ general_context = _format_docs([d for d in all_retrieved_docs if d.metadata.get('source') not in personal_sources_set], "(No general guidance found.)")
731
+
 
 
 
 
 
 
 
 
732
  if is_personal_route:
733
+ template = ANSWER_TEMPLATE_SUMMARIZE if "summarization" in query_type else ANSWER_TEMPLATE_FACTUAL_MULTI if "multi_hop" in query_type else ANSWER_TEMPLATE_FACTUAL
734
+ user_prompt = template.format(personal_context=personal_context, general_context=general_context, question=query, language=language, patient_name=p_name, caregiver_name=c_name, context=personal_context, role=role)
735
+ print("[DEBUG] Personal Route Factual / Sum / Multi PROMPT")
 
 
 
 
 
 
 
736
  else: # caregiving_scenario
737
+ if disease_stage == "Advanced Stage": template = ANSWER_TEMPLATE_ADQ_ADVANCED
738
+ elif disease_stage == "Moderate Stage": template = ANSWER_TEMPLATE_ADQ_MODERATE
739
+ else: template = ANSWER_TEMPLATE_ADQ
740
+ emotions_context = render_emotion_guidelines(kwargs.get("emotion_tag"))
741
+ user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=kwargs.get("scenario_tag"), emotions_context=emotions_context, role=role, language=language, patient_name=p_name, caregiver_name=c_name, emotion_tag=kwargs.get("emotion_tag"))
742
+ print("[DEBUG] Caregiving Scenario PROMPT")
743
+ # end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
 
745
+ messages.append({"role": "user", "content": user_prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
+ raw_answer = call_llm(messages, temperature=0.0 if for_evaluation else temperature)
748
+ answer = _clean_surface_text(raw_answer)
749
+ print("[DEBUG] LLM Answer", {answer})
750
+
751
+ if (kwargs.get("scenario_tag") or "").lower() in ["exit_seeking", "wandering"]:
752
+ answer += f"\n\n---\n{RISK_FOOTER}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
+ sources = _source_ids_for_eval(all_retrieved_docs) if for_evaluation else sorted(list(set(d.metadata.get("source", "unknown") for d in all_retrieved_docs if d.metadata.get("source") != "placeholder")))
755
  print("DEBUG Sources (After Filtering):", sources)
756
  return {"answer": answer, "sources": sources, "source_documents": all_retrieved_docs}
757
 
758
+ return _answer_fn
 
759
  # END of make_rag_chain
760
 
761
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]: