Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| import pandas as pd | |
| # --- Model Loading --- | |
| def get_model_paths(): | |
| """Returns a dictionary of model names to their file paths.""" | |
| model_dir = "models" | |
| model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")] | |
| # You can create more descriptive names here if you want | |
| model_names = [os.path.splitext(f)[0] for f in model_files] | |
| return dict(zip(model_names, [os.path.join(model_dir, f) for f in model_files])) | |
| MODEL_PATHS = get_model_paths() | |
| # Add placeholder paths for the other two models | |
| MODEL_PATHS["Future Model 1"] = "models/future_model_1.pth" | |
| MODEL_PATHS["Future Model 2"] = "models/future_model_2.pth" | |
| # This is a placeholder for your actual model loading logic | |
| # You will need to replace this with the code to load your specific model architecture | |
| def load_model(model_path): | |
| """Loads a model from the given path.""" | |
| # Example: | |
| # model = torch.load(model_path) | |
| # model.eval() | |
| # return model | |
| # For now, returning a dummy object | |
| print(f"Loading model from: {model_path}") | |
| if not os.path.exists(model_path): | |
| print("Warning: Model file does not exist. Using a dummy model.") | |
| return None | |
| # Replace with your actual model loading | |
| try: | |
| # This is a guess, you'll need to replace with your actual model class | |
| from baseline import convnext_v2_base | |
| model = convnext_v2_base(num_classes=10) # Or whatever your number of classes is | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Using a dummy model.") | |
| return None | |
| # --- Image Preprocessing --- | |
| # You'll need to adjust this to match the preprocessing your model expects | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # --- Prediction --- | |
| # Load species list | |
| species_df = pd.read_csv('species_list.txt', sep=';', header=None, names=['class_id', 'species_name']) | |
| idx_to_class = {i: row['class_id'] for i, row in species_df.iterrows()} | |
| class_id_to_name = {row['class_id']: row['species_name'] for i, row in species_df.iterrows()} | |
| def predict(model_name, image): | |
| """Makes a prediction on an image using the selected model.""" | |
| model_path = MODEL_PATHS[model_name] | |
| if not os.path.exists(model_path): | |
| return f"Model '{model_name}' not found. Please upload the model file." | |
| model = load_model(model_path) | |
| if model is None: | |
| return f"Could not load model '{model_name}'." | |
| pil_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| processed_image = preprocess(pil_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(processed_image).logits | |
| _, predicted_idx = torch.max(outputs, 1) | |
| class_id = idx_to_class[predicted_idx.item()] | |
| class_name = class_id_to_name[class_id] | |
| return f"Prediction: {class_name}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Plant Classification") | |
| gr.Markdown("Select a model and upload an image to classify.") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_PATHS.keys()), | |
| label="Select Model", | |
| value=list(MODEL_PATHS.keys())[0] if MODEL_PATHS else None | |
| ) | |
| image_input = gr.Image(type="numpy") | |
| output_text = gr.Textbox(label="Prediction") | |
| image_input.change( | |
| fn=predict, | |
| inputs=[model_dropdown, image_input], | |
| outputs=output_text | |
| ) | |
| model_dropdown.change( | |
| fn=predict, | |
| inputs=[model_dropdown, image_input], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |