Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app for ConvNeXtV2 Image Regression Model | |
| Predicts price of installing Celebright permanent holiday lights | |
| given an image of a residential home. | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification | |
| from peft import PeftModel | |
| import albumentations as A | |
| # Configuration | |
| BASE_MODEL = "facebook/convnextv2-base-22k-384" | |
| ADAPTER_PATH = "." # Adapter files are in the same directory as app.py | |
| # Denormalization constants (from training) | |
| MEAN_PRICE = 2928.0898333333334 | |
| STD_PRICE = 883.9606849497703 | |
| # Image preprocessing constants | |
| IMAGE_SIZE = 384 | |
| IMAGE_MEAN = [0.485, 0.456, 0.406] | |
| IMAGE_STD = [0.229, 0.224, 0.225] | |
| print("Loading model...") | |
| # Create config for regression (num_labels=1) | |
| model_config = AutoConfig.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=1, | |
| trust_remote_code=True, | |
| ) | |
| model_config._num_labels = 1 | |
| model_config.label2id = {"target": 0} | |
| model_config.id2label = {0: "target"} | |
| # Load base model | |
| base_model = AutoModelForImageClassification.from_pretrained( | |
| BASE_MODEL, | |
| config=model_config, | |
| trust_remote_code=True, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| # Load LoRA adapter | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) | |
| model.eval() | |
| # Set device (CPU for free tier) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| print(f"Model loaded on {device}") | |
| # Define preprocessing transform (matching training) | |
| transform = A.Compose([ | |
| A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE), | |
| A.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD), | |
| ]) | |
| def predict(image: Image.Image) -> dict: | |
| """ | |
| Run inference on an image and return the predicted price. | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Dictionary with predicted price | |
| """ | |
| if image is None: | |
| return {"error": "No image provided"} | |
| # Convert to numpy array | |
| img_array = np.array(image.convert("RGB")) | |
| # Apply preprocessing | |
| transformed = transform(image=img_array) | |
| img_tensor = transformed['image'] | |
| # Transpose to (C, H, W) and add batch dimension | |
| img_tensor = np.transpose(img_tensor, (2, 0, 1)).astype(np.float32) | |
| img_tensor = torch.tensor(img_tensor, dtype=torch.float).unsqueeze(0).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=img_tensor) | |
| normalized_pred = outputs.logits.squeeze().cpu().item() | |
| # Denormalize to get actual price | |
| predicted_price = normalized_pred * STD_PRICE + MEAN_PRICE | |
| return { | |
| "predicted_price": round(predicted_price, 2), | |
| "currency": "CAD", | |
| "normalized_output": round(normalized_pred, 4) | |
| } | |
| def predict_simple(image: Image.Image) -> str: | |
| """ | |
| Simplified prediction function for Gradio interface. | |
| Returns formatted price string. | |
| """ | |
| result = predict(image) | |
| if "error" in result: | |
| return result["error"] | |
| return f"${result['predicted_price']:,.2f}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_simple, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Textbox(label="Predicted Price"), | |
| description="Upload an image of a house to get a Celebright installation price prediction.", | |
| examples=[], # Add example images if desired | |
| flagging_mode="never", | |
| ) | |
| # Also expose the raw prediction function for API users who want JSON | |
| demo_api = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.JSON(), | |
| api_name="predict_json" | |
| ) | |
| # Combine interfaces | |
| app = gr.TabbedInterface( | |
| [demo, demo_api], | |
| ["Simple", "JSON API"], | |
| title="Celebright AI Quote Tool" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |