File size: 4,975 Bytes
dba440b
b61399b
 
dba440b
 
 
 
 
b61399b
 
dba440b
b61399b
 
dba440b
b61399b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cce1f0
b61399b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f82b58
b61399b
 
 
 
 
dba440b
 
42a8ae5
b61399b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f446994
 
 
 
b61399b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba440b
 
 
 
b61399b
 
 
 
 
dba440b
 
b61399b
 
 
dba440b
b61399b
 
 
 
 
dba440b
 
b61399b
 
 
dba440b
b61399b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba440b
 
42a8ae5
a336494
 
 
 
 
 
 
b61399b
 
dba440b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi.routing import APIRouter
from typing import List
import base64
import gdown
import io
import os
import pickle
import time
import numpy as np
from PIL import Image
from io import BytesIO

segmentationColors = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
    (255, 165, 0),
    (128, 0, 128),
    (255, 192, 203),
    (50, 205, 50),
    (0, 128, 128),
    (139, 69, 19)
]

data_path = os.getenv('DATA_PATH')
if not os.path.exists(data_path):
    url = os.getenv('DATA_URL')
    file_id = url.split('/')[-2]
    direct_link = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(direct_link, data_path, quiet=False)

try:
    with open(data_path, 'rb') as f:
        data = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load data from {data_path}: {e}")

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins or specify your frontend's domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create a router for the API endpoints
router = APIRouter()

def overlay_mask(base_image, mask_image, color_idx):
    overlay = np.array(base_image, dtype=np.uint8)
    mask = np.array(mask_image).astype(bool)
    if overlay.shape[:2] != mask.shape:
        raise ValueError("Base image and mask must have the same dimensions.")
    if not (0 <= color_idx < len(segmentationColors)):
        raise ValueError(f"Color index {color_idx} is out of bounds.")
    color = np.array(segmentationColors[color_idx], dtype=np.uint8)
    overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8)
    return Image.fromarray(overlay)

def convert_to_pil(image):
    if isinstance(image, np.ndarray):
        return Image.fromarray(image)
    return image

async def return_thumbnails():
    thumbnails = []
    for item in data:
        pil_image = convert_to_pil(item['image'])
        # make a copy
        thumb_img = pil_image.copy()
        thumb_img.thumbnail((256, 256))
        thumbnails.append(thumb_img)
    return thumbnails

def rgb_to_hex(rgb):
    return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])

async def return_state_data(state):
    image_data = data[state['image_index']]
    base_image = convert_to_pil(image_data['image'])
    response = {
        'mask_overlayed_image': base_image,
        'valid_object_color_tuples': [],
        'invalid_objects': []
    }
    mask_data = image_data['mask_data'].get(state['detail_level'], {})
    for object_type, mask_info in mask_data.items():
        if mask_info['valid']:
            idx = len(response['valid_object_color_tuples'])
            if idx in state['object_list']:
                response['mask_overlayed_image'] = overlay_mask(
                    response['mask_overlayed_image'],
                    mask_info['mask'], idx
                )
            color = segmentationColors[idx]
            response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color)))
        else:
            response['invalid_objects'].append(object_type)
    buffer = BytesIO()
    response['mask_overlayed_image'].save(buffer, format="PNG")
    base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    response['mask_overlayed_image'] = base64_str
    return response

@router.get("/return_thumbnails")
async def return_thumbnails_endpoint():
    thumbnails = await return_thumbnails()
    encoded_images = []
    for thumbnail in thumbnails:
        buffer = BytesIO()
        thumbnail.save(buffer, format="PNG")
        base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
        encoded_images.append(base64_str)
    return JSONResponse(content={"thumbnails": encoded_images})

@router.get("/return_state_data")
async def return_state_data_endpoint(
    image_index: int = Query(...),
    detail_level: int = Query(...),
    object_list: str = Query(...)
):
    if object_list == 'None':
        object_list = []
    else:
        object_list = [int(x) for x in object_list.split(",")]
    state = {
        "image_index": image_index,
        "detail_level": detail_level,
        "object_list": object_list,
    }
    response = await return_state_data(state)
    return response

# Include the router with a prefix, making endpoints accessible under /api
app.include_router(router, prefix="/api")

# Serve the React frontend if available
frontend_path = "/app/frontend/build"
if os.path.exists(frontend_path):
    app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
else:
    print(f"Warning: Frontend build directory '{frontend_path}' does not exist.")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)