Spaces:
Runtime error
Runtime error
File size: 4,975 Bytes
dba440b b61399b dba440b b61399b dba440b b61399b dba440b b61399b 7cce1f0 b61399b 7f82b58 b61399b dba440b 42a8ae5 b61399b f446994 b61399b dba440b b61399b dba440b b61399b dba440b b61399b dba440b b61399b dba440b b61399b dba440b 42a8ae5 a336494 b61399b dba440b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi.routing import APIRouter
from typing import List
import base64
import gdown
import io
import os
import pickle
import time
import numpy as np
from PIL import Image
from io import BytesIO
segmentationColors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 165, 0),
(128, 0, 128),
(255, 192, 203),
(50, 205, 50),
(0, 128, 128),
(139, 69, 19)
]
data_path = os.getenv('DATA_PATH')
if not os.path.exists(data_path):
url = os.getenv('DATA_URL')
file_id = url.split('/')[-2]
direct_link = f"https://drive.google.com/uc?id={file_id}"
gdown.download(direct_link, data_path, quiet=False)
try:
with open(data_path, 'rb') as f:
data = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Failed to load data from {data_path}: {e}")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins or specify your frontend's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create a router for the API endpoints
router = APIRouter()
def overlay_mask(base_image, mask_image, color_idx):
overlay = np.array(base_image, dtype=np.uint8)
mask = np.array(mask_image).astype(bool)
if overlay.shape[:2] != mask.shape:
raise ValueError("Base image and mask must have the same dimensions.")
if not (0 <= color_idx < len(segmentationColors)):
raise ValueError(f"Color index {color_idx} is out of bounds.")
color = np.array(segmentationColors[color_idx], dtype=np.uint8)
overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8)
return Image.fromarray(overlay)
def convert_to_pil(image):
if isinstance(image, np.ndarray):
return Image.fromarray(image)
return image
async def return_thumbnails():
thumbnails = []
for item in data:
pil_image = convert_to_pil(item['image'])
# make a copy
thumb_img = pil_image.copy()
thumb_img.thumbnail((256, 256))
thumbnails.append(thumb_img)
return thumbnails
def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
async def return_state_data(state):
image_data = data[state['image_index']]
base_image = convert_to_pil(image_data['image'])
response = {
'mask_overlayed_image': base_image,
'valid_object_color_tuples': [],
'invalid_objects': []
}
mask_data = image_data['mask_data'].get(state['detail_level'], {})
for object_type, mask_info in mask_data.items():
if mask_info['valid']:
idx = len(response['valid_object_color_tuples'])
if idx in state['object_list']:
response['mask_overlayed_image'] = overlay_mask(
response['mask_overlayed_image'],
mask_info['mask'], idx
)
color = segmentationColors[idx]
response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color)))
else:
response['invalid_objects'].append(object_type)
buffer = BytesIO()
response['mask_overlayed_image'].save(buffer, format="PNG")
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
response['mask_overlayed_image'] = base64_str
return response
@router.get("/return_thumbnails")
async def return_thumbnails_endpoint():
thumbnails = await return_thumbnails()
encoded_images = []
for thumbnail in thumbnails:
buffer = BytesIO()
thumbnail.save(buffer, format="PNG")
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
encoded_images.append(base64_str)
return JSONResponse(content={"thumbnails": encoded_images})
@router.get("/return_state_data")
async def return_state_data_endpoint(
image_index: int = Query(...),
detail_level: int = Query(...),
object_list: str = Query(...)
):
if object_list == 'None':
object_list = []
else:
object_list = [int(x) for x in object_list.split(",")]
state = {
"image_index": image_index,
"detail_level": detail_level,
"object_list": object_list,
}
response = await return_state_data(state)
return response
# Include the router with a prefix, making endpoints accessible under /api
app.include_router(router, prefix="/api")
# Serve the React frontend if available
frontend_path = "/app/frontend/build"
if os.path.exists(frontend_path):
app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
else:
print(f"Warning: Frontend build directory '{frontend_path}' does not exist.")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|