Niranjan Sathish commited on
Commit
40d2f99
·
1 Parent(s): a337894

Initial Commit

Browse files
.gitattributes CHANGED
@@ -1,35 +1,14 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
+ *.npy filter=lfs diff=lfs merge=lfs -text
3
+ *.idx filter=lfs diff=lfs merge=lfs -text
4
+ *.csv filter=lfs diff=lfs merge=lfs -text
5
+ Data/* filter=lfs diff=lfs merge=lfs -text
6
+ Data/*.pkl filter=lfs diff=lfs merge=lfs -text
7
+ Data/*.npy filter=lfs diff=lfs merge=lfs -text
8
+ Data/*.idx filter=lfs diff=lfs merge=lfs -text
9
+ Data/*.csv filter=lfs diff=lfs merge=lfs -text
10
+ Data/*.npygit filter=lfs diff=lfs merge=lfs -text
11
+ Data/*.idxgit filter=lfs diff=lfs merge=lfs -text
12
+ Data/*.pklgit filter=lfs diff=lfs merge=lfs -text
13
+ lfs filter=lfs diff=lfs merge=lfs -text
14
+ track filter=lfs diff=lfs merge=lfs -text
 
.gitignore ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Environment
27
+ .env
28
+ Chatbot.venv/
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+
36
+ # OS
37
+ .DS_Store
38
+ Thumbs.db
39
+
40
+ # Data (don't commit large files)
41
+ *.pkl
42
+ *.npy
43
+ *.idx
44
+ !Data/*.pkl
45
+ !Data/*.npy
46
+ !Data/*.idx
47
+
48
+ # Model cache
49
+ .cache/
50
+ model_cache/
51
+
52
+ ---
Data/Dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc38f7e5bfad6d7c2865ed7c94d483c8b9b887a47853e4a3c16ce957ce1f06a0
3
+ size 35120734
Data/doc_metadata.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800157a95b50080634fdce730014af49a8e0cf01d2dbb484785b15936dc9abff
3
+ size 53368209
Data/doc_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f54da3cd890cf384fdc3b7abcd6ed5f840c0f53da30615fd417fc8256fd1b5ca
3
+ size 70190720
Data/faiss_index.idx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d68a5ccb27c94e357ab12eec21d5d54d903949ae37648202643eb33387156b
3
+ size 70190637
Data/flattened_drug_dataset_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0669d5d7366973a342a3cc35321366a02837c66ac5e7c28c3bf0569897db5b84
3
+ size 31338099
Evaluation/Evaluation_metrics_score.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Script for Retrieval-based QA Chatbot
3
+ =================================================
4
+
5
+ This module handles:
6
+ 1. Loading evaluation questions and expected chunk IDs
7
+ 2. Preprocessing queries and retrieving top chunks
8
+ 3. Calculating Precision@3, Recall@3, F1-Score@3, and Success Rate@3
9
+ """
10
+
11
+ import pandas as pd
12
+ from Query_processing import preprocess_query
13
+ from Retrieval import Retrieval_averagedQP
14
+ import os
15
+
16
+ # -------------------------------
17
+ # File Paths
18
+ # -------------------------------
19
+
20
+ # Get the directory of the current script
21
+ script_dir = os.path.dirname(os.path.abspath(__file__))
22
+
23
+ # Path to evaluation dataset
24
+ csv_path = os.path.join(script_dir, 'custom_drug_eval_set_id.csv')
25
+
26
+ # -------------------------------
27
+ # Load Evaluation Dataset
28
+ # -------------------------------
29
+
30
+ df = pd.read_csv(csv_path)
31
+
32
+ # -------------------------------
33
+ # Evaluation Storage
34
+ # -------------------------------
35
+
36
+ all_precisions = []
37
+ all_recalls = []
38
+ all_f1s = []
39
+ all_successes = []
40
+
41
+ # -------------------------------
42
+ # Evaluation Loop
43
+ # -------------------------------
44
+
45
+ for _, row in df.iterrows():
46
+ question = row['question']
47
+ expected_ids = set(map(int, filter(None, str(row['relevant_chunk']).split(';'))))
48
+
49
+ print(f"\n[Evaluation] Question: {question}")
50
+ print(f"[Expected Chunk IDs] {expected_ids}")
51
+
52
+ # Preprocess the query
53
+ intent, entities = preprocess_query(question)
54
+
55
+ # Retrieve top-k chunk predictions
56
+ retrieved_df = Retrieval_averagedQP(question, intent, entities, top_k=10, alpha=0.8)
57
+ retrieved_df = retrieved_df.head(3) # Limit to top 3 results
58
+ retrieved_ids = set(retrieved_df['chunk_id'].astype(int).tolist())
59
+
60
+ print(f"[Retrieved Chunk IDs] {retrieved_ids}")
61
+
62
+ # Evaluation Metrics Calculation
63
+ tp = len(retrieved_ids & expected_ids)
64
+ fp = len(retrieved_ids - expected_ids)
65
+ fn = len(expected_ids - retrieved_ids)
66
+
67
+ print(f"[Metrics] TP: {tp}, FP: {fp}, FN: {fn}")
68
+
69
+ success = 1 if tp > 0 else 0
70
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
71
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
72
+ f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
73
+
74
+ all_precisions.append(precision)
75
+ all_recalls.append(recall)
76
+ all_f1s.append(f1)
77
+ all_successes.append(success)
78
+
79
+ # -------------------------------
80
+ # Aggregate Results
81
+ # -------------------------------
82
+
83
+ mean_precision = sum(all_precisions) / len(all_precisions)
84
+ mean_recall = sum(all_recalls) / len(all_recalls)
85
+ mean_f1 = sum(all_f1s) / len(all_f1s)
86
+ mean_success = sum(all_successes) / len(all_successes)
87
+
88
+ # -------------------------------
89
+ # Display Final Metrics
90
+ # -------------------------------
91
+
92
+ print("\n========= Final Evaluation Metrics =========")
93
+ print(f"Success Rate@3: {mean_success:.4f}")
94
+ print(f"Precision@3: {mean_precision:.4f}")
95
+ print(f"Recall@3: {mean_recall:.4f}")
96
+ print(f"F1 Score@3: {mean_f1:.4f}")
Evaluation/custom_drug_eval_set_id.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a32b1282d7fd5e6d55b73499ee314410cffa69b456a7372983225a71da6b5674
3
+ size 4001
README.md CHANGED
@@ -1,12 +1,58 @@
1
  ---
2
- title: DrugBot Retrieval Based
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Medical Drug QA Chatbot
3
+ emoji: 💊
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: Scripts/app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # 💊 Medical Drug QA Chatbot
14
+
15
+ An intelligent chatbot that answers questions about medications using advanced NLP techniques.
16
+
17
+ ## Features
18
+
19
+ - 🔍 **Smart Query Processing**: BioBERT-based NER for drug entity extraction
20
+ - 📚 **Hybrid Retrieval**: FAISS + BioBERT semantic reranking
21
+ - 🤖 **AI-Powered Answers**: Groq Llama-4 for natural language generation
22
+ - 💾 **Comprehensive Database**: Mayo Clinic drug information
23
+
24
+ ## Usage
25
+
26
+ Simply ask questions about:
27
+ - Side effects and warnings
28
+ - Dosage and usage instructions
29
+ - Drug interactions
30
+ - Storage guidelines
31
+ - Precautions for specific conditions
32
+
33
+ ## Example Questions
34
+
35
+ - "What are the side effects of Aspirin?"
36
+ - "How should I store Insulin?"
37
+ - "What precautions should I take with Lisinopril?"
38
+ - "Can I take Metformin with alcohol?"
39
+
40
+ ## Tech Stack
41
+
42
+ - **Frontend**: Gradio
43
+ - **NER**: BioBERT (alvaroalon2/biobert_chemical_ner)
44
+ - **Embeddings**: MiniLM-V6, BioBERT
45
+ - **Vector DB**: FAISS
46
+ - **LLM**: Llama-4 via Groq API
47
+
48
+ ## ⚠️ Disclaimer
49
+
50
+ This chatbot provides educational information only. Always consult healthcare professionals for medical advice.
51
+
52
+ ## Setup
53
+
54
+ 1. Clone the repository
55
+ 2. Install dependencies: `pip install -r requirements.txt`
56
+ 3. Set `GROQ_API_KEY` environment variable
57
+ 4. Build FAISS index: `python Scripts/Retrieval.py`
58
+ 5. Run: `python app.py`
Scripts/Answer_Generation.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer Generation Module for Retrieval-based Medical QA Chatbot
3
+ =================================================================
4
+ This module handles answer generation using Groq API with proper error handling.
5
+ """
6
+
7
+ import os
8
+ from openai import OpenAI
9
+
10
+ # Get API key from environment
11
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
12
+
13
+ if GROQ_API_KEY is None:
14
+ print("[Warning] GROQ_API_KEY not set!")
15
+ client = None
16
+ else:
17
+ client = OpenAI(
18
+ api_key=GROQ_API_KEY,
19
+ base_url="https://api.groq.com/openai/v1"
20
+ )
21
+
22
+ # -------------------------------
23
+ # Function: Query Groq API
24
+ # -------------------------------
25
+
26
+ def query_groq(prompt, model="meta-llama/llama-4-scout-17b-16e-instruct", max_tokens=300):
27
+ """
28
+ Sends a prompt to Groq API and returns the generated response.
29
+
30
+ Parameters:
31
+ prompt (str): The text prompt for the model.
32
+ model (str): Model name deployed on Groq API.
33
+ max_tokens (int): Maximum tokens allowed in the output.
34
+
35
+ Returns:
36
+ str: Model-generated response text.
37
+ """
38
+ if client is None:
39
+ return "⚠️ Error: API key not configured. Please contact the administrator."
40
+
41
+ try:
42
+ response = client.chat.completions.create(
43
+ model=model,
44
+ messages=[
45
+ {"role": "system", "content": "You are a helpful biomedical assistant providing accurate drug information."},
46
+ {"role": "user", "content": prompt}
47
+ ],
48
+ temperature=0.7,
49
+ max_tokens=max_tokens
50
+ )
51
+ return response.choices[0].message.content.strip()
52
+ except Exception as e:
53
+ print(f"[Answer Generation] Error calling Groq API: {e}")
54
+ return f"⚠️ Error generating answer: {str(e)}"
55
+
56
+ # -------------------------------
57
+ # Function: Build Prompt
58
+ # -------------------------------
59
+
60
+ def build_prompt(question, context):
61
+ """
62
+ Constructs a prompt for the model combining the user question and retrieved context.
63
+
64
+ Parameters:
65
+ question (str): User's question.
66
+ context (str): Retrieved relevant text chunks.
67
+
68
+ Returns:
69
+ str: Complete prompt text.
70
+ """
71
+ return f"""Based strictly on the following medical information, answer the question clearly and concisely.
72
+
73
+ Question: {question}
74
+
75
+ Context:
76
+ {context}
77
+
78
+ Instructions:
79
+ - Provide a direct, accurate answer based only on the context
80
+ - Use clear, simple language
81
+ - If the context doesn't contain enough information, say so
82
+ - Do not add information not present in the context
83
+ """
84
+
85
+ # -------------------------------
86
+ # Function: Answer Generation
87
+ # -------------------------------
88
+
89
+ def answer_generation(question, top_chunks, top_k=3):
90
+ """
91
+ Generates an answer based on retrieved top chunks.
92
+
93
+ Parameters:
94
+ question (str): User's question.
95
+ top_chunks (DataFrame): Retrieved top chunks with context.
96
+ top_k (int): Number of top chunks to use for answer generation.
97
+
98
+ Returns:
99
+ str: Final generated answer.
100
+ """
101
+ try:
102
+ # Select top-k chunks
103
+ top_chunks = top_chunks.head(top_k)
104
+ print(f"[Answer Generation] Using top {len(top_chunks)} chunks")
105
+
106
+ if top_chunks.empty:
107
+ return "⚠️ No relevant information found. Please try rephrasing your question."
108
+
109
+ # Join context
110
+ context = "\n\n".join([
111
+ f"Drug: {row['drug_name']}\n"
112
+ f"Section: {row['section']}\n"
113
+ f"Info: {row['chunk_text']}"
114
+ for _, row in top_chunks.iterrows()
115
+ ])
116
+
117
+ # Build prompt and query Groq
118
+ prompt = build_prompt(question, context)
119
+ answer = query_groq(prompt)
120
+
121
+ return answer
122
+
123
+ except Exception as e:
124
+ print(f"[Answer Generation] Error: {e}")
125
+ return f"⚠️ Error generating answer: {str(e)}"
Scripts/Query_processing.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Processing Pipeline for Retrieval-based QA Chatbot
3
+ ========================================================
4
+
5
+ This module handles:
6
+ 1. Query preprocessing
7
+ 2. Intent and sub-intent classification
8
+ 3. Named Entity Recognition (NER) using lightweight BioBERT
9
+
10
+ Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals)
11
+ """
12
+
13
+ import re
14
+ from typing import List, Tuple
15
+ from transformers import pipeline
16
+ import torch
17
+
18
+ # -------------------------------
19
+ # Initialize Lightweight NER Model
20
+ # -------------------------------
21
+
22
+ print("[NER] Loading lightweight BioBERT NER model...")
23
+
24
+ try:
25
+ # This model is specifically trained for chemical/drug entity recognition
26
+ ner_model = pipeline(
27
+ "ner",
28
+ model="alvaroalon2/biobert_chemical_ner",
29
+ aggregation_strategy="simple",
30
+ device=0 if torch.cuda.is_available() else -1
31
+ )
32
+ print("[NER] ✓ Model loaded successfully\n")
33
+ except Exception as e:
34
+ print(f"[NER] ✗ Failed to load model: {e}")
35
+ ner_model = None
36
+
37
+ # -------------------------------
38
+ # Named Entity Extraction
39
+ # -------------------------------
40
+
41
+ def extract_entities_BERT(question: str) -> List[str]:
42
+ """
43
+ Extract biomedical entities using lightweight BioBERT NER.
44
+
45
+ Parameters:
46
+ question (str): User query
47
+
48
+ Returns:
49
+ List[str]: Extracted entities (drugs, chemicals, etc.)
50
+ """
51
+ if ner_model is None:
52
+ print("[NER] Model not available, returning empty list")
53
+ return []
54
+
55
+ try:
56
+ # Run NER pipeline
57
+ entities = ner_model(question)
58
+
59
+ # Filter and clean entities
60
+ extracted = []
61
+ for ent in entities:
62
+ # Only keep high-confidence entities (>70%)
63
+ if ent['score'] > 0.7:
64
+ # Clean up subword tokens (remove ##)
65
+ entity_text = ent['word'].replace('##', '').strip()
66
+
67
+ # Filter out very short entities and common words
68
+ if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']:
69
+ extracted.append(entity_text)
70
+
71
+ # Remove duplicates while preserving order
72
+ unique_entities = []
73
+ seen = set()
74
+ for ent in extracted:
75
+ ent_lower = ent.lower()
76
+ if ent_lower not in seen:
77
+ seen.add(ent_lower)
78
+ unique_entities.append(ent)
79
+
80
+ return unique_entities
81
+
82
+ except Exception as e:
83
+ print(f"[NER] Extraction failed: {e}")
84
+ return []
85
+
86
+
87
+ # -------------------------------
88
+ # Rule-Based Intent Classification
89
+ # -------------------------------
90
+
91
+ def classify_intent(question: str) -> str:
92
+ """
93
+ Classify the user's query into a high-level intent based on keywords.
94
+
95
+ Parameters:
96
+ question (str): The user's question.
97
+
98
+ Returns:
99
+ str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects']
100
+ """
101
+ q = question.lower()
102
+
103
+ if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q):
104
+ return "description"
105
+ elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q):
106
+ return "before_using"
107
+ elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q):
108
+ return "proper_use"
109
+ elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q):
110
+ return "precautions"
111
+ elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q):
112
+ return "side_effects"
113
+ else:
114
+ return "description" # default fallback
115
+
116
+
117
+ # -------------------------------
118
+ # Query Preprocessing Wrapper
119
+ # -------------------------------
120
+
121
+ def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]:
122
+ """
123
+ Main preprocessing function that extracts:
124
+ - Intent
125
+ - Subsection
126
+ - Named Entities
127
+
128
+ Parameters:
129
+ raw_query (str): The raw user question.
130
+
131
+ Returns:
132
+ Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities)
133
+ """
134
+ try:
135
+ intent = classify_intent(raw_query)
136
+ entities = extract_entities_BERT(raw_query)
137
+
138
+ if not entities:
139
+ print("[NER fallback] No entities found. Using raw query.")
140
+ return (intent or ""), []
141
+
142
+ print(f"[Query Processed] Intent = {intent}| Entities = {entities}")
143
+ return (intent or ""), entities
144
+
145
+ except Exception as e:
146
+ print(f"[Preprocessing failed] {e}")
147
+ return (""), []
148
+
149
+
150
+ # -------------------------------
151
+ # Optional: Test Function
152
+ # -------------------------------
153
+
154
+ if __name__ == "__main__":
155
+ """Test the NER with sample queries."""
156
+
157
+ test_queries = [
158
+ "What are the side effects of Azithromycin?",
159
+ "How much dosage of aspirin should I take for headache?",
160
+ "Can I take Lisinopril during pregnancy?",
161
+ "What is Metformin used for?",
162
+ "Are there interactions between Warfarin and Ibuprofen?",
163
+ "How should I store insulin?",
164
+ ]
165
+
166
+ print("\n" + "="*70)
167
+ print("TESTING LIGHTWEIGHT TRANSFORMER NER")
168
+ print("="*70 + "\n")
169
+
170
+ for i, query in enumerate(test_queries, 1):
171
+ print(f"[Test {i}] Query: {query}")
172
+ print("-" * 70)
173
+
174
+ (intent), entities = preprocess_query(query)
175
+
176
+ print(f" Intent: {intent}")
177
+ print(f" Entities: {entities if entities else 'None detected'}")
178
+ print("-" * 70 + "\n")
179
+
180
+ print("="*70)
181
+ print("TESTING COMPLETE")
182
+ print("="*70)
Scripts/Retrieval.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieval and FAISS Embedding Module for Medical QA Chatbot
3
+ ============================================================
4
+
5
+ This module handles:
6
+ 1. Embedding documents
7
+ 2. Building and saving FAISS index
8
+ 3. Retrieval with initial FAISS search + reranking using BioBERT similarity
9
+ """
10
+
11
+ import faiss
12
+ import pandas as pd
13
+ import numpy as np
14
+ import torch
15
+ from sentence_transformers import SentenceTransformer, util
16
+ from sklearn.preprocessing import normalize
17
+ from Query_processing import preprocess_query
18
+ import os
19
+
20
+ # -------------------------------
21
+ # File Paths
22
+ # -------------------------------
23
+
24
+ # Get the directory of the current script
25
+ script_dir = os.path.dirname(os.path.abspath(__file__))
26
+
27
+ # Go up one level to project root, then into Data folder
28
+ PROJECT_ROOT = os.path.dirname(script_dir) # Go up from Scripts/ to project root
29
+ DATA_FOLDER = os.path.join(PROJECT_ROOT, 'Data')
30
+
31
+ # Define all paths
32
+ csv_path = os.path.join(DATA_FOLDER, 'flattened_drug_dataset_cleaned.csv')
33
+ faiss_index_path = os.path.join(DATA_FOLDER, 'faiss_index.idx')
34
+ doc_metadata_path = os.path.join(DATA_FOLDER, 'doc_metadata.pkl')
35
+ doc_vectors_path = os.path.join(DATA_FOLDER, 'doc_vectors.npy')
36
+
37
+ # Load the dataset
38
+ df = pd.read_csv(csv_path).dropna(subset=['chunk_text'])
39
+
40
+ # -------------------------------
41
+ # Model Initialization
42
+ # -------------------------------
43
+
44
+ fast_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
45
+ biobert = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
46
+
47
+ # -------------------------------
48
+ # Function: Embed and Build FAISS Index
49
+ # -------------------------------
50
+
51
+ def Embed_and_FAISS():
52
+ """
53
+ Embeds the drug dataset and builds a FAISS index for fast retrieval.
54
+ Saves the index, metadata, and document vectors to disk.
55
+ """
56
+ print("Embedding document chunks using fast embedder...")
57
+
58
+ # Build full context strings
59
+ df['full_text'] = df.apply(lambda x: f"{x['drug_name']} | {x['section']} > {x['subsection']} | {x['chunk_text']}", axis=1)
60
+
61
+ full_texts = df['full_text'].tolist()
62
+ doc_embeddings = fast_embedder.encode(full_texts, convert_to_numpy=True, show_progress_bar=True)
63
+
64
+ # Normalize embeddings and build index
65
+ doc_embeddings = normalize(doc_embeddings, axis=1, norm='l2')
66
+ dimension = doc_embeddings.shape[1]
67
+ index = faiss.IndexFlatIP(dimension)
68
+ index.add(doc_embeddings)
69
+
70
+ # Save index and metadata
71
+ faiss.write_index(index, faiss_index_path)
72
+ df.to_pickle(doc_metadata_path)
73
+ np.save(doc_vectors_path, doc_embeddings)
74
+
75
+ print("FAISS index built and saved successfully.")
76
+
77
+ # -------------------------------
78
+ # Function: Retrieve with Context and Averaged Embeddings
79
+ # -------------------------------
80
+
81
+ def retrieve_with_context_averagedembeddings(query, top_k=10, predicted_intent=None, detected_entities=None, alpha=0.8):
82
+ """
83
+ Retrieve top chunks using FAISS followed by reranking with BioBERT similarity.
84
+
85
+ Parameters:
86
+ query (str): User query text.
87
+ top_k (int): Number of top results to retrieve.
88
+ predicted_intent (str, optional): Detected intent to adjust retrieval.
89
+ detected_entities (list, optional): List of named entities.
90
+ alpha (float): Weight for combining query and intent embeddings.
91
+
92
+ Returns:
93
+ pd.DataFrame: Retrieved chunks with metadata and reranked scores.
94
+ """
95
+ print(f"[Retrieval Pipeline Started] Query: {query}")
96
+
97
+ # Embed and normalize the query
98
+ query_vec = fast_embedder.encode([query], convert_to_numpy=True)
99
+
100
+ if predicted_intent:
101
+ intent_vec = fast_embedder.encode([predicted_intent], convert_to_numpy=True)
102
+ query_vec = normalize((alpha * query_vec + (1 - alpha) * intent_vec), axis=1)
103
+
104
+ # Load FAISS index and search
105
+ index = faiss.read_index(faiss_index_path)
106
+ D, I = index.search(query_vec, top_k)
107
+
108
+ df_meta = pd.read_pickle(doc_metadata_path)
109
+ retrieved_df = df_meta.loc[I[0]].copy()
110
+ retrieved_df['faiss_score'] = D[0]
111
+
112
+ # BioBERT reranking
113
+ query_emb = biobert.encode(query, convert_to_tensor=True)
114
+ chunk_embs = biobert.encode(retrieved_df['full_text'].tolist(), convert_to_tensor=True)
115
+ cos_scores = util.pytorch_cos_sim(query_emb, chunk_embs)[0]
116
+ reranked_idx = torch.argsort(cos_scores, descending=True)
117
+
118
+ # Boost scores based on intent, subsection match, or entity presence
119
+ results = []
120
+ for idx in reranked_idx:
121
+ idx = int(idx)
122
+ row = retrieved_df.iloc[idx]
123
+ score = cos_scores[idx].item()
124
+
125
+ section = row['section'][0] if isinstance(row['section'], tuple) else row['section']
126
+ subsection = row['subsection'][0] if isinstance(row['subsection'], tuple) else row['subsection']
127
+ if isinstance(predicted_intent, tuple):
128
+ predicted_intent = predicted_intent[0]
129
+
130
+ if predicted_intent and section.strip().lower() == predicted_intent.strip().lower():
131
+ score += 0.05
132
+ if predicted_intent and predicted_intent.lower() in subsection.strip().lower():
133
+ score += 0.03
134
+ if detected_entities:
135
+ if any(ent.lower() in row['chunk_text'].lower() for ent in detected_entities):
136
+ score += 0.1
137
+
138
+ results.append({
139
+ 'chunk_id': row['chunk_id'],
140
+ 'drug_name': row['drug_name'],
141
+ 'section': row['section'],
142
+ 'subsection': row['subsection'],
143
+ 'chunk_text': row['chunk_text'],
144
+ 'faiss_score': row['faiss_score'],
145
+ 'semantic_similarity_score': score
146
+ })
147
+
148
+ return pd.DataFrame(results)
149
+
150
+ # -------------------------------
151
+ # Function: Retrieval Wrapper
152
+ # -------------------------------
153
+
154
+ def Retrieval_averagedQP(raw_query, intent, entities, top_k=10, alpha=0.8):
155
+ """
156
+ Wrapper to retrieve top-k chunks given a raw user query.
157
+
158
+ Parameters:
159
+ raw_query (str): The user query.
160
+ intent (str): Predicted intent from query processing.
161
+ entities (list): Detected biomedical entities.
162
+ top_k (int): Number of top results to return.
163
+ alpha (float): Weighting between query and intent embeddings.
164
+
165
+ Returns:
166
+ pd.DataFrame: Top retrieved chunks with scores.
167
+ """
168
+ results_df = retrieve_with_context_averagedembeddings(
169
+ raw_query,
170
+ top_k=top_k,
171
+ predicted_intent=intent,
172
+ detected_entities=entities,
173
+ alpha=alpha
174
+ )
175
+ return results_df[['chunk_id', 'drug_name', 'section', 'subsection', 'chunk_text', 'faiss_score', 'semantic_similarity_score']]
Scripts/app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Medical Drug QA Chatbot - Gradio Interface
3
+ Optimized for Hugging Face Spaces Deployment
4
+ """
5
+
6
+ """
7
+ Medical Drug QA Chatbot - Gradio Interface
8
+ """
9
+
10
+ import gradio as gr
11
+ import os
12
+ import sys
13
+
14
+ # This ensures the imports work correctly
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ sys.path.insert(0, current_dir)
17
+
18
+ from Query_processing import preprocess_query
19
+ from Retrieval import Retrieval_averagedQP
20
+ from Answer_Generation import answer_generation
21
+
22
+ # Rest of your code stays exactly the same...
23
+
24
+ # Lazy imports - only load when needed
25
+ _query_processor = None
26
+ _retrieval_system = None
27
+ _answer_generator = None
28
+
29
+ def initialize_models():
30
+ """Lazy loading of models to speed up startup."""
31
+ global _query_processor, _retrieval_system, _answer_generator
32
+
33
+ if _query_processor is None:
34
+ print("[App] Loading query processor...")
35
+ from Query_processing import preprocess_query
36
+ _query_processor = preprocess_query
37
+
38
+ if _retrieval_system is None:
39
+ print("[App] Loading retrieval system...")
40
+ from Retrieval import Retrieval_averagedQP
41
+ _retrieval_system = Retrieval_averagedQP
42
+
43
+ if _answer_generator is None:
44
+ print("[App] Loading answer generator...")
45
+ from Answer_Generation import answer_generation
46
+ _answer_generator = answer_generation
47
+
48
+ return _query_processor, _retrieval_system, _answer_generator
49
+
50
+
51
+ def chat_agent(message: str, history: list) -> tuple:
52
+ """
53
+ Main chat function with error handling and loading states.
54
+
55
+ Parameters:
56
+ message (str): User's question
57
+ history (list): Chat history
58
+
59
+ Returns:
60
+ tuple: (empty string, updated history)
61
+ """
62
+ if not message or message.strip() == "":
63
+ return "", history
64
+
65
+ try:
66
+ # Initialize models
67
+ preprocess_query, Retrieval_averagedQP, answer_generation = initialize_models()
68
+
69
+ # Step 1: Query Processing
70
+ print(f"[Chat] Processing query: {message}")
71
+ intent, entities = preprocess_query(message)
72
+
73
+ # Step 2: Retrieval
74
+ print(f"[Chat] Retrieving relevant chunks...")
75
+ chunks = Retrieval_averagedQP(message, intent, entities, top_k=10, alpha=0.8)
76
+
77
+ if chunks.empty:
78
+ error_msg = "⚠️ Sorry, I couldn't find relevant information in the database. Please try rephrasing your question."
79
+ history.append({"role": "user", "content": message})
80
+ history.append({"role": "assistant", "content": error_msg})
81
+ return "", history
82
+
83
+ # Step 3: Answer Generation
84
+ print(f"[Chat] Generating answer...")
85
+ answer = answer_generation(message, chunks, top_k=3)
86
+
87
+ # Format context for display
88
+ context = "\n\n".join([
89
+ f"**{row['drug_name']} | {row['section']} > {row['subsection']}**\n"
90
+ f"{row['chunk_text'][:200]}{'...' if len(row['chunk_text']) > 200 else ''}\n"
91
+ f"*Relevance Score: {round(row['semantic_similarity_score'], 3)}*"
92
+ for i, row in chunks.head(3).iterrows()
93
+ ])
94
+
95
+ # Add to history
96
+ history.append({"role": "user", "content": message})
97
+ history.append({"role": "assistant", "content": answer})
98
+ history.append({
99
+ "role": "assistant",
100
+ "content": f"<details><summary>📚 View Source Chunks</summary>\n\n{context}\n\n</details>"
101
+ })
102
+
103
+ print(f"[Chat] ✓ Response generated successfully")
104
+ return "", history
105
+
106
+ except Exception as e:
107
+ print(f"[Chat] ERROR: {e}")
108
+ import traceback
109
+ traceback.print_exc()
110
+
111
+ error_msg = f"❌ An error occurred: {str(e)}\n\nPlease try again or rephrase your question."
112
+ history.append({"role": "user", "content": message})
113
+ history.append({"role": "assistant", "content": error_msg})
114
+ return "", history
115
+
116
+
117
+ # Build Gradio Interface
118
+ with gr.Blocks(
119
+ theme=gr.themes.Soft(primary_hue="cyan"),
120
+ title="Medical Drug QA Chatbot",
121
+ css="""
122
+ .info-container, .info-footer {
123
+ width: 90%;
124
+ max-width: 1000px;
125
+ margin: 0 auto;
126
+ }
127
+ details.info-section, details.about-section {
128
+ background: white;
129
+ border-radius: 12px;
130
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
131
+ margin: 1em 0;
132
+ padding: 0;
133
+ }
134
+ details > summary {
135
+ padding: 1em 1.5em;
136
+ font-size: 1.1em;
137
+ font-weight: bold;
138
+ color: #00838f;
139
+ cursor: pointer;
140
+ border-radius: 12px;
141
+ transition: background-color 0.3s ease;
142
+ }
143
+ details > summary:hover {
144
+ background-color: #e0f7fa;
145
+ }
146
+ .disclaimer {
147
+ background: #fff3cd;
148
+ border: 1px solid #ffc107;
149
+ border-radius: 8px;
150
+ padding: 1em;
151
+ margin: 1em 0;
152
+ }
153
+ """
154
+ ) as demo:
155
+
156
+ # Header
157
+ gr.Markdown("# 💊 Medical Drug QA Chatbot")
158
+ gr.Markdown("_Ask questions about medications and get reliable answers from trusted medical sources._")
159
+
160
+ # Instructions
161
+ with gr.Accordion("🤔 How to Use", open=False):
162
+ gr.Markdown("""
163
+ Simply type your question about any medication. You can ask about:
164
+ - **Side effects** and warnings
165
+ - **Dosage** and usage instructions
166
+ - **Drug interactions**
167
+ - **Storage** and handling
168
+ - **Precautions** for specific conditions
169
+
170
+ ### 💡 Example Questions:
171
+ - "What are the common side effects of Aspirin?"
172
+ - "How should I store Insulin?"
173
+ - "What precautions should I take with Lisinopril?"
174
+ - "Can I drink alcohol while taking Metformin?"
175
+ """)
176
+
177
+ # Chatbot
178
+ chatbot = gr.Chatbot(
179
+ type="messages",
180
+ height=500,
181
+ label="Chat",
182
+ show_label=False,
183
+ avatar_images=(None, "🤖")
184
+ )
185
+
186
+ # Input
187
+ with gr.Row():
188
+ msg = gr.Textbox(
189
+ placeholder="Ask your medical question here...",
190
+ scale=9,
191
+ container=False,
192
+ show_label=False
193
+ )
194
+ submit = gr.Button("Send", scale=1, variant="primary")
195
+
196
+ with gr.Row():
197
+ clear = gr.Button("🗑️ Clear Chat", scale=1)
198
+
199
+ # Event handlers
200
+ msg.submit(
201
+ fn=chat_agent,
202
+ inputs=[msg, chatbot],
203
+ outputs=[msg, chatbot],
204
+ )
205
+
206
+ submit.click(
207
+ fn=chat_agent,
208
+ inputs=[msg, chatbot],
209
+ outputs=[msg, chatbot],
210
+ )
211
+
212
+ clear.click(
213
+ fn=lambda: (None, []),
214
+ inputs=None,
215
+ outputs=[msg, chatbot],
216
+ )
217
+
218
+ # About section
219
+ with gr.Accordion("📚 About This Project", open=False):
220
+ gr.Markdown("""
221
+ This Medical Drug QA system uses advanced NLP technologies:
222
+
223
+ - **Data Source**: Mayo Clinic's comprehensive drug database
224
+ - **NER**: BioBERT for chemical/drug entity recognition
225
+ - **Retrieval**: Hybrid system with MiniLM-V6 + BioBERT reranking
226
+ - **Answer Generation**: Llama-4 via Groq API
227
+
228
+ **Technologies**: Transformers, FAISS, Sentence-BERT, Gradio
229
+ """)
230
+
231
+ # Disclaimer
232
+ gr.HTML("""
233
+ <div class="disclaimer">
234
+ <strong>⚠️ Medical Disclaimer</strong>: This chatbot provides educational information only.
235
+ It should NOT be used as a substitute for professional medical advice, diagnosis, or treatment.
236
+ Always consult a qualified healthcare provider for medical decisions.
237
+ </div>
238
+ """)
239
+
240
+ # Launch
241
+ if __name__ == "__main__":
242
+ demo.queue() # Enable queuing for better performance
243
+ demo.launch(
244
+ share=False, # Set to False for HF Spaces
245
+ show_error=True
246
+ )
Scripts/demo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Execution Script for Retrieval-based Medical QA Chatbot
3
+ ============================================================
4
+
5
+ This script handles:
6
+ 1. Query preprocessing
7
+ 2. Information retrieval
8
+ 3. Answer generation
9
+ """
10
+
11
+ import warnings
12
+ warnings.filterwarnings("ignore", category=UserWarning)
13
+
14
+ from Query_processing import preprocess_query
15
+ from Retrieval import Retrieval_averagedQP
16
+ from Answer_Generation import answer_generation
17
+ from Retrieval import Embed_and_FAISS
18
+
19
+ # -------------------------------
20
+ # Optional: Embed and Store FAISS Index
21
+ # -------------------------------
22
+ # Uncomment the below line to generate embeddings and build the FAISS index if not already done.
23
+ # Embed_and_FAISS()
24
+
25
+ # -------------------------------
26
+ # Define User Question
27
+ # -------------------------------
28
+
29
+ Question = "how much dosage of ibuprofen should be taken for treatment of fever?"
30
+
31
+ # -------------------------------
32
+ # Step 1: Query Preprocessing
33
+ # -------------------------------
34
+
35
+ intent, entities = preprocess_query(Question)
36
+
37
+ # -------------------------------
38
+ # Step 2: Retrieve Relevant Chunks
39
+ # -------------------------------
40
+
41
+ top_chunks = Retrieval_averagedQP(Question, intent, entities, top_k=10, alpha=0.8)
42
+
43
+ # -------------------------------
44
+ # Step 3: Answer Generation
45
+ # -------------------------------
46
+
47
+ Generated_answer = answer_generation(Question, top_chunks, top_k=3)
48
+
49
+ # -------------------------------
50
+ # Display Generated Answer
51
+ # -------------------------------
52
+
53
+ print("Generated Answer:", Generated_answer)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Framework
2
+ gradio>=4.0.0
3
+
4
+ # Data Processing
5
+ pandas>=2.0.0
6
+ numpy>=1.24.0
7
+
8
+ # NLP & ML
9
+ torch>=2.0.0
10
+ transformers>=4.35.0
11
+ sentence-transformers>=2.2.0
12
+ scikit-learn>=1.3.0
13
+
14
+ # Vector Search
15
+ faiss-cpu>=1.7.4
16
+
17
+ # API Client
18
+ openai>=1.0.0
19
+
20
+ # Optional Performance
21
+ accelerate>=0.24.0
22
+ sentencepiece>=0.1.99