Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import JSONResponse | |
| from fastapi.routing import APIRouter | |
| from typing import List | |
| import base64 | |
| import gdown | |
| import io | |
| import os | |
| import pickle | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from io import BytesIO | |
| segmentationColors = [ | |
| (255, 0, 0), | |
| (0, 255, 0), | |
| (0, 0, 255), | |
| (255, 255, 0), | |
| (255, 0, 255), | |
| (0, 255, 255), | |
| (255, 165, 0), | |
| (128, 0, 128), | |
| (255, 192, 203), | |
| (50, 205, 50), | |
| (0, 128, 128), | |
| (139, 69, 19) | |
| ] | |
| data_path = os.getenv('DATA_PATH') | |
| if not os.path.exists(data_path): | |
| url = os.getenv('DATA_URL') | |
| file_id = url.split('/')[-2] | |
| direct_link = f"https://drive.google.com/uc?id={file_id}" | |
| gdown.download(direct_link, data_path, quiet=False) | |
| try: | |
| with open(data_path, 'rb') as f: | |
| data = pickle.load(f) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load data from {data_path}: {e}") | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins or specify your frontend's domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Create a router for the API endpoints | |
| router = APIRouter() | |
| def overlay_mask(base_image, mask_image, color_idx): | |
| overlay = np.array(base_image, dtype=np.uint8) | |
| mask = np.array(mask_image).astype(bool) | |
| if overlay.shape[:2] != mask.shape: | |
| raise ValueError("Base image and mask must have the same dimensions.") | |
| if not (0 <= color_idx < len(segmentationColors)): | |
| raise ValueError(f"Color index {color_idx} is out of bounds.") | |
| color = np.array(segmentationColors[color_idx], dtype=np.uint8) | |
| overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8) | |
| return Image.fromarray(overlay) | |
| def convert_to_pil(image): | |
| if isinstance(image, np.ndarray): | |
| return Image.fromarray(image) | |
| return image | |
| async def return_thumbnails(): | |
| thumbnails = [] | |
| for item in data: | |
| pil_image = convert_to_pil(item['image']) | |
| # make a copy | |
| thumb_img = pil_image.copy() | |
| thumb_img.thumbnail((256, 256)) | |
| thumbnails.append(thumb_img) | |
| return thumbnails | |
| def rgb_to_hex(rgb): | |
| return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2]) | |
| async def return_state_data(state): | |
| image_data = data[state['image_index']] | |
| base_image = convert_to_pil(image_data['image']) | |
| response = { | |
| 'mask_overlayed_image': base_image, | |
| 'valid_object_color_tuples': [], | |
| 'invalid_objects': [] | |
| } | |
| mask_data = image_data['mask_data'].get(state['detail_level'], {}) | |
| for object_type, mask_info in mask_data.items(): | |
| if mask_info['valid']: | |
| idx = len(response['valid_object_color_tuples']) | |
| if idx in state['object_list']: | |
| response['mask_overlayed_image'] = overlay_mask( | |
| response['mask_overlayed_image'], | |
| mask_info['mask'], idx | |
| ) | |
| color = segmentationColors[idx] | |
| response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color))) | |
| else: | |
| response['invalid_objects'].append(object_type) | |
| buffer = BytesIO() | |
| response['mask_overlayed_image'].save(buffer, format="PNG") | |
| base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| response['mask_overlayed_image'] = base64_str | |
| return response | |
| async def return_thumbnails_endpoint(): | |
| thumbnails = await return_thumbnails() | |
| encoded_images = [] | |
| for thumbnail in thumbnails: | |
| buffer = BytesIO() | |
| thumbnail.save(buffer, format="PNG") | |
| base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| encoded_images.append(base64_str) | |
| return JSONResponse(content={"thumbnails": encoded_images}) | |
| async def return_state_data_endpoint( | |
| image_index: int = Query(...), | |
| detail_level: int = Query(...), | |
| object_list: str = Query(...) | |
| ): | |
| if object_list == 'None': | |
| object_list = [] | |
| else: | |
| object_list = [int(x) for x in object_list.split(",")] | |
| state = { | |
| "image_index": image_index, | |
| "detail_level": detail_level, | |
| "object_list": object_list, | |
| } | |
| response = await return_state_data(state) | |
| return response | |
| # Include the router with a prefix, making endpoints accessible under /api | |
| app.include_router(router, prefix="/api") | |
| # Serve the React frontend if available | |
| frontend_path = "/app/frontend/build" | |
| if os.path.exists(frontend_path): | |
| app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend") | |
| else: | |
| print(f"Warning: Frontend build directory '{frontend_path}' does not exist.") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |