kirubel1738 commited on
Commit
7e75a84
·
verified ·
1 Parent(s): ac9c332

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +96 -144
src/streamlit_app.py CHANGED
@@ -1,156 +1,108 @@
1
- # streamlit_app.py
 
 
2
  import os
3
- import json
4
- import time
5
-
6
- # -----------------------------
7
- # IMPORTANT: set cache dirs BEFORE importing transformers/huggingface_hub
8
- # -----------------------------
9
- os.environ.setdefault("HF_HOME", os.environ.get("HF_HOME", "/tmp/huggingface"))
10
- os.environ.setdefault("TRANSFORMERS_CACHE", os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers"))
11
- os.environ.setdefault("HF_DATASETS_CACHE", os.environ.get("HF_DATASETS_CACHE", "/tmp/huggingface/datasets"))
12
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", os.environ.get("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub"))
13
- os.environ.setdefault("XDG_CACHE_HOME", os.environ.get("XDG_CACHE_HOME", "/tmp/huggingface"))
14
- os.environ.setdefault("HOME", os.environ.get("HOME", "/tmp"))
15
-
16
- # create cache dirs (best-effort)
17
- for d in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["HF_DATASETS_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"]]:
 
18
  try:
19
- os.makedirs(d, exist_ok=True)
20
- os.chmod(d, 0o777)
21
- except Exception:
22
- pass
23
-
24
- import streamlit as st
25
- import requests
26
-
27
- # Optional heavy imports will be inside local-model branch
28
- LOCAL_MODE = os.environ.get("USE_LOCAL_MODEL", "0") == "1"
29
-
30
- # default model id the user provided; keep as-is
31
- DEFAULT_MODEL_ID = "kirubel1738/biogpt-pubmedqa-finetuned"
32
-
33
- st.set_page_config(page_title="BioGPT (PubMedQA) demo", layout="centered")
34
-
35
- st.title("BioGPT — PubMedQA demo")
36
- st.caption("Defaults to the Hugging Face Inference API (recommended for Spaces / CPU).")
37
-
38
- st.markdown(
39
- """
40
- **How it works**
41
- - By default the app will call Hugging Face's Inference API for the model you specify (fast and avoids memory issues).
42
- - If you set `USE_LOCAL_MODEL=1` in your environment, the app will attempt to load the model locally using `transformers` (only for GPUs/large memory machines).
43
- """
44
- )
45
-
46
- col1, col2 = st.columns([3,1])
47
-
48
- with col1:
49
- model_id = st.text_input("Model repo id", value=DEFAULT_MODEL_ID, help="Hugging Face repo id (e.g. username/modelname).")
50
- prompt = st.text_area("Question / prompt", height=180, placeholder="Enter a PubMed-style question or prompt...")
51
- with col2:
52
- max_new_tokens = st.slider("Max new tokens", 16, 1024, 128)
53
- temperature = st.slider("Temperature", 0.0, 1.5, 0.0, step=0.05)
54
- method = st.radio("Run method", ("Inference API (recommended)", "Local model (heavy)"), index=0)
55
 
56
- # override radio if user set USE_LOCAL_MODEL env var
57
- if LOCAL_MODE:
58
- method = "Local model (heavy)"
 
 
 
59
 
60
- hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_API_TOKEN")
 
 
61
 
62
- def call_inference_api(model_id: str, prompt: str, max_new_tokens: int, temperature: float):
 
 
63
  """
64
- Simple POST to Hugging Face Inference API.
65
- If you want to use the InferenceClient from huggingface_hub you can swap this.
66
  """
67
- api_url = f"https://api-inference.huggingface.co/models/{model_id}"
68
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
69
- payload = {
70
- "inputs": prompt,
71
- "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature},
72
- "options": {"wait_for_model": True}
73
- }
74
  try:
75
- r = requests.post(api_url, headers=headers, json=payload, timeout=120)
 
 
 
 
 
 
 
76
  except Exception as e:
77
- return False, f"Request failed: {e}"
78
- if r.status_code != 200:
79
- try:
80
- error = r.json()
81
- except Exception:
82
- error = r.text
83
- return False, f"API error ({r.status_code}): {error}"
84
- try:
85
- resp = r.json()
86
- # handle several possible response schemas
87
- if isinstance(resp, dict) and "error" in resp:
88
- return False, resp["error"]
89
- # often it's a list of dicts with 'generated_text'
90
- if isinstance(resp, list):
91
- out_texts = []
92
- for item in resp:
93
- if isinstance(item, dict):
94
- # common key: 'generated_text'
95
- for k in ("generated_text", "text", "content"):
96
- if k in item:
97
- out_texts.append(item[k])
98
- break
99
- else:
100
- out_texts.append(json.dumps(item))
101
- else:
102
- out_texts.append(str(item))
103
- return True, "\n\n".join(out_texts)
104
- # fallback
105
- return True, str(resp)
106
- except Exception as e:
107
- return False, f"Could not parse response: {e}"
108
-
109
- # Local model loader (only if method chosen)
110
- generator = None
111
- if method.startswith("Local"):
112
- st.warning("Local model mode selected — this requires transformers + torch and lots of RAM/GPU. Only use if you know the model fits your hardware.")
113
- try:
114
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
115
- import torch
116
- device = 0 if torch.cuda.is_available() else -1
117
- st.info(f"torch.cuda.is_available={torch.cuda.is_available()} -- device set to {device}")
118
- with st.spinner("Loading tokenizer & model (this can take a while)..."):
119
- tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"))
120
- model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=os.environ.get("TRANSFORMERS_CACHE"), low_cpu_mem_usage=True)
121
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
122
- except Exception as e:
123
- st.error(f"Local model load failed: {e}")
124
  st.stop()
