import gradio as gr import torch from torchvision import transforms from PIL import Image import os import pandas as pd # --- Model Loading --- def get_model_paths(): """Returns a dictionary of model names to their file paths.""" model_dir = "models" model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")] # You can create more descriptive names here if you want model_names = [os.path.splitext(f)[0] for f in model_files] return dict(zip(model_names, [os.path.join(model_dir, f) for f in model_files])) MODEL_PATHS = get_model_paths() # Add placeholder paths for the other two models MODEL_PATHS["Future Model 1"] = "models/future_model_1.pth" MODEL_PATHS["Future Model 2"] = "models/future_model_2.pth" # This is a placeholder for your actual model loading logic # You will need to replace this with the code to load your specific model architecture def load_model(model_path): """Loads a model from the given path.""" # Example: # model = torch.load(model_path) # model.eval() # return model # For now, returning a dummy object print(f"Loading model from: {model_path}") if not os.path.exists(model_path): print("Warning: Model file does not exist. Using a dummy model.") return None # Replace with your actual model loading try: # This is a guess, you'll need to replace with your actual model class from baseline import convnext_v2_base model = convnext_v2_base(num_classes=10) # Or whatever your number of classes is model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model except Exception as e: print(f"Error loading model: {e}") print("Using a dummy model.") return None # --- Image Preprocessing --- # You'll need to adjust this to match the preprocessing your model expects preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --- Prediction --- # Load species list species_df = pd.read_csv('species_list.txt', sep=';', header=None, names=['class_id', 'species_name']) idx_to_class = {i: row['class_id'] for i, row in species_df.iterrows()} class_id_to_name = {row['class_id']: row['species_name'] for i, row in species_df.iterrows()} def predict(model_name, image): """Makes a prediction on an image using the selected model.""" model_path = MODEL_PATHS[model_name] if not os.path.exists(model_path): return f"Model '{model_name}' not found. Please upload the model file." model = load_model(model_path) if model is None: return f"Could not load model '{model_name}'." pil_image = Image.fromarray(image.astype('uint8'), 'RGB') processed_image = preprocess(pil_image).unsqueeze(0) with torch.no_grad(): outputs = model(processed_image).logits _, predicted_idx = torch.max(outputs, 1) class_id = idx_to_class[predicted_idx.item()] class_name = class_id_to_name[class_id] return f"Prediction: {class_name}" # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# Plant Classification") gr.Markdown("Select a model and upload an image to classify.") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0] if MODEL_PATHS else None ) image_input = gr.Image(type="numpy") output_text = gr.Textbox(label="Prediction") image_input.change( fn=predict, inputs=[model_dropdown, image_input], outputs=output_text ) model_dropdown.change( fn=predict, inputs=[model_dropdown, image_input], outputs=output_text ) if __name__ == "__main__": demo.launch()