PsychicFireSong's picture
Initial upload of Gradio app and baseline model
b404ed5
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
# --- Model Loading ---
def get_model_paths():
"""Returns a dictionary of model names to their file paths."""
model_dir = "models"
model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
# You can create more descriptive names here if you want
model_names = [os.path.splitext(f)[0] for f in model_files]
return dict(zip(model_names, [os.path.join(model_dir, f) for f in model_files]))
MODEL_PATHS = get_model_paths()
# Add placeholder paths for the other two models
MODEL_PATHS["Future Model 1"] = "models/future_model_1.pth"
MODEL_PATHS["Future Model 2"] = "models/future_model_2.pth"
# This is a placeholder for your actual model loading logic
# You will need to replace this with the code to load your specific model architecture
def load_model(model_path):
"""Loads a model from the given path."""
# Example:
# model = torch.load(model_path)
# model.eval()
# return model
# For now, returning a dummy object
print(f"Loading model from: {model_path}")
if not os.path.exists(model_path):
print("Warning: Model file does not exist. Using a dummy model.")
return None
# Replace with your actual model loading
try:
# This is a guess, you'll need to replace with your actual model class
from baseline import convnext_v2_base
model = convnext_v2_base(num_classes=10) # Or whatever your number of classes is
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
print("Using a dummy model.")
return None
# --- Image Preprocessing ---
# You'll need to adjust this to match the preprocessing your model expects
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- Prediction ---
# Load species list
species_df = pd.read_csv('species_list.txt', sep=';', header=None, names=['class_id', 'species_name'])
idx_to_class = {i: row['class_id'] for i, row in species_df.iterrows()}
class_id_to_name = {row['class_id']: row['species_name'] for i, row in species_df.iterrows()}
def predict(model_name, image):
"""Makes a prediction on an image using the selected model."""
model_path = MODEL_PATHS[model_name]
if not os.path.exists(model_path):
return f"Model '{model_name}' not found. Please upload the model file."
model = load_model(model_path)
if model is None:
return f"Could not load model '{model_name}'."
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
processed_image = preprocess(pil_image).unsqueeze(0)
with torch.no_grad():
outputs = model(processed_image).logits
_, predicted_idx = torch.max(outputs, 1)
class_id = idx_to_class[predicted_idx.item()]
class_name = class_id_to_name[class_id]
return f"Prediction: {class_name}"
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Plant Classification")
gr.Markdown("Select a model and upload an image to classify.")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODEL_PATHS.keys()),
label="Select Model",
value=list(MODEL_PATHS.keys())[0] if MODEL_PATHS else None
)
image_input = gr.Image(type="numpy")
output_text = gr.Textbox(label="Prediction")
image_input.change(
fn=predict,
inputs=[model_dropdown, image_input],
outputs=output_text
)
model_dropdown.change(
fn=predict,
inputs=[model_dropdown, image_input],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()