|
|
import streamlit as st |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import os |
|
|
import shutil |
|
|
|
|
|
|
|
|
cache_dir = "/tmp/biogpt_app_cache" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["STREAMLIT_CACHE_DIR"] = "/tmp/streamlit_cache" |
|
|
os.environ["HF_HOME"] = cache_dir |
|
|
os.environ["TRANSFORMERS_CACHE"] = cache_dir |
|
|
os.environ["XDG_CACHE_HOME"] = cache_dir |
|
|
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false" |
|
|
|
|
|
|
|
|
if os.path.exists(cache_dir): |
|
|
try: |
|
|
st.info("Clearing old cache to ensure a fresh download...") |
|
|
shutil.rmtree(cache_dir) |
|
|
except Exception as e: |
|
|
st.error(f"Failed to clear old cache. Please check directory permissions. Error: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
except Exception as e: |
|
|
st.error(f"Failed to create cache directory at {cache_dir}. Error: {e}") |
|
|
st.stop() |
|
|
|
|
|
st.set_page_config(page_title="BioGPT-PubMedQA Chatbot", layout="centered") |
|
|
st.title("🧬 BioGPT-PubMedQA Chatbot") |
|
|
st.write("A fine-tuned BioGPT model for biomedical Q&A.") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(cache_directory): |
|
|
""" |
|
|
Loads the tokenizer and model from Hugging Face Hub, |
|
|
explicitly using the specified cache directory. |
|
|
""" |
|
|
model_name = "kirubel1738/biogpt-pubmedqa-finetuned" |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_directory) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
cache_dir=cache_directory |
|
|
) |
|
|
return tokenizer, model |
|
|
except Exception as e: |
|
|
st.error(f"Failed to load model. Please ensure the model name is correct and it is publicly accessible.") |
|
|
st.exception(e) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer, model = load_model(cache_dir) |
|
|
except Exception as e: |
|
|
st.error(f"An unexpected error occurred during model loading: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state["messages"] = [] |
|
|
|
|
|
|
|
|
for msg in st.session_state["messages"]: |
|
|
with st.chat_message(msg["role"]): |
|
|
st.markdown(msg["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask me a biomedical question..."): |
|
|
st.session_state["messages"].append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
formatted_prompt = f"""### Question:{prompt}### Answer:""" |
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=200, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "### Answer:" in decoded: |
|
|
answer = decoded.split("### Answer:")[-1].strip() |
|
|
else: |
|
|
answer = decoded.strip() |
|
|
|
|
|
st.session_state["messages"].append({"role": "assistant", "content": answer}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(answer) |
|
|
|