Niranjan Sathish
commited on
Commit
·
40d2f99
1
Parent(s):
a337894
Initial Commit
Browse files- .gitattributes +13 -34
- .gitignore +52 -0
- Data/Dataset.json +3 -0
- Data/doc_metadata.pkl +3 -0
- Data/doc_vectors.npy +3 -0
- Data/faiss_index.idx +3 -0
- Data/flattened_drug_dataset_cleaned.csv +3 -0
- Evaluation/Evaluation_metrics_score.py +96 -0
- Evaluation/custom_drug_eval_set_id.csv +3 -0
- README.md +53 -7
- Scripts/Answer_Generation.py +125 -0
- Scripts/Query_processing.py +182 -0
- Scripts/Retrieval.py +175 -0
- Scripts/app.py +246 -0
- Scripts/demo.py +53 -0
- requirements.txt +22 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,14 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.idx filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
Data/* filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
Data/*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
Data/*.npy filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
Data/*.idx filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
Data/*.csv filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
Data/*.npygit filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
Data/*.idxgit filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
Data/*.pklgit filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
lfs filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
track filter=lfs diff=lfs merge=lfs -text
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
build/
|
| 11 |
+
develop-eggs/
|
| 12 |
+
dist/
|
| 13 |
+
downloads/
|
| 14 |
+
eggs/
|
| 15 |
+
.eggs/
|
| 16 |
+
lib/
|
| 17 |
+
lib64/
|
| 18 |
+
parts/
|
| 19 |
+
sdist/
|
| 20 |
+
var/
|
| 21 |
+
wheels/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Environment
|
| 27 |
+
.env
|
| 28 |
+
Chatbot.venv/
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
|
| 36 |
+
# OS
|
| 37 |
+
.DS_Store
|
| 38 |
+
Thumbs.db
|
| 39 |
+
|
| 40 |
+
# Data (don't commit large files)
|
| 41 |
+
*.pkl
|
| 42 |
+
*.npy
|
| 43 |
+
*.idx
|
| 44 |
+
!Data/*.pkl
|
| 45 |
+
!Data/*.npy
|
| 46 |
+
!Data/*.idx
|
| 47 |
+
|
| 48 |
+
# Model cache
|
| 49 |
+
.cache/
|
| 50 |
+
model_cache/
|
| 51 |
+
|
| 52 |
+
---
|
Data/Dataset.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc38f7e5bfad6d7c2865ed7c94d483c8b9b887a47853e4a3c16ce957ce1f06a0
|
| 3 |
+
size 35120734
|
Data/doc_metadata.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:800157a95b50080634fdce730014af49a8e0cf01d2dbb484785b15936dc9abff
|
| 3 |
+
size 53368209
|
Data/doc_vectors.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f54da3cd890cf384fdc3b7abcd6ed5f840c0f53da30615fd417fc8256fd1b5ca
|
| 3 |
+
size 70190720
|
Data/faiss_index.idx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58d68a5ccb27c94e357ab12eec21d5d54d903949ae37648202643eb33387156b
|
| 3 |
+
size 70190637
|
Data/flattened_drug_dataset_cleaned.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0669d5d7366973a342a3cc35321366a02837c66ac5e7c28c3bf0569897db5b84
|
| 3 |
+
size 31338099
|
Evaluation/Evaluation_metrics_score.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation Script for Retrieval-based QA Chatbot
|
| 3 |
+
=================================================
|
| 4 |
+
|
| 5 |
+
This module handles:
|
| 6 |
+
1. Loading evaluation questions and expected chunk IDs
|
| 7 |
+
2. Preprocessing queries and retrieving top chunks
|
| 8 |
+
3. Calculating Precision@3, Recall@3, F1-Score@3, and Success Rate@3
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from Query_processing import preprocess_query
|
| 13 |
+
from Retrieval import Retrieval_averagedQP
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
# -------------------------------
|
| 17 |
+
# File Paths
|
| 18 |
+
# -------------------------------
|
| 19 |
+
|
| 20 |
+
# Get the directory of the current script
|
| 21 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
|
| 23 |
+
# Path to evaluation dataset
|
| 24 |
+
csv_path = os.path.join(script_dir, 'custom_drug_eval_set_id.csv')
|
| 25 |
+
|
| 26 |
+
# -------------------------------
|
| 27 |
+
# Load Evaluation Dataset
|
| 28 |
+
# -------------------------------
|
| 29 |
+
|
| 30 |
+
df = pd.read_csv(csv_path)
|
| 31 |
+
|
| 32 |
+
# -------------------------------
|
| 33 |
+
# Evaluation Storage
|
| 34 |
+
# -------------------------------
|
| 35 |
+
|
| 36 |
+
all_precisions = []
|
| 37 |
+
all_recalls = []
|
| 38 |
+
all_f1s = []
|
| 39 |
+
all_successes = []
|
| 40 |
+
|
| 41 |
+
# -------------------------------
|
| 42 |
+
# Evaluation Loop
|
| 43 |
+
# -------------------------------
|
| 44 |
+
|
| 45 |
+
for _, row in df.iterrows():
|
| 46 |
+
question = row['question']
|
| 47 |
+
expected_ids = set(map(int, filter(None, str(row['relevant_chunk']).split(';'))))
|
| 48 |
+
|
| 49 |
+
print(f"\n[Evaluation] Question: {question}")
|
| 50 |
+
print(f"[Expected Chunk IDs] {expected_ids}")
|
| 51 |
+
|
| 52 |
+
# Preprocess the query
|
| 53 |
+
intent, entities = preprocess_query(question)
|
| 54 |
+
|
| 55 |
+
# Retrieve top-k chunk predictions
|
| 56 |
+
retrieved_df = Retrieval_averagedQP(question, intent, entities, top_k=10, alpha=0.8)
|
| 57 |
+
retrieved_df = retrieved_df.head(3) # Limit to top 3 results
|
| 58 |
+
retrieved_ids = set(retrieved_df['chunk_id'].astype(int).tolist())
|
| 59 |
+
|
| 60 |
+
print(f"[Retrieved Chunk IDs] {retrieved_ids}")
|
| 61 |
+
|
| 62 |
+
# Evaluation Metrics Calculation
|
| 63 |
+
tp = len(retrieved_ids & expected_ids)
|
| 64 |
+
fp = len(retrieved_ids - expected_ids)
|
| 65 |
+
fn = len(expected_ids - retrieved_ids)
|
| 66 |
+
|
| 67 |
+
print(f"[Metrics] TP: {tp}, FP: {fp}, FN: {fn}")
|
| 68 |
+
|
| 69 |
+
success = 1 if tp > 0 else 0
|
| 70 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 71 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 72 |
+
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
| 73 |
+
|
| 74 |
+
all_precisions.append(precision)
|
| 75 |
+
all_recalls.append(recall)
|
| 76 |
+
all_f1s.append(f1)
|
| 77 |
+
all_successes.append(success)
|
| 78 |
+
|
| 79 |
+
# -------------------------------
|
| 80 |
+
# Aggregate Results
|
| 81 |
+
# -------------------------------
|
| 82 |
+
|
| 83 |
+
mean_precision = sum(all_precisions) / len(all_precisions)
|
| 84 |
+
mean_recall = sum(all_recalls) / len(all_recalls)
|
| 85 |
+
mean_f1 = sum(all_f1s) / len(all_f1s)
|
| 86 |
+
mean_success = sum(all_successes) / len(all_successes)
|
| 87 |
+
|
| 88 |
+
# -------------------------------
|
| 89 |
+
# Display Final Metrics
|
| 90 |
+
# -------------------------------
|
| 91 |
+
|
| 92 |
+
print("\n========= Final Evaluation Metrics =========")
|
| 93 |
+
print(f"Success Rate@3: {mean_success:.4f}")
|
| 94 |
+
print(f"Precision@3: {mean_precision:.4f}")
|
| 95 |
+
print(f"Recall@3: {mean_recall:.4f}")
|
| 96 |
+
print(f"F1 Score@3: {mean_f1:.4f}")
|
Evaluation/custom_drug_eval_set_id.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a32b1282d7fd5e6d55b73499ee314410cffa69b456a7372983225a71da6b5674
|
| 3 |
+
size 4001
|
README.md
CHANGED
|
@@ -1,12 +1,58 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Medical Drug QA Chatbot
|
| 3 |
+
emoji: 💊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: Scripts/app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 💊 Medical Drug QA Chatbot
|
| 14 |
+
|
| 15 |
+
An intelligent chatbot that answers questions about medications using advanced NLP techniques.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- 🔍 **Smart Query Processing**: BioBERT-based NER for drug entity extraction
|
| 20 |
+
- 📚 **Hybrid Retrieval**: FAISS + BioBERT semantic reranking
|
| 21 |
+
- 🤖 **AI-Powered Answers**: Groq Llama-4 for natural language generation
|
| 22 |
+
- 💾 **Comprehensive Database**: Mayo Clinic drug information
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
Simply ask questions about:
|
| 27 |
+
- Side effects and warnings
|
| 28 |
+
- Dosage and usage instructions
|
| 29 |
+
- Drug interactions
|
| 30 |
+
- Storage guidelines
|
| 31 |
+
- Precautions for specific conditions
|
| 32 |
+
|
| 33 |
+
## Example Questions
|
| 34 |
+
|
| 35 |
+
- "What are the side effects of Aspirin?"
|
| 36 |
+
- "How should I store Insulin?"
|
| 37 |
+
- "What precautions should I take with Lisinopril?"
|
| 38 |
+
- "Can I take Metformin with alcohol?"
|
| 39 |
+
|
| 40 |
+
## Tech Stack
|
| 41 |
+
|
| 42 |
+
- **Frontend**: Gradio
|
| 43 |
+
- **NER**: BioBERT (alvaroalon2/biobert_chemical_ner)
|
| 44 |
+
- **Embeddings**: MiniLM-V6, BioBERT
|
| 45 |
+
- **Vector DB**: FAISS
|
| 46 |
+
- **LLM**: Llama-4 via Groq API
|
| 47 |
+
|
| 48 |
+
## ⚠️ Disclaimer
|
| 49 |
+
|
| 50 |
+
This chatbot provides educational information only. Always consult healthcare professionals for medical advice.
|
| 51 |
+
|
| 52 |
+
## Setup
|
| 53 |
+
|
| 54 |
+
1. Clone the repository
|
| 55 |
+
2. Install dependencies: `pip install -r requirements.txt`
|
| 56 |
+
3. Set `GROQ_API_KEY` environment variable
|
| 57 |
+
4. Build FAISS index: `python Scripts/Retrieval.py`
|
| 58 |
+
5. Run: `python app.py`
|
Scripts/Answer_Generation.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Answer Generation Module for Retrieval-based Medical QA Chatbot
|
| 3 |
+
=================================================================
|
| 4 |
+
This module handles answer generation using Groq API with proper error handling.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
|
| 10 |
+
# Get API key from environment
|
| 11 |
+
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
| 12 |
+
|
| 13 |
+
if GROQ_API_KEY is None:
|
| 14 |
+
print("[Warning] GROQ_API_KEY not set!")
|
| 15 |
+
client = None
|
| 16 |
+
else:
|
| 17 |
+
client = OpenAI(
|
| 18 |
+
api_key=GROQ_API_KEY,
|
| 19 |
+
base_url="https://api.groq.com/openai/v1"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# -------------------------------
|
| 23 |
+
# Function: Query Groq API
|
| 24 |
+
# -------------------------------
|
| 25 |
+
|
| 26 |
+
def query_groq(prompt, model="meta-llama/llama-4-scout-17b-16e-instruct", max_tokens=300):
|
| 27 |
+
"""
|
| 28 |
+
Sends a prompt to Groq API and returns the generated response.
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
prompt (str): The text prompt for the model.
|
| 32 |
+
model (str): Model name deployed on Groq API.
|
| 33 |
+
max_tokens (int): Maximum tokens allowed in the output.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
str: Model-generated response text.
|
| 37 |
+
"""
|
| 38 |
+
if client is None:
|
| 39 |
+
return "⚠️ Error: API key not configured. Please contact the administrator."
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
response = client.chat.completions.create(
|
| 43 |
+
model=model,
|
| 44 |
+
messages=[
|
| 45 |
+
{"role": "system", "content": "You are a helpful biomedical assistant providing accurate drug information."},
|
| 46 |
+
{"role": "user", "content": prompt}
|
| 47 |
+
],
|
| 48 |
+
temperature=0.7,
|
| 49 |
+
max_tokens=max_tokens
|
| 50 |
+
)
|
| 51 |
+
return response.choices[0].message.content.strip()
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"[Answer Generation] Error calling Groq API: {e}")
|
| 54 |
+
return f"⚠️ Error generating answer: {str(e)}"
|
| 55 |
+
|
| 56 |
+
# -------------------------------
|
| 57 |
+
# Function: Build Prompt
|
| 58 |
+
# -------------------------------
|
| 59 |
+
|
| 60 |
+
def build_prompt(question, context):
|
| 61 |
+
"""
|
| 62 |
+
Constructs a prompt for the model combining the user question and retrieved context.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
question (str): User's question.
|
| 66 |
+
context (str): Retrieved relevant text chunks.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: Complete prompt text.
|
| 70 |
+
"""
|
| 71 |
+
return f"""Based strictly on the following medical information, answer the question clearly and concisely.
|
| 72 |
+
|
| 73 |
+
Question: {question}
|
| 74 |
+
|
| 75 |
+
Context:
|
| 76 |
+
{context}
|
| 77 |
+
|
| 78 |
+
Instructions:
|
| 79 |
+
- Provide a direct, accurate answer based only on the context
|
| 80 |
+
- Use clear, simple language
|
| 81 |
+
- If the context doesn't contain enough information, say so
|
| 82 |
+
- Do not add information not present in the context
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# -------------------------------
|
| 86 |
+
# Function: Answer Generation
|
| 87 |
+
# -------------------------------
|
| 88 |
+
|
| 89 |
+
def answer_generation(question, top_chunks, top_k=3):
|
| 90 |
+
"""
|
| 91 |
+
Generates an answer based on retrieved top chunks.
|
| 92 |
+
|
| 93 |
+
Parameters:
|
| 94 |
+
question (str): User's question.
|
| 95 |
+
top_chunks (DataFrame): Retrieved top chunks with context.
|
| 96 |
+
top_k (int): Number of top chunks to use for answer generation.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
str: Final generated answer.
|
| 100 |
+
"""
|
| 101 |
+
try:
|
| 102 |
+
# Select top-k chunks
|
| 103 |
+
top_chunks = top_chunks.head(top_k)
|
| 104 |
+
print(f"[Answer Generation] Using top {len(top_chunks)} chunks")
|
| 105 |
+
|
| 106 |
+
if top_chunks.empty:
|
| 107 |
+
return "⚠️ No relevant information found. Please try rephrasing your question."
|
| 108 |
+
|
| 109 |
+
# Join context
|
| 110 |
+
context = "\n\n".join([
|
| 111 |
+
f"Drug: {row['drug_name']}\n"
|
| 112 |
+
f"Section: {row['section']}\n"
|
| 113 |
+
f"Info: {row['chunk_text']}"
|
| 114 |
+
for _, row in top_chunks.iterrows()
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
# Build prompt and query Groq
|
| 118 |
+
prompt = build_prompt(question, context)
|
| 119 |
+
answer = query_groq(prompt)
|
| 120 |
+
|
| 121 |
+
return answer
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"[Answer Generation] Error: {e}")
|
| 125 |
+
return f"⚠️ Error generating answer: {str(e)}"
|
Scripts/Query_processing.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Processing Pipeline for Retrieval-based QA Chatbot
|
| 3 |
+
========================================================
|
| 4 |
+
|
| 5 |
+
This module handles:
|
| 6 |
+
1. Query preprocessing
|
| 7 |
+
2. Intent and sub-intent classification
|
| 8 |
+
3. Named Entity Recognition (NER) using lightweight BioBERT
|
| 9 |
+
|
| 10 |
+
Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
from typing import List, Tuple
|
| 15 |
+
from transformers import pipeline
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
# -------------------------------
|
| 19 |
+
# Initialize Lightweight NER Model
|
| 20 |
+
# -------------------------------
|
| 21 |
+
|
| 22 |
+
print("[NER] Loading lightweight BioBERT NER model...")
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
# This model is specifically trained for chemical/drug entity recognition
|
| 26 |
+
ner_model = pipeline(
|
| 27 |
+
"ner",
|
| 28 |
+
model="alvaroalon2/biobert_chemical_ner",
|
| 29 |
+
aggregation_strategy="simple",
|
| 30 |
+
device=0 if torch.cuda.is_available() else -1
|
| 31 |
+
)
|
| 32 |
+
print("[NER] ✓ Model loaded successfully\n")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"[NER] ✗ Failed to load model: {e}")
|
| 35 |
+
ner_model = None
|
| 36 |
+
|
| 37 |
+
# -------------------------------
|
| 38 |
+
# Named Entity Extraction
|
| 39 |
+
# -------------------------------
|
| 40 |
+
|
| 41 |
+
def extract_entities_BERT(question: str) -> List[str]:
|
| 42 |
+
"""
|
| 43 |
+
Extract biomedical entities using lightweight BioBERT NER.
|
| 44 |
+
|
| 45 |
+
Parameters:
|
| 46 |
+
question (str): User query
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List[str]: Extracted entities (drugs, chemicals, etc.)
|
| 50 |
+
"""
|
| 51 |
+
if ner_model is None:
|
| 52 |
+
print("[NER] Model not available, returning empty list")
|
| 53 |
+
return []
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# Run NER pipeline
|
| 57 |
+
entities = ner_model(question)
|
| 58 |
+
|
| 59 |
+
# Filter and clean entities
|
| 60 |
+
extracted = []
|
| 61 |
+
for ent in entities:
|
| 62 |
+
# Only keep high-confidence entities (>70%)
|
| 63 |
+
if ent['score'] > 0.7:
|
| 64 |
+
# Clean up subword tokens (remove ##)
|
| 65 |
+
entity_text = ent['word'].replace('##', '').strip()
|
| 66 |
+
|
| 67 |
+
# Filter out very short entities and common words
|
| 68 |
+
if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']:
|
| 69 |
+
extracted.append(entity_text)
|
| 70 |
+
|
| 71 |
+
# Remove duplicates while preserving order
|
| 72 |
+
unique_entities = []
|
| 73 |
+
seen = set()
|
| 74 |
+
for ent in extracted:
|
| 75 |
+
ent_lower = ent.lower()
|
| 76 |
+
if ent_lower not in seen:
|
| 77 |
+
seen.add(ent_lower)
|
| 78 |
+
unique_entities.append(ent)
|
| 79 |
+
|
| 80 |
+
return unique_entities
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"[NER] Extraction failed: {e}")
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -------------------------------
|
| 88 |
+
# Rule-Based Intent Classification
|
| 89 |
+
# -------------------------------
|
| 90 |
+
|
| 91 |
+
def classify_intent(question: str) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Classify the user's query into a high-level intent based on keywords.
|
| 94 |
+
|
| 95 |
+
Parameters:
|
| 96 |
+
question (str): The user's question.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects']
|
| 100 |
+
"""
|
| 101 |
+
q = question.lower()
|
| 102 |
+
|
| 103 |
+
if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q):
|
| 104 |
+
return "description"
|
| 105 |
+
elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q):
|
| 106 |
+
return "before_using"
|
| 107 |
+
elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q):
|
| 108 |
+
return "proper_use"
|
| 109 |
+
elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q):
|
| 110 |
+
return "precautions"
|
| 111 |
+
elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q):
|
| 112 |
+
return "side_effects"
|
| 113 |
+
else:
|
| 114 |
+
return "description" # default fallback
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# -------------------------------
|
| 118 |
+
# Query Preprocessing Wrapper
|
| 119 |
+
# -------------------------------
|
| 120 |
+
|
| 121 |
+
def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]:
|
| 122 |
+
"""
|
| 123 |
+
Main preprocessing function that extracts:
|
| 124 |
+
- Intent
|
| 125 |
+
- Subsection
|
| 126 |
+
- Named Entities
|
| 127 |
+
|
| 128 |
+
Parameters:
|
| 129 |
+
raw_query (str): The raw user question.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities)
|
| 133 |
+
"""
|
| 134 |
+
try:
|
| 135 |
+
intent = classify_intent(raw_query)
|
| 136 |
+
entities = extract_entities_BERT(raw_query)
|
| 137 |
+
|
| 138 |
+
if not entities:
|
| 139 |
+
print("[NER fallback] No entities found. Using raw query.")
|
| 140 |
+
return (intent or ""), []
|
| 141 |
+
|
| 142 |
+
print(f"[Query Processed] Intent = {intent}| Entities = {entities}")
|
| 143 |
+
return (intent or ""), entities
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"[Preprocessing failed] {e}")
|
| 147 |
+
return (""), []
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# -------------------------------
|
| 151 |
+
# Optional: Test Function
|
| 152 |
+
# -------------------------------
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
"""Test the NER with sample queries."""
|
| 156 |
+
|
| 157 |
+
test_queries = [
|
| 158 |
+
"What are the side effects of Azithromycin?",
|
| 159 |
+
"How much dosage of aspirin should I take for headache?",
|
| 160 |
+
"Can I take Lisinopril during pregnancy?",
|
| 161 |
+
"What is Metformin used for?",
|
| 162 |
+
"Are there interactions between Warfarin and Ibuprofen?",
|
| 163 |
+
"How should I store insulin?",
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
print("\n" + "="*70)
|
| 167 |
+
print("TESTING LIGHTWEIGHT TRANSFORMER NER")
|
| 168 |
+
print("="*70 + "\n")
|
| 169 |
+
|
| 170 |
+
for i, query in enumerate(test_queries, 1):
|
| 171 |
+
print(f"[Test {i}] Query: {query}")
|
| 172 |
+
print("-" * 70)
|
| 173 |
+
|
| 174 |
+
(intent), entities = preprocess_query(query)
|
| 175 |
+
|
| 176 |
+
print(f" Intent: {intent}")
|
| 177 |
+
print(f" Entities: {entities if entities else 'None detected'}")
|
| 178 |
+
print("-" * 70 + "\n")
|
| 179 |
+
|
| 180 |
+
print("="*70)
|
| 181 |
+
print("TESTING COMPLETE")
|
| 182 |
+
print("="*70)
|
Scripts/Retrieval.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Retrieval and FAISS Embedding Module for Medical QA Chatbot
|
| 3 |
+
============================================================
|
| 4 |
+
|
| 5 |
+
This module handles:
|
| 6 |
+
1. Embedding documents
|
| 7 |
+
2. Building and saving FAISS index
|
| 8 |
+
3. Retrieval with initial FAISS search + reranking using BioBERT similarity
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import faiss
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from sentence_transformers import SentenceTransformer, util
|
| 16 |
+
from sklearn.preprocessing import normalize
|
| 17 |
+
from Query_processing import preprocess_query
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# -------------------------------
|
| 21 |
+
# File Paths
|
| 22 |
+
# -------------------------------
|
| 23 |
+
|
| 24 |
+
# Get the directory of the current script
|
| 25 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 26 |
+
|
| 27 |
+
# Go up one level to project root, then into Data folder
|
| 28 |
+
PROJECT_ROOT = os.path.dirname(script_dir) # Go up from Scripts/ to project root
|
| 29 |
+
DATA_FOLDER = os.path.join(PROJECT_ROOT, 'Data')
|
| 30 |
+
|
| 31 |
+
# Define all paths
|
| 32 |
+
csv_path = os.path.join(DATA_FOLDER, 'flattened_drug_dataset_cleaned.csv')
|
| 33 |
+
faiss_index_path = os.path.join(DATA_FOLDER, 'faiss_index.idx')
|
| 34 |
+
doc_metadata_path = os.path.join(DATA_FOLDER, 'doc_metadata.pkl')
|
| 35 |
+
doc_vectors_path = os.path.join(DATA_FOLDER, 'doc_vectors.npy')
|
| 36 |
+
|
| 37 |
+
# Load the dataset
|
| 38 |
+
df = pd.read_csv(csv_path).dropna(subset=['chunk_text'])
|
| 39 |
+
|
| 40 |
+
# -------------------------------
|
| 41 |
+
# Model Initialization
|
| 42 |
+
# -------------------------------
|
| 43 |
+
|
| 44 |
+
fast_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 45 |
+
biobert = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
|
| 46 |
+
|
| 47 |
+
# -------------------------------
|
| 48 |
+
# Function: Embed and Build FAISS Index
|
| 49 |
+
# -------------------------------
|
| 50 |
+
|
| 51 |
+
def Embed_and_FAISS():
|
| 52 |
+
"""
|
| 53 |
+
Embeds the drug dataset and builds a FAISS index for fast retrieval.
|
| 54 |
+
Saves the index, metadata, and document vectors to disk.
|
| 55 |
+
"""
|
| 56 |
+
print("Embedding document chunks using fast embedder...")
|
| 57 |
+
|
| 58 |
+
# Build full context strings
|
| 59 |
+
df['full_text'] = df.apply(lambda x: f"{x['drug_name']} | {x['section']} > {x['subsection']} | {x['chunk_text']}", axis=1)
|
| 60 |
+
|
| 61 |
+
full_texts = df['full_text'].tolist()
|
| 62 |
+
doc_embeddings = fast_embedder.encode(full_texts, convert_to_numpy=True, show_progress_bar=True)
|
| 63 |
+
|
| 64 |
+
# Normalize embeddings and build index
|
| 65 |
+
doc_embeddings = normalize(doc_embeddings, axis=1, norm='l2')
|
| 66 |
+
dimension = doc_embeddings.shape[1]
|
| 67 |
+
index = faiss.IndexFlatIP(dimension)
|
| 68 |
+
index.add(doc_embeddings)
|
| 69 |
+
|
| 70 |
+
# Save index and metadata
|
| 71 |
+
faiss.write_index(index, faiss_index_path)
|
| 72 |
+
df.to_pickle(doc_metadata_path)
|
| 73 |
+
np.save(doc_vectors_path, doc_embeddings)
|
| 74 |
+
|
| 75 |
+
print("FAISS index built and saved successfully.")
|
| 76 |
+
|
| 77 |
+
# -------------------------------
|
| 78 |
+
# Function: Retrieve with Context and Averaged Embeddings
|
| 79 |
+
# -------------------------------
|
| 80 |
+
|
| 81 |
+
def retrieve_with_context_averagedembeddings(query, top_k=10, predicted_intent=None, detected_entities=None, alpha=0.8):
|
| 82 |
+
"""
|
| 83 |
+
Retrieve top chunks using FAISS followed by reranking with BioBERT similarity.
|
| 84 |
+
|
| 85 |
+
Parameters:
|
| 86 |
+
query (str): User query text.
|
| 87 |
+
top_k (int): Number of top results to retrieve.
|
| 88 |
+
predicted_intent (str, optional): Detected intent to adjust retrieval.
|
| 89 |
+
detected_entities (list, optional): List of named entities.
|
| 90 |
+
alpha (float): Weight for combining query and intent embeddings.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
pd.DataFrame: Retrieved chunks with metadata and reranked scores.
|
| 94 |
+
"""
|
| 95 |
+
print(f"[Retrieval Pipeline Started] Query: {query}")
|
| 96 |
+
|
| 97 |
+
# Embed and normalize the query
|
| 98 |
+
query_vec = fast_embedder.encode([query], convert_to_numpy=True)
|
| 99 |
+
|
| 100 |
+
if predicted_intent:
|
| 101 |
+
intent_vec = fast_embedder.encode([predicted_intent], convert_to_numpy=True)
|
| 102 |
+
query_vec = normalize((alpha * query_vec + (1 - alpha) * intent_vec), axis=1)
|
| 103 |
+
|
| 104 |
+
# Load FAISS index and search
|
| 105 |
+
index = faiss.read_index(faiss_index_path)
|
| 106 |
+
D, I = index.search(query_vec, top_k)
|
| 107 |
+
|
| 108 |
+
df_meta = pd.read_pickle(doc_metadata_path)
|
| 109 |
+
retrieved_df = df_meta.loc[I[0]].copy()
|
| 110 |
+
retrieved_df['faiss_score'] = D[0]
|
| 111 |
+
|
| 112 |
+
# BioBERT reranking
|
| 113 |
+
query_emb = biobert.encode(query, convert_to_tensor=True)
|
| 114 |
+
chunk_embs = biobert.encode(retrieved_df['full_text'].tolist(), convert_to_tensor=True)
|
| 115 |
+
cos_scores = util.pytorch_cos_sim(query_emb, chunk_embs)[0]
|
| 116 |
+
reranked_idx = torch.argsort(cos_scores, descending=True)
|
| 117 |
+
|
| 118 |
+
# Boost scores based on intent, subsection match, or entity presence
|
| 119 |
+
results = []
|
| 120 |
+
for idx in reranked_idx:
|
| 121 |
+
idx = int(idx)
|
| 122 |
+
row = retrieved_df.iloc[idx]
|
| 123 |
+
score = cos_scores[idx].item()
|
| 124 |
+
|
| 125 |
+
section = row['section'][0] if isinstance(row['section'], tuple) else row['section']
|
| 126 |
+
subsection = row['subsection'][0] if isinstance(row['subsection'], tuple) else row['subsection']
|
| 127 |
+
if isinstance(predicted_intent, tuple):
|
| 128 |
+
predicted_intent = predicted_intent[0]
|
| 129 |
+
|
| 130 |
+
if predicted_intent and section.strip().lower() == predicted_intent.strip().lower():
|
| 131 |
+
score += 0.05
|
| 132 |
+
if predicted_intent and predicted_intent.lower() in subsection.strip().lower():
|
| 133 |
+
score += 0.03
|
| 134 |
+
if detected_entities:
|
| 135 |
+
if any(ent.lower() in row['chunk_text'].lower() for ent in detected_entities):
|
| 136 |
+
score += 0.1
|
| 137 |
+
|
| 138 |
+
results.append({
|
| 139 |
+
'chunk_id': row['chunk_id'],
|
| 140 |
+
'drug_name': row['drug_name'],
|
| 141 |
+
'section': row['section'],
|
| 142 |
+
'subsection': row['subsection'],
|
| 143 |
+
'chunk_text': row['chunk_text'],
|
| 144 |
+
'faiss_score': row['faiss_score'],
|
| 145 |
+
'semantic_similarity_score': score
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
return pd.DataFrame(results)
|
| 149 |
+
|
| 150 |
+
# -------------------------------
|
| 151 |
+
# Function: Retrieval Wrapper
|
| 152 |
+
# -------------------------------
|
| 153 |
+
|
| 154 |
+
def Retrieval_averagedQP(raw_query, intent, entities, top_k=10, alpha=0.8):
|
| 155 |
+
"""
|
| 156 |
+
Wrapper to retrieve top-k chunks given a raw user query.
|
| 157 |
+
|
| 158 |
+
Parameters:
|
| 159 |
+
raw_query (str): The user query.
|
| 160 |
+
intent (str): Predicted intent from query processing.
|
| 161 |
+
entities (list): Detected biomedical entities.
|
| 162 |
+
top_k (int): Number of top results to return.
|
| 163 |
+
alpha (float): Weighting between query and intent embeddings.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
pd.DataFrame: Top retrieved chunks with scores.
|
| 167 |
+
"""
|
| 168 |
+
results_df = retrieve_with_context_averagedembeddings(
|
| 169 |
+
raw_query,
|
| 170 |
+
top_k=top_k,
|
| 171 |
+
predicted_intent=intent,
|
| 172 |
+
detected_entities=entities,
|
| 173 |
+
alpha=alpha
|
| 174 |
+
)
|
| 175 |
+
return results_df[['chunk_id', 'drug_name', 'section', 'subsection', 'chunk_text', 'faiss_score', 'semantic_similarity_score']]
|
Scripts/app.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medical Drug QA Chatbot - Gradio Interface
|
| 3 |
+
Optimized for Hugging Face Spaces Deployment
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Medical Drug QA Chatbot - Gradio Interface
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
# This ensures the imports work correctly
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
sys.path.insert(0, current_dir)
|
| 17 |
+
|
| 18 |
+
from Query_processing import preprocess_query
|
| 19 |
+
from Retrieval import Retrieval_averagedQP
|
| 20 |
+
from Answer_Generation import answer_generation
|
| 21 |
+
|
| 22 |
+
# Rest of your code stays exactly the same...
|
| 23 |
+
|
| 24 |
+
# Lazy imports - only load when needed
|
| 25 |
+
_query_processor = None
|
| 26 |
+
_retrieval_system = None
|
| 27 |
+
_answer_generator = None
|
| 28 |
+
|
| 29 |
+
def initialize_models():
|
| 30 |
+
"""Lazy loading of models to speed up startup."""
|
| 31 |
+
global _query_processor, _retrieval_system, _answer_generator
|
| 32 |
+
|
| 33 |
+
if _query_processor is None:
|
| 34 |
+
print("[App] Loading query processor...")
|
| 35 |
+
from Query_processing import preprocess_query
|
| 36 |
+
_query_processor = preprocess_query
|
| 37 |
+
|
| 38 |
+
if _retrieval_system is None:
|
| 39 |
+
print("[App] Loading retrieval system...")
|
| 40 |
+
from Retrieval import Retrieval_averagedQP
|
| 41 |
+
_retrieval_system = Retrieval_averagedQP
|
| 42 |
+
|
| 43 |
+
if _answer_generator is None:
|
| 44 |
+
print("[App] Loading answer generator...")
|
| 45 |
+
from Answer_Generation import answer_generation
|
| 46 |
+
_answer_generator = answer_generation
|
| 47 |
+
|
| 48 |
+
return _query_processor, _retrieval_system, _answer_generator
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def chat_agent(message: str, history: list) -> tuple:
|
| 52 |
+
"""
|
| 53 |
+
Main chat function with error handling and loading states.
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
message (str): User's question
|
| 57 |
+
history (list): Chat history
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
tuple: (empty string, updated history)
|
| 61 |
+
"""
|
| 62 |
+
if not message or message.strip() == "":
|
| 63 |
+
return "", history
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Initialize models
|
| 67 |
+
preprocess_query, Retrieval_averagedQP, answer_generation = initialize_models()
|
| 68 |
+
|
| 69 |
+
# Step 1: Query Processing
|
| 70 |
+
print(f"[Chat] Processing query: {message}")
|
| 71 |
+
intent, entities = preprocess_query(message)
|
| 72 |
+
|
| 73 |
+
# Step 2: Retrieval
|
| 74 |
+
print(f"[Chat] Retrieving relevant chunks...")
|
| 75 |
+
chunks = Retrieval_averagedQP(message, intent, entities, top_k=10, alpha=0.8)
|
| 76 |
+
|
| 77 |
+
if chunks.empty:
|
| 78 |
+
error_msg = "⚠️ Sorry, I couldn't find relevant information in the database. Please try rephrasing your question."
|
| 79 |
+
history.append({"role": "user", "content": message})
|
| 80 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 81 |
+
return "", history
|
| 82 |
+
|
| 83 |
+
# Step 3: Answer Generation
|
| 84 |
+
print(f"[Chat] Generating answer...")
|
| 85 |
+
answer = answer_generation(message, chunks, top_k=3)
|
| 86 |
+
|
| 87 |
+
# Format context for display
|
| 88 |
+
context = "\n\n".join([
|
| 89 |
+
f"**{row['drug_name']} | {row['section']} > {row['subsection']}**\n"
|
| 90 |
+
f"{row['chunk_text'][:200]}{'...' if len(row['chunk_text']) > 200 else ''}\n"
|
| 91 |
+
f"*Relevance Score: {round(row['semantic_similarity_score'], 3)}*"
|
| 92 |
+
for i, row in chunks.head(3).iterrows()
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
# Add to history
|
| 96 |
+
history.append({"role": "user", "content": message})
|
| 97 |
+
history.append({"role": "assistant", "content": answer})
|
| 98 |
+
history.append({
|
| 99 |
+
"role": "assistant",
|
| 100 |
+
"content": f"<details><summary>📚 View Source Chunks</summary>\n\n{context}\n\n</details>"
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
print(f"[Chat] ✓ Response generated successfully")
|
| 104 |
+
return "", history
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"[Chat] ERROR: {e}")
|
| 108 |
+
import traceback
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
|
| 111 |
+
error_msg = f"❌ An error occurred: {str(e)}\n\nPlease try again or rephrase your question."
|
| 112 |
+
history.append({"role": "user", "content": message})
|
| 113 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 114 |
+
return "", history
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Build Gradio Interface
|
| 118 |
+
with gr.Blocks(
|
| 119 |
+
theme=gr.themes.Soft(primary_hue="cyan"),
|
| 120 |
+
title="Medical Drug QA Chatbot",
|
| 121 |
+
css="""
|
| 122 |
+
.info-container, .info-footer {
|
| 123 |
+
width: 90%;
|
| 124 |
+
max-width: 1000px;
|
| 125 |
+
margin: 0 auto;
|
| 126 |
+
}
|
| 127 |
+
details.info-section, details.about-section {
|
| 128 |
+
background: white;
|
| 129 |
+
border-radius: 12px;
|
| 130 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
| 131 |
+
margin: 1em 0;
|
| 132 |
+
padding: 0;
|
| 133 |
+
}
|
| 134 |
+
details > summary {
|
| 135 |
+
padding: 1em 1.5em;
|
| 136 |
+
font-size: 1.1em;
|
| 137 |
+
font-weight: bold;
|
| 138 |
+
color: #00838f;
|
| 139 |
+
cursor: pointer;
|
| 140 |
+
border-radius: 12px;
|
| 141 |
+
transition: background-color 0.3s ease;
|
| 142 |
+
}
|
| 143 |
+
details > summary:hover {
|
| 144 |
+
background-color: #e0f7fa;
|
| 145 |
+
}
|
| 146 |
+
.disclaimer {
|
| 147 |
+
background: #fff3cd;
|
| 148 |
+
border: 1px solid #ffc107;
|
| 149 |
+
border-radius: 8px;
|
| 150 |
+
padding: 1em;
|
| 151 |
+
margin: 1em 0;
|
| 152 |
+
}
|
| 153 |
+
"""
|
| 154 |
+
) as demo:
|
| 155 |
+
|
| 156 |
+
# Header
|
| 157 |
+
gr.Markdown("# 💊 Medical Drug QA Chatbot")
|
| 158 |
+
gr.Markdown("_Ask questions about medications and get reliable answers from trusted medical sources._")
|
| 159 |
+
|
| 160 |
+
# Instructions
|
| 161 |
+
with gr.Accordion("🤔 How to Use", open=False):
|
| 162 |
+
gr.Markdown("""
|
| 163 |
+
Simply type your question about any medication. You can ask about:
|
| 164 |
+
- **Side effects** and warnings
|
| 165 |
+
- **Dosage** and usage instructions
|
| 166 |
+
- **Drug interactions**
|
| 167 |
+
- **Storage** and handling
|
| 168 |
+
- **Precautions** for specific conditions
|
| 169 |
+
|
| 170 |
+
### 💡 Example Questions:
|
| 171 |
+
- "What are the common side effects of Aspirin?"
|
| 172 |
+
- "How should I store Insulin?"
|
| 173 |
+
- "What precautions should I take with Lisinopril?"
|
| 174 |
+
- "Can I drink alcohol while taking Metformin?"
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
# Chatbot
|
| 178 |
+
chatbot = gr.Chatbot(
|
| 179 |
+
type="messages",
|
| 180 |
+
height=500,
|
| 181 |
+
label="Chat",
|
| 182 |
+
show_label=False,
|
| 183 |
+
avatar_images=(None, "🤖")
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Input
|
| 187 |
+
with gr.Row():
|
| 188 |
+
msg = gr.Textbox(
|
| 189 |
+
placeholder="Ask your medical question here...",
|
| 190 |
+
scale=9,
|
| 191 |
+
container=False,
|
| 192 |
+
show_label=False
|
| 193 |
+
)
|
| 194 |
+
submit = gr.Button("Send", scale=1, variant="primary")
|
| 195 |
+
|
| 196 |
+
with gr.Row():
|
| 197 |
+
clear = gr.Button("🗑️ Clear Chat", scale=1)
|
| 198 |
+
|
| 199 |
+
# Event handlers
|
| 200 |
+
msg.submit(
|
| 201 |
+
fn=chat_agent,
|
| 202 |
+
inputs=[msg, chatbot],
|
| 203 |
+
outputs=[msg, chatbot],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
submit.click(
|
| 207 |
+
fn=chat_agent,
|
| 208 |
+
inputs=[msg, chatbot],
|
| 209 |
+
outputs=[msg, chatbot],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
clear.click(
|
| 213 |
+
fn=lambda: (None, []),
|
| 214 |
+
inputs=None,
|
| 215 |
+
outputs=[msg, chatbot],
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# About section
|
| 219 |
+
with gr.Accordion("📚 About This Project", open=False):
|
| 220 |
+
gr.Markdown("""
|
| 221 |
+
This Medical Drug QA system uses advanced NLP technologies:
|
| 222 |
+
|
| 223 |
+
- **Data Source**: Mayo Clinic's comprehensive drug database
|
| 224 |
+
- **NER**: BioBERT for chemical/drug entity recognition
|
| 225 |
+
- **Retrieval**: Hybrid system with MiniLM-V6 + BioBERT reranking
|
| 226 |
+
- **Answer Generation**: Llama-4 via Groq API
|
| 227 |
+
|
| 228 |
+
**Technologies**: Transformers, FAISS, Sentence-BERT, Gradio
|
| 229 |
+
""")
|
| 230 |
+
|
| 231 |
+
# Disclaimer
|
| 232 |
+
gr.HTML("""
|
| 233 |
+
<div class="disclaimer">
|
| 234 |
+
<strong>⚠️ Medical Disclaimer</strong>: This chatbot provides educational information only.
|
| 235 |
+
It should NOT be used as a substitute for professional medical advice, diagnosis, or treatment.
|
| 236 |
+
Always consult a qualified healthcare provider for medical decisions.
|
| 237 |
+
</div>
|
| 238 |
+
""")
|
| 239 |
+
|
| 240 |
+
# Launch
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
demo.queue() # Enable queuing for better performance
|
| 243 |
+
demo.launch(
|
| 244 |
+
share=False, # Set to False for HF Spaces
|
| 245 |
+
show_error=True
|
| 246 |
+
)
|
Scripts/demo.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main Execution Script for Retrieval-based Medical QA Chatbot
|
| 3 |
+
============================================================
|
| 4 |
+
|
| 5 |
+
This script handles:
|
| 6 |
+
1. Query preprocessing
|
| 7 |
+
2. Information retrieval
|
| 8 |
+
3. Answer generation
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 13 |
+
|
| 14 |
+
from Query_processing import preprocess_query
|
| 15 |
+
from Retrieval import Retrieval_averagedQP
|
| 16 |
+
from Answer_Generation import answer_generation
|
| 17 |
+
from Retrieval import Embed_and_FAISS
|
| 18 |
+
|
| 19 |
+
# -------------------------------
|
| 20 |
+
# Optional: Embed and Store FAISS Index
|
| 21 |
+
# -------------------------------
|
| 22 |
+
# Uncomment the below line to generate embeddings and build the FAISS index if not already done.
|
| 23 |
+
# Embed_and_FAISS()
|
| 24 |
+
|
| 25 |
+
# -------------------------------
|
| 26 |
+
# Define User Question
|
| 27 |
+
# -------------------------------
|
| 28 |
+
|
| 29 |
+
Question = "how much dosage of ibuprofen should be taken for treatment of fever?"
|
| 30 |
+
|
| 31 |
+
# -------------------------------
|
| 32 |
+
# Step 1: Query Preprocessing
|
| 33 |
+
# -------------------------------
|
| 34 |
+
|
| 35 |
+
intent, entities = preprocess_query(Question)
|
| 36 |
+
|
| 37 |
+
# -------------------------------
|
| 38 |
+
# Step 2: Retrieve Relevant Chunks
|
| 39 |
+
# -------------------------------
|
| 40 |
+
|
| 41 |
+
top_chunks = Retrieval_averagedQP(Question, intent, entities, top_k=10, alpha=0.8)
|
| 42 |
+
|
| 43 |
+
# -------------------------------
|
| 44 |
+
# Step 3: Answer Generation
|
| 45 |
+
# -------------------------------
|
| 46 |
+
|
| 47 |
+
Generated_answer = answer_generation(Question, top_chunks, top_k=3)
|
| 48 |
+
|
| 49 |
+
# -------------------------------
|
| 50 |
+
# Display Generated Answer
|
| 51 |
+
# -------------------------------
|
| 52 |
+
|
| 53 |
+
print("Generated Answer:", Generated_answer)
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web Framework
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
|
| 4 |
+
# Data Processing
|
| 5 |
+
pandas>=2.0.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
|
| 8 |
+
# NLP & ML
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
transformers>=4.35.0
|
| 11 |
+
sentence-transformers>=2.2.0
|
| 12 |
+
scikit-learn>=1.3.0
|
| 13 |
+
|
| 14 |
+
# Vector Search
|
| 15 |
+
faiss-cpu>=1.7.4
|
| 16 |
+
|
| 17 |
+
# API Client
|
| 18 |
+
openai>=1.0.0
|
| 19 |
+
|
| 20 |
+
# Optional Performance
|
| 21 |
+
accelerate>=0.24.0
|
| 22 |
+
sentencepiece>=0.1.99
|