File size: 1,559 Bytes
97039ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import time
from PIL import Image
import gradio as gr
from transformers import pipeline

MODEL_MAP = {
    "ViT (Base/16, 224)": "google/vit-base-patch16-224",
    "ResNet-50":           "microsoft/resnet-50",
    "EfficientNet-B0":     "google/efficientnet-b0"
}

# Lazy-load to keep startup fast
_pipes = {}
def get_pipe(model_id: str):
    if model_id not in _pipes:
        _pipes[model_id] = pipeline("image-classification", model=model_id, top_k=5)
    return _pipes[model_id]

def predict(img: Image.Image, model_name: str):
    if img is None:
        return "Upload an image.", None
    model_id = MODEL_MAP[model_name]
    pipe = get_pipe(model_id)
    t0 = time.time()
    preds = pipe(img)
    latency_ms = int((time.time() - t0) * 1000)
    # Clean top-k dict for Gradio Label
    scores = {p["label"]: round(float(p["score"]), 3) for p in preds}
    return scores, f"{model_name} β€’ ~{latency_ms} ms"

with gr.Blocks(title="Image Classifier – Multi-Model") as demo:
    gr.Markdown("# 🐢🐱 Image Classifier (Multi-Model)\nUpload an image, choose a backbone, see top-5 predictions.")
    with gr.Row():
        with gr.Column():
            img = gr.Image(type="pil", label="Image")
            model = gr.Dropdown(list(MODEL_MAP.keys()), value="ViT (Base/16, 224)", label="Backbone")
            btn = gr.Button("Predict")
        with gr.Column():
            out = gr.Label(label="Top-5")
            info = gr.Markdown()
    btn.click(fn=predict, inputs=[img, model], outputs=[out, info])

if __name__ == "__main__":
    demo.launch()