mhamzaerol's picture
thumbnail issue
f446994
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi.routing import APIRouter
from typing import List
import base64
import gdown
import io
import os
import pickle
import time
import numpy as np
from PIL import Image
from io import BytesIO
segmentationColors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 165, 0),
(128, 0, 128),
(255, 192, 203),
(50, 205, 50),
(0, 128, 128),
(139, 69, 19)
]
data_path = os.getenv('DATA_PATH')
if not os.path.exists(data_path):
url = os.getenv('DATA_URL')
file_id = url.split('/')[-2]
direct_link = f"https://drive.google.com/uc?id={file_id}"
gdown.download(direct_link, data_path, quiet=False)
try:
with open(data_path, 'rb') as f:
data = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Failed to load data from {data_path}: {e}")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins or specify your frontend's domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create a router for the API endpoints
router = APIRouter()
def overlay_mask(base_image, mask_image, color_idx):
overlay = np.array(base_image, dtype=np.uint8)
mask = np.array(mask_image).astype(bool)
if overlay.shape[:2] != mask.shape:
raise ValueError("Base image and mask must have the same dimensions.")
if not (0 <= color_idx < len(segmentationColors)):
raise ValueError(f"Color index {color_idx} is out of bounds.")
color = np.array(segmentationColors[color_idx], dtype=np.uint8)
overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8)
return Image.fromarray(overlay)
def convert_to_pil(image):
if isinstance(image, np.ndarray):
return Image.fromarray(image)
return image
async def return_thumbnails():
thumbnails = []
for item in data:
pil_image = convert_to_pil(item['image'])
# make a copy
thumb_img = pil_image.copy()
thumb_img.thumbnail((256, 256))
thumbnails.append(thumb_img)
return thumbnails
def rgb_to_hex(rgb):
return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
async def return_state_data(state):
image_data = data[state['image_index']]
base_image = convert_to_pil(image_data['image'])
response = {
'mask_overlayed_image': base_image,
'valid_object_color_tuples': [],
'invalid_objects': []
}
mask_data = image_data['mask_data'].get(state['detail_level'], {})
for object_type, mask_info in mask_data.items():
if mask_info['valid']:
idx = len(response['valid_object_color_tuples'])
if idx in state['object_list']:
response['mask_overlayed_image'] = overlay_mask(
response['mask_overlayed_image'],
mask_info['mask'], idx
)
color = segmentationColors[idx]
response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color)))
else:
response['invalid_objects'].append(object_type)
buffer = BytesIO()
response['mask_overlayed_image'].save(buffer, format="PNG")
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
response['mask_overlayed_image'] = base64_str
return response
@router.get("/return_thumbnails")
async def return_thumbnails_endpoint():
thumbnails = await return_thumbnails()
encoded_images = []
for thumbnail in thumbnails:
buffer = BytesIO()
thumbnail.save(buffer, format="PNG")
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
encoded_images.append(base64_str)
return JSONResponse(content={"thumbnails": encoded_images})
@router.get("/return_state_data")
async def return_state_data_endpoint(
image_index: int = Query(...),
detail_level: int = Query(...),
object_list: str = Query(...)
):
if object_list == 'None':
object_list = []
else:
object_list = [int(x) for x in object_list.split(",")]
state = {
"image_index": image_index,
"detail_level": detail_level,
"object_list": object_list,
}
response = await return_state_data(state)
return response
# Include the router with a prefix, making endpoints accessible under /api
app.include_router(router, prefix="/api")
# Serve the React frontend if available
frontend_path = "/app/frontend/build"
if os.path.exists(frontend_path):
app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
else:
print(f"Warning: Frontend build directory '{frontend_path}' does not exist.")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)