timothytzkung commited on
Commit
e93d19d
·
verified ·
1 Parent(s): f129043

Update app.py

Browse files

- Added optimizations

Files changed (1) hide show
  1. app.py +70 -104
app.py CHANGED
@@ -1,158 +1,124 @@
1
  import json
2
  import numpy as np
3
  import pandas as pd
4
-
5
- from transformers import pipeline
6
  from sentence_transformers import SentenceTransformer
7
  import gradio as gr
8
  import torch
9
  from huggingface_hub import login
10
  import os
11
 
12
- # Sanity Check
13
  hf_token = os.getenv("V2_TOKEN")
14
  if hf_token is None:
15
- raise RuntimeError("V2_TOKEN environment variable is not set in this Space.")
16
 
17
- # Explicit login
18
  login(token=hf_token)
 
19
 
20
- # --- Configuration ---
21
- print("Loading RAG system on your device...")
22
 
23
- # Load Knowledge base
 
24
  FILE_PATH = "data.jsonl"
25
  PRELOAD_FILE_PATH = "preload-data.json"
26
 
27
- # Load data
28
- print(f"Found Preloaded Data! Using {PRELOAD_FILE_PATH}...")
29
- with open(PRELOAD_FILE_PATH, "r", encoding="utf-8") as f:
30
- data = json.load(f)
31
-
32
- # Set data
33
- documents = data
34
 
35
- # Embeddings
36
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
37
- embeddings = embedding_model.encode(documents, convert_to_numpy=True)
38
-
39
- # Use pandas dataframe
40
- df = pd.DataFrame(
41
- {
42
- "Document": documents,
43
- "Embedding": list(embeddings), # store as list
44
- }
45
- )
46
 
47
- # Load LLM Pipeline
48
- llm = pipeline(
49
- "text-generation",
50
- model="google/gemma-3-4b-it", # Might not have enough storage ngl
51
- token=hf_token
52
- )
53
 
54
- def clean_query_with_llm(query):
55
- prompt_content = f"""
56
- Below is a new question asked by the user that needs to be answered by searching in a knowledge base.
57
- You have access to SFU IT Knowledge Base index with 100's of chunked documents.
58
- Generate a search question based the user's question.
59
- If you cannot generate a search query, return just the number 0.
60
 
61
- User's Question:
62
- {query}
 
 
63
 
64
- Search Query:
65
- """
66
 
67
- response = llm(
68
- prompt_content,
69
- max_new_tokens=100,
70
- do_sample=False,
71
- return_full_text=False
72
- )
73
- return response[0]["generated_text"].strip()
 
 
74
 
 
 
 
 
 
 
 
 
75
 
76
- # Retrieve w Pandas
77
- def retrieve_with_pandas(query: str, top_k: int = 5):
78
  """
79
- Embed the query, compute cosine similarity to each document,
80
- and return the top_k most similar documents (as a DataFrame).
81
  """
 
82
  query_embedding = embedding_model.encode([query])[0]
 
 
 
 
 
 
 
 
 
83
 
84
- def cosine_sim(x):
85
- x = np.array(x)
86
- return float(
87
- np.dot(query_embedding, x)
88
- / (np.linalg.norm(query_embedding) * np.linalg.norm(x))
89
- )
90
-
91
- df["Similarity"] = df["Embedding"].apply(cosine_sim)
92
- results = df.sort_values(by="Similarity", ascending=False).head(top_k)
93
- return results[["Document", "Similarity"]]
94
-
95
-
96
- def generate_with_rag(query, top_k=5):
97
- # goSFU specific cleaning
98
  if "gosfu" in query.lower():
99
  query = query.replace("gosfu", "goSFU")
100
 
 
101
  # Retrieve
102
- search_query = clean_query_with_llm(query)
103
- results = retrieve_with_pandas(search_query)
104
 
105
- # Turn the Series into a single string of text
106
- # (each doc separated by a divider)
107
- context_str = "\n\n---\n\n".join(results["Document"].tolist())
108
-
109
- # Build a clean prompt
110
  prompt_content = f"""
111
  You are a SFU IT helpdesk chatbot.
112
- Your task is to answer SFU IT related questions such as accessing various technology services or general troubleshooting.
113
- Below is new question asked by the user, and related article chunks to the user question.
114
- If the user asked a question, answer the user's question with short step by step instructions: consider all the articles below.
115
- If there are links in the articles, provide those links in your answer.
116
- If the user asked a question and the answer is not in the contexts, say that you're sorry that you can't help them and suggest contacting SFU IT at 778-782-8888 or by submitting an inquiry ticket at https://www.sfu.ca/information-systems/get-help.html
117
- If the user DID NOT ask a question, be friendly and ask how you can help them.
118
- Do not recommend, suggest, or provide any advice on anything that is not related to SFU or SFU IT.
119
- If the user asked something relating to mental health or is seeking medical advice, redirect them to SFU Health & Counselling at https://www.sfu.ca/students/health.html
120
- Do not ask the user any follow-up questions after answering them.
121
-
122
- Question:
123
- {query}
124
-
125
- -- Start of Articles --
126
- {context_str}
127
-
128
- -- End of Articles --
129
-
130
- Answer:"""
131
-
132
- # Call the LLM
133
  response = llm(
134
  prompt_content,
135
- max_new_tokens=500,
136
  do_sample=False,
137
  return_full_text=False
138
  )
139
  return response[0]["generated_text"].strip()
140
-
141
 
142
  def chat_fn(message, history):
