biogpt-pubmedqa-chatbot / src /streamlit_app.py
kirubel1738's picture
Update src/streamlit_app.py
7e75a84 verified
raw
history blame
3.8 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import shutil
# Define the custom cache directory for Hugging Face models
cache_dir = "/tmp/biogpt_app_cache"
# --- PROACTIVE CACHE CLEARING ---
# Set environment variables to point Hugging Face and Streamlit to our custom cache directory
# This is done to prevent PermissionErrors in read-only environments.
os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache"
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
# Clear the cache directory before attempting to download the model.
if os.path.exists(cache_dir):
try:
st.info("Clearing old cache to ensure a fresh download...")
shutil.rmtree(cache_dir)
except Exception as e:
st.error(f"Failed to clear old cache. Please check directory permissions. Error: {e}")
st.stop()
# Ensure the new cache directory exists before model loading
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as e:
st.error(f"Failed to create cache directory at {cache_dir}. Error: {e}")
st.stop()
st.set_page_config(page_title="BioGPT-PubMedQA Chatbot", layout="centered")
st.title("🧬 BioGPT-PubMedQA Chatbot")
st.write("A fine-tuned BioGPT model for biomedical Q&A.")
# Load model once using Streamlit's resource caching
@st.cache_resource
def load_model(cache_directory):
"""
Loads the tokenizer and model from Hugging Face Hub,
explicitly using the specified cache directory.
"""
model_name = "kirubel1738/biogpt-pubmedqa-finetuned"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_directory)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
cache_dir=cache_directory
)
return tokenizer, model
except Exception as e:
st.error(f"Failed to load model. Please ensure the model name is correct and it is publicly accessible.")
st.exception(e)
st.stop()
# Load the model, passing the cache directory
try:
tokenizer, model = load_model(cache_dir)
except Exception as e:
st.error(f"An unexpected error occurred during model loading: {e}")
st.stop()
# Maintain chat history
if "messages" not in st.session_state:
st.session_state["messages"] = []
# Display chat history
for msg in st.session_state["messages"]:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Input box for user
if prompt := st.chat_input("Ask me a biomedical question..."):
st.session_state["messages"].append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
formatted_prompt = f"""### Question:{prompt}### Answer:"""
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with st.spinner("Thinking..."):
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### Answer:" in decoded:
answer = decoded.split("### Answer:")[-1].strip()
else:
answer = decoded.strip()
st.session_state["messages"].append({"role": "assistant", "content": answer})
with st.chat_message("assistant"):
st.markdown(answer)