File size: 3,978 Bytes
b404ed5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
import pandas as pd

# --- Model Loading ---
def get_model_paths():
    """Returns a dictionary of model names to their file paths."""
    model_dir = "models"
    model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
    # You can create more descriptive names here if you want
    model_names = [os.path.splitext(f)[0] for f in model_files]
    return dict(zip(model_names, [os.path.join(model_dir, f) for f in model_files]))

MODEL_PATHS = get_model_paths()
# Add placeholder paths for the other two models
MODEL_PATHS["Future Model 1"] = "models/future_model_1.pth"
MODEL_PATHS["Future Model 2"] = "models/future_model_2.pth"


# This is a placeholder for your actual model loading logic
# You will need to replace this with the code to load your specific model architecture
def load_model(model_path):
    """Loads a model from the given path."""
    # Example:
    # model = torch.load(model_path)
    # model.eval()
    # return model
    # For now, returning a dummy object
    print(f"Loading model from: {model_path}")
    if not os.path.exists(model_path):
        print("Warning: Model file does not exist. Using a dummy model.")
        return None
    # Replace with your actual model loading
    try:
        # This is a guess, you'll need to replace with your actual model class
        from baseline import convnext_v2_base 
        model = convnext_v2_base(num_classes=10) # Or whatever your number of classes is
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Using a dummy model.")
        return None


# --- Image Preprocessing ---
# You'll need to adjust this to match the preprocessing your model expects
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- Prediction ---
# Load species list
species_df = pd.read_csv('species_list.txt', sep=';', header=None, names=['class_id', 'species_name'])
idx_to_class = {i: row['class_id'] for i, row in species_df.iterrows()}
class_id_to_name = {row['class_id']: row['species_name'] for i, row in species_df.iterrows()}


def predict(model_name, image):
    """Makes a prediction on an image using the selected model."""
    model_path = MODEL_PATHS[model_name]
    
    if not os.path.exists(model_path):
        return f"Model '{model_name}' not found. Please upload the model file."

    model = load_model(model_path)
    if model is None:
        return f"Could not load model '{model_name}'."

    pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
    processed_image = preprocess(pil_image).unsqueeze(0)

    with torch.no_grad():
        outputs = model(processed_image).logits
        _, predicted_idx = torch.max(outputs, 1)
        class_id = idx_to_class[predicted_idx.item()]
        class_name = class_id_to_name[class_id]
        return f"Prediction: {class_name}"


# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# Plant Classification")
    gr.Markdown("Select a model and upload an image to classify.")

    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(MODEL_PATHS.keys()),
            label="Select Model",
            value=list(MODEL_PATHS.keys())[0] if MODEL_PATHS else None
        )
        image_input = gr.Image(type="numpy")
    
    output_text = gr.Textbox(label="Prediction")

    image_input.change(
        fn=predict,
        inputs=[model_dropdown, image_input],
        outputs=output_text
    )
    model_dropdown.change(
        fn=predict,
        inputs=[model_dropdown, image_input],
        outputs=output_text
    )


if __name__ == "__main__":
    demo.launch()