KeenWoo's picture
Upload 13 files
c5eec51 verified
raw
history blame
11.3 kB
from __future__ import annotations
import os
import json
import time
import tempfile
from typing import List, Dict, Any, Optional
# OpenAI for LLM (optional)
try:
from openai import OpenAI
except Exception: # pragma: no cover
OpenAI = None # type: ignore
# LangChain & RAG
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
# TTS
try:
from gtts import gTTS
except Exception: # pragma: no cover
gTTS = None # type: ignore
# --- INTEGRATION: Import the new, sophisticated prompts from prompts.py ---
from .prompts import (
SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM, ANSWER_TEMPLATE_ADQ,
SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines, CLASSIFICATION_PROMPT
)
# -----------------------------
# NLU Classification Function (NEW)
# -----------------------------
def detect_tags_from_query(query: str, behavior_options: list, emotion_options: list) -> Dict[str, Optional[str]]:
"""Uses an LLM call to classify the user's query into a behavior and emotion tag."""
# Format the options for the prompt
behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
emotion_str = ", ".join(f'"{opt}"' for opt in emotion_options if opt != "None")
# Build the classification prompt
prompt = CLASSIFICATION_PROMPT.format(
behavior_options=behavior_str,
emotion_options=emotion_str,
query=query
)
messages = [
{"role": "system", "content": "You are a helpful NLU classification assistant. Respond only with the JSON object requested."},
{"role": "user", "content": prompt}
]
# Call the LLM with low temperature for a deterministic response
response_str = call_llm(messages, temperature=0.1)
# Safely parse the JSON response
try:
# The LLM might return the JSON inside a markdown block
clean_response = response_str.strip().replace("```json", "").replace("```", "")
result = json.loads(clean_response)
# Validate the response
behavior = result.get("detected_behavior")
emotion = result.get("detected_emotion")
return {
"detected_behavior": behavior if behavior in behavior_options else "None",
"detected_emotion": emotion if emotion in emotion_options else "None"
}
except (json.JSONDecodeError, AttributeError):
# Fallback if the LLM response is not valid JSON
return {"detected_behavior": "None", "detected_emotion": "None"}
# -----------------------------
# Embeddings & VectorStore
# -----------------------------
# (This entire section remains unchanged)
def _default_embeddings():
"""Lightweight, widely available model."""
model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
return HuggingFaceEmbeddings(model_name=model_name)
def build_or_load_vectorstore(docs: List[Document], index_path: str) -> FAISS:
os.makedirs(os.path.dirname(index_path), exist_ok=True)
if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
try:
return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
except Exception:
pass
vs = FAISS.from_documents(docs, _default_embeddings())
vs.save_local(index_path)
return vs
def texts_from_jsonl(path: str) -> List[Document]:
"""Load a JSONL file, parsing text and all relevant metadata."""
out: List[Document] = []
try:
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line: continue
try:
obj = json.loads(line)
except Exception:
obj = {"text": line}
txt = obj.get("text") or obj.get("content") or obj.get("dialogue") or ""
if not isinstance(txt, str) or not txt.strip(): continue
md = {"source": os.path.basename(path), "chunk": i}
if "metadata" in obj and isinstance(obj["metadata"], dict):
md.update(obj["metadata"])
for k in ("scene_description", "tags", "theme", "behaviors", "role", "emotion"):
if k in obj:
if k == 'behaviors' and isinstance(obj[k], str):
md[k] = [tag.strip() for tag in obj[k].split(',')]
else:
md[k] = obj[k]
out.append(Document(page_content=txt, metadata=md))
except Exception:
return []
return out
def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
docs: List[Document] = []
for p in (sample_paths or []):
try:
if p.lower().endswith(".jsonl"):
docs.extend(texts_from_jsonl(p))
else:
with open(p, "r", encoding="utf-8", errors="ignore") as fh:
docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
except Exception:
continue
if not docs:
docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
return build_or_load_vectorstore(docs, index_path=index_path)
# -----------------------------
# LLM Call
# -----------------------------
# (This entire section remains unchanged)
def _openai_client() -> Optional[OpenAI]:
api_key = os.getenv("OPENAI_API_KEY", "").strip()
return OpenAI(api_key=api_key) if api_key and OpenAI else None
def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6) -> str:
"""Call OpenAI Chat Completions if available; else return a fallback."""
client = _openai_client()
model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
if not client:
return "(Offline Mode: OpenAI API key not configured.)"
try:
resp = client.chat.completions.create(model=model, messages=messages, temperature=float(temperature))
return (resp.choices[0].message.content or "").strip()
except Exception as e:
return f"[LLM API Error: {e}]"
# -----------------------------
# Prompting & RAG Chain
# -----------------------------
# (This section is unchanged as the logic now lives in _answer_fn)
def _format_sources(docs: List[Document]) -> List[str]:
return list(set(d.metadata.get("source", "unknown") for d in docs))
def make_rag_chain(
vs: FAISS,
*,
role: str = "patient",
temperature: float = 0.6,
language: str = "English",
patient_name: str = "the patient",
caregiver_name: str = "the caregiver",
tone: str = "warm",
):
"""Returns a callable that performs the complete, context-aware RAG process."""
retriever = vs.as_retriever(search_kwargs={"k": 5})
def _format_docs(docs: List[Document]) -> str:
if not docs: return "(No relevant information found in the knowledge base.)"
return "\n".join([f"- {d.page_content.strip()}" for d in docs])
def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None) -> Dict[str, Any]:
search_filter = {}
if scenario_tag and scenario_tag != "None":
search_filter["behaviors"] = scenario_tag.lower()
if emotion_tag and emotion_tag != "None":
search_filter["emotion"] = emotion_tag.lower()
if search_filter:
docs = vs.similarity_search(query, k=5, filter=search_filter)
else:
docs = retriever.invoke(query)
context = _format_docs(docs)
first_emotion = None
for doc in docs:
if "emotion" in doc.metadata and doc.metadata["emotion"]:
emotion_data = doc.metadata["emotion"]
if isinstance(emotion_data, list):
first_emotion = emotion_data[0]
else:
first_emotion = emotion_data
break
emotions_context = render_emotion_guidelines(first_emotion)
is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
user_prompt = template.format(
context=context,
question=query,
scenario_tag=scenario_tag,
emotions_context=emotions_context,
role=role,
language=language
)
system_message = SYSTEM_TEMPLATE.format(
tone=tone, language=language, patient_name=patient_name or "the patient",
caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS,
)
messages = [{"role": "system", "content": system_message}]
messages.extend(chat_history)
messages.append({"role": "user", "content": user_prompt})
answer = call_llm(messages, temperature=temperature)
high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
if scenario_tag and scenario_tag.lower() in high_risk_scenarios:
answer += f"\n\n---\n{RISK_FOOTER}"
return {"answer": answer, "sources": _format_sources(docs)}
return _answer_fn
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
"""A clean wrapper to pass arguments from the UI to the RAG chain."""
if not callable(chain):
return {"answer": "[Error: RAG chain is not callable]", "sources": []}
chat_history = kwargs.get("chat_history", [])
scenario_tag = kwargs.get("scenario_tag")
emotion_tag = kwargs.get("emotion_tag")
try:
return chain(question, chat_history=chat_history, scenario_tag=scenario_tag, emotion_tag=emotion_tag)
except Exception as e:
print(f"ERROR in answer_query: {e}")
return {"answer": f"[Error executing chain: {e}]", "sources": []}
# -----------------------------
# TTS & Transcription
# -----------------------------
# (This entire section remains unchanged)
def synthesize_tts(text: str, lang: str = "en"):
"""Returns a path to a temporary audio file."""
if not text or gTTS is None: return None
try:
fd, path = tempfile.mkstemp(suffix=".mp3")
os.close(fd)
tts = gTTS(text=text, lang=(lang or "en"))
tts.save(path)
return path
except Exception:
return None
def transcribe_audio(filepath: str, lang: str = "en"):
"""Transcribes an audio file using OpenAI's Whisper API."""
client = _openai_client()
if not client:
return "[Transcription failed: API key not configured]"
api_args = {
"model": "whisper-1",
}
if lang and lang != "auto":
api_args["language"] = lang
with open(filepath, "rb") as audio_file:
transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
return transcription.text