eksemyashkina's picture
Added files
f096e52
from typing import Dict
import gradio as gr
import json
import PIL.Image, PIL.ImageOps
import torch
import torchvision.transforms.functional as F
from src.models.resnet50 import ResNet
from src.models.mobilenet_v2 import MobileNetV2
num_classes = 30
model1 = ResNet(weights_path="weights/checkpoint-best-resnet.pth", num_classes=num_classes)
model1.eval()
model2 = MobileNetV2(weights_path="weights/checkpoint-best-mobilenet.pth", num_classes=num_classes)
model2.eval()
with open("labels.json", "r") as f:
class_labels = json.load(f)
label_mapping = {v: k for k, v in class_labels.items()}
def predict(img, model_choice) -> Dict[str, float]:
model = model1 if model_choice == "ResNet" else model2
width, height = img.size
max_dim = max(width, height)
padding = (max_dim - width, max_dim - height)
img = PIL.ImageOps.expand(img, padding, (255, 255, 255))
img = img.resize((224, 224))
img = F.to_tensor(img)
img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img = img.unsqueeze(0)
with torch.inference_mode():
logits = model.forward(img)
probs = torch.nn.functional.softmax(logits, dim=1)
top_probs, top_indices = probs[0].topk(3)
top_classes = {label_mapping[idx.item()]: prob.item() for idx, prob in zip(top_indices, top_probs)}
return top_classes
examples = [
["assets/banana.jpg"],
["assets/pineapple.jpg"],
["assets/mango.jpg"],
["assets/melon.jpg"],
["assets/orange.jpg"],
["assets/eggplant.jpg"],
["assets/black.jpg"],
["assets/white.jpg"]
]
with gr.Blocks() as demo:
gr.Markdown("## Plant Classification")
with gr.Row():
with gr.Column():
pic = gr.Image(label="Upload Plant Image", type="pil", height=300, width=300)
model_choice = gr.Dropdown(choices=["ResNet", "MobileNetV2"], label="Select Model", value="ResNet")
with gr.Row():
with gr.Column(scale=1):
predict_btn = gr.Button("Predict")
with gr.Column(scale=1):
clear_btn = gr.Button("Clear")
with gr.Column():
output = gr.Label(label="Top 3 Predicted Classes")
predict_btn.click(fn=predict, inputs=[pic, model_choice], outputs=output, api_name="predict")
clear_btn.click(lambda: (None, None), outputs=[pic, output])
gr.Examples(examples=examples, inputs=[pic])
demo.launch()