classify_2 / app.py
hudaakram's picture
Create app.py
97039ed verified
import time
from PIL import Image
import gradio as gr
from transformers import pipeline
MODEL_MAP = {
"ViT (Base/16, 224)": "google/vit-base-patch16-224",
"ResNet-50": "microsoft/resnet-50",
"EfficientNet-B0": "google/efficientnet-b0"
}
# Lazy-load to keep startup fast
_pipes = {}
def get_pipe(model_id: str):
if model_id not in _pipes:
_pipes[model_id] = pipeline("image-classification", model=model_id, top_k=5)
return _pipes[model_id]
def predict(img: Image.Image, model_name: str):
if img is None:
return "Upload an image.", None
model_id = MODEL_MAP[model_name]
pipe = get_pipe(model_id)
t0 = time.time()
preds = pipe(img)
latency_ms = int((time.time() - t0) * 1000)
# Clean top-k dict for Gradio Label
scores = {p["label"]: round(float(p["score"]), 3) for p in preds}
return scores, f"{model_name} β€’ ~{latency_ms} ms"
with gr.Blocks(title="Image Classifier – Multi-Model") as demo:
gr.Markdown("# 🐢🐱 Image Classifier (Multi-Model)\nUpload an image, choose a backbone, see top-5 predictions.")
with gr.Row():
with gr.Column():
img = gr.Image(type="pil", label="Image")
model = gr.Dropdown(list(MODEL_MAP.keys()), value="ViT (Base/16, 224)", label="Backbone")
btn = gr.Button("Predict")
with gr.Column():
out = gr.Label(label="Top-5")
info = gr.Markdown()
btn.click(fn=predict, inputs=[img, model], outputs=[out, info])
if __name__ == "__main__":
demo.launch()