143
- """
144
- Chat Interface callback
145
- """
146
- answer = generate_with_rag(message, top_k=5)
147
- return answer
148
-
149
 
150
  demo = gr.ChatInterface(
151
  fn=chat_fn,
152
- title="SFU IT Chatbot",
153
  description="Enter your question and the SFU IT Chatbot will try to answer using retrieved SFU IT knowledge.",
154
  )
155
 
156
- # share=True
157
  if __name__ == "__main__":
158
  demo.launch()
 
1
  import json
2
  import numpy as np
3
  import pandas as pd
4
+ from transformers import pipeline, BitsAndBytesConfig
 
5
  from sentence_transformers import SentenceTransformer
6
  import gradio as gr
7
  import torch
8
  from huggingface_hub import login
9
  import os
10
 
11
+ # --- Setup & Configuration ---
12
  hf_token = os.getenv("V2_TOKEN")
13
  if hf_token is None:
14
+ raise RuntimeError("V2_TOKEN environment variable is not set.")
15
 
 
16
  login(token=hf_token)
17
+ PRELOAD_PARQUET = "preload.parquet"
18
 
19
+ print("Loading RAG system...")
 
20
 
21
+ # --- Load Data & Pre-compute Embeddings ---
22
+ # optimization: Ensure we aren't re-embedding every restart if possible.
23
  FILE_PATH = "data.jsonl"
24
  PRELOAD_FILE_PATH = "preload-data.json"
25
 
 
 
 
 
 
 
 
26
 
27
+ # Load Embedding Model
28
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
29
 
30
+ # Pre-calculate embeddings once and stack them into a numpy matrix for fast math nyoom
 
 
 
 
 
31
 
 
 
 
 
 
 
32
 
33
+ if not os.path.exists(PRELOAD_PARQUET):
34
+ print(f"Loading data from {PRELOAD_FILE_PATH}...")
35
+ with open(PRELOAD_FILE_PATH, "r", encoding="utf-8") as f:
36
+ documents = json.load(f)
37
 
38
+ print("Generating/Loading embeddings...")
39
+ doc_embeddings = embedding_model.encode(documents, convert_to_numpy=True)
40
 
41
+ # Normalize embeddings now so only need dot product later (faster than cosine calc every time I guess)
42
+ doc_embeddings = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
43
+
44
+ # Create DataFrame just for text storage (we will use numpy for math)
45
+ df = pd.DataFrame({"Document": documents})
46
+ tmp = df.to_parquet(PRELOAD_PARQUET)
47
+ else:
48
+ print("Parquet found!")
49
+ df = pd.read_parquet(PRELOAD_PARQUET, engine='fastparquet')
50
 
51
+ print("Parquet established.")
52
+
53
+ print("Loading LLM...")
54
+ llm = pipeline(
55
+ "text-generation",
56
+ model="google/gemma-3-1b-it",
57
+ token=hf_token,
58
+ )
59
 
60
+ # --- Optimized Retrieval Function ---
61
+ def retrieve_vectorized(query: str, top_k: int = 5):
62
  """
63
+ Uses Matrix Multiplication instead of Row-by-Row iteration.
 
64
  """
65
+ # Encode query
66
  query_embedding = embedding_model.encode([query])[0]
67
+
68
+ # Normalize query
69
+ query_norm = query_embedding / np.linalg.norm(query_embedding)
70
+ scores = np.dot(doc_embeddings, query_norm)
71
+ top_indices = np.argsort(scores)[::-1][:top_k]
72
+
73
+ # Retrieve documents
74
+ results = df.iloc[top_indices].copy()
75
+ return results["Document"].tolist()
76
 
77
+ # --- Main Generation Function ---
78
+ def generate_with_rag(query):
79
+ # goSFU specific cleaning
 
 
 
 
 
 
 
 
 
 
 
80
  if "gosfu" in query.lower():
81
  query = query.replace("gosfu", "goSFU")
82
 
83
+
84
  # Retrieve
85
+ retrieved_docs = retrieve_vectorized(query, top_k=5)
86
+ context_str = "\n\n---\n\n".join(retrieved_docs)
87
 
88
+ # Prompt
 
 
 
 
89
  prompt_content = f"""
90
  You are a SFU IT helpdesk chatbot.
91
+ Your task is to answer SFU IT related questions.
92
+
93
+ Context Articles:
94
+ {context_str}
95
+
96
+ User Question: {query}
97
+
98
+ Instructions:
99
+ 1. Answer the question using ONLY the Context Articles above.
100
+ 2. Provide step-by-step instructions and include relevant links found in the text.
101
+ 3. If the answer is not in the context, suggest contacting SFU IT at 778-782-8888.
102
+ 4. If the user is asking about mental health, redirect to SFU Health & Counselling.
103
+
104
+ Answer:"""
105
+
 
 
 
 
 
 
106
  response = llm(
107
  prompt_content,
108
+ max_new_tokens=300, # Reduced token count for speed
109
  do_sample=False,
110
  return_full_text=False
111
  )
112
  return response[0]["generated_text"].strip()
 
113
 
114
  def chat_fn(message, history):
115
+ return generate_with_rag(message)
 
 
 
 
 
116
 
117
  demo = gr.ChatInterface(
118
  fn=chat_fn,
119
+ title="SFU IT Chatbot (Optimized)",
120
  description="Enter your question and the SFU IT Chatbot will try to answer using retrieved SFU IT knowledge.",
121
  )
122
 
 
123
  if __name__ == "__main__":
124
  demo.launch()