125
 
126
- if st.button("Generate"):
127
- if not prompt or prompt.strip() == "":
128
- st.error("Please enter a prompt.")
129
- st.stop()
130
-
131
- if method.startswith("Inference"):
132
- if ("kirubel1738/biogpt-pubmedqa-finetuned" in model_id) and not hf_token:
133
- st.info("If the model is private or rate-limited, set HUGGINGFACE_HUB_TOKEN as a secret in Spaces or as an env var locally.")
134
- with st.spinner("Querying Hugging Face Inference API..."):
135
- ok, out = call_inference_api(model_id, prompt, max_new_tokens, float(temperature))
136
- if not ok:
137
- st.error(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  else:
139
- st.success("Done")
140
- st.text_area("Model output", value=out, height=320)
141
- else:
142
- # local model generation
143
- try:
144
- with st.spinner("Running local generation..."):
145
- results = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature)
146
- if isinstance(results, list) and len(results) > 0 and "generated_text" in results[0]:
147
- out = results[0]["generated_text"]
148
- else:
149
- out = str(results)
150
- st.success("Done")
151
- st.text_area("Model output", value=out, height=320)
152
- except Exception as e:
153
- st.error(f"Local generation failed: {e}")
154
-
155
- st.markdown("---")
156
- st.caption("If you run into permissions errors in Spaces, ensure the HF cache env vars above point to a writable directory (we already set them to /tmp/huggingface in this container).")
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
  import os
5
+ import shutil
6
+
7
+ # Define the custom cache directory for Hugging Face models
8
+ cache_dir = "/tmp/biogpt_app_cache"
9
+
10
+ # --- PROACTIVE CACHE CLEARING ---
11
+ # Set environment variables to point Hugging Face and Streamlit to our custom cache directory
12
+ # This is done to prevent PermissionErrors in read-only environments.
13
+ os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache"
14
+ os.environ["HF_HOME"] = cache_dir
15
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
16
+ os.environ["XDG_CACHE_HOME"] = cache_dir
17
+ os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
18
+
19
+ # Clear the cache directory before attempting to download the model.
20
+ if os.path.exists(cache_dir):
21
  try:
22
+ st.info("Clearing old cache to ensure a fresh download...")
23
+ shutil.rmtree(cache_dir)
24
+ except Exception as e:
25
+ st.error(f"Failed to clear old cache. Please check directory permissions. Error: {e}")
26
+ st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Ensure the new cache directory exists before model loading
29
+ try:
30
+ os.makedirs(cache_dir, exist_ok=True)
31
+ except Exception as e:
32
+ st.error(f"Failed to create cache directory at {cache_dir}. Error: {e}")
33
+ st.stop()
34
 
35
+ st.set_page_config(page_title="BioGPT-PubMedQA Chatbot", layout="centered")
36
+ st.title("🧬 BioGPT-PubMedQA Chatbot")
37
+ st.write("A fine-tuned BioGPT model for biomedical Q&A.")
38
 
39
+ # Load model once using Streamlit's resource caching
40
+ @st.cache_resource
41
+ def load_model(cache_directory):
42
  """
43
+ Loads the tokenizer and model from Hugging Face Hub,
44
+ explicitly using the specified cache directory.
45
  """
46
+ model_name = "kirubel1738/biogpt-pubmedqa-finetuned"
47
+
 
 
 
 
 
48
  try:
49
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_directory)
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ model_name,
52
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
53
+ device_map="auto",
54
+ cache_dir=cache_directory
55
+ )
56
+ return tokenizer, model
57
  except Exception as e:
58
+ st.error(f"Failed to load model. Please ensure the model name is correct and it is publicly accessible.")
59
+ st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  st.stop()
61
 
62
+ # Load the model, passing the cache directory
63
+ try:
64
+ tokenizer, model = load_model(cache_dir)
65
+ except Exception as e:
66
+ st.error(f"An unexpected error occurred during model loading: {e}")
67
+ st.stop()
68
+
69
+ # Maintain chat history
70
+ if "messages" not in st.session_state:
71
+ st.session_state["messages"] = []
72
+
73
+ # Display chat history
74
+ for msg in st.session_state["messages"]:
75
+ with st.chat_message(msg["role"]):
76
+ st.markdown(msg["content"])
77
+
78
+ # Input box for user
79
+ if prompt := st.chat_input("Ask me a biomedical question..."):
80
+ st.session_state["messages"].append({"role": "user", "content": prompt})
81
+
82
+ with st.chat_message("user"):
83
+ st.markdown(prompt)
84
+
85
+ formatted_prompt = f"""### Question:{prompt}### Answer:"""
86
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
87
+
88
+ with st.spinner("Thinking..."):
89
+ with torch.no_grad():
90
+ outputs = model.generate(
91
+ **inputs,
92
+ max_new_tokens=200,
93
+ do_sample=True,
94
+ temperature=0.7,
95
+ top_p=0.9,
96
+ eos_token_id=tokenizer.eos_token_id,
97
+ )
98
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ if "### Answer:" in decoded:
101
+ answer = decoded.split("### Answer:")[-1].strip()
102
  else:
103
+ answer = decoded.strip()
104
+
105
+ st.session_state["messages"].append({"role": "assistant", "content": answer})
106
+
107
+ with st.chat_message("assistant"):
108
+ st.markdown(answer)