wmhds-medical / app.py
LeduBaSK's picture
Update app.py
2c01e65 verified
raw
history blame contribute delete
837 Bytes
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMultipleChoice
import torch
# Load your model
tokenizer = AutoTokenizer.from_pretrained("LeduBaSK/wmhds")
model = AutoModelForMultipleChoice.from_pretrained("LeduBaSK/wmhds")
def predict(question, a, b, c, d):
inputs = tokenizer([question]*4, [a,b,c,d], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
pred = torch.argmax(outputs.logits).item()
return ["A", "B", "C", "D"][pred]
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Question"),
gr.Textbox(label="Option A"),
gr.Textbox(label="Option B"),
gr.Textbox(label="Option C"),
gr.Textbox(label="Option D")
],
outputs="text",
title="Medical QA with WMHDS"
).launch()