Spaces:
Sleeping
Sleeping
| from typing import Dict | |
| import gradio as gr | |
| import json | |
| import PIL.Image, PIL.ImageOps | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from src.models.resnet50 import ResNet | |
| from src.models.mobilenet_v2 import MobileNetV2 | |
| num_classes = 30 | |
| model1 = ResNet(weights_path="weights/checkpoint-best-resnet.pth", num_classes=num_classes) | |
| model1.eval() | |
| model2 = MobileNetV2(weights_path="weights/checkpoint-best-mobilenet.pth", num_classes=num_classes) | |
| model2.eval() | |
| with open("labels.json", "r") as f: | |
| class_labels = json.load(f) | |
| label_mapping = {v: k for k, v in class_labels.items()} | |
| def predict(img, model_choice) -> Dict[str, float]: | |
| model = model1 if model_choice == "ResNet" else model2 | |
| width, height = img.size | |
| max_dim = max(width, height) | |
| padding = (max_dim - width, max_dim - height) | |
| img = PIL.ImageOps.expand(img, padding, (255, 255, 255)) | |
| img = img.resize((224, 224)) | |
| img = F.to_tensor(img) | |
| img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| img = img.unsqueeze(0) | |
| with torch.inference_mode(): | |
| logits = model.forward(img) | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| top_probs, top_indices = probs[0].topk(3) | |
| top_classes = {label_mapping[idx.item()]: prob.item() for idx, prob in zip(top_indices, top_probs)} | |
| return top_classes | |
| examples = [ | |
| ["assets/banana.jpg"], | |
| ["assets/pineapple.jpg"], | |
| ["assets/mango.jpg"], | |
| ["assets/melon.jpg"], | |
| ["assets/orange.jpg"], | |
| ["assets/eggplant.jpg"], | |
| ["assets/black.jpg"], | |
| ["assets/white.jpg"] | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Plant Classification") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pic = gr.Image(label="Upload Plant Image", type="pil", height=300, width=300) | |
| model_choice = gr.Dropdown(choices=["ResNet", "MobileNetV2"], label="Select Model", value="ResNet") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| predict_btn = gr.Button("Predict") | |
| with gr.Column(scale=1): | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(): | |
| output = gr.Label(label="Top 3 Predicted Classes") | |
| predict_btn.click(fn=predict, inputs=[pic, model_choice], outputs=output, api_name="predict") | |
| clear_btn.click(lambda: (None, None), outputs=[pic, output]) | |
| gr.Examples(examples=examples, inputs=[pic]) | |
| demo.launch() | |