mhamzaerol commited on
Commit
dba440b
·
1 Parent(s): 42a8ae5

router fix

Browse files
Files changed (1) hide show
  1. backend/app.py +24 -62
backend/app.py CHANGED
@@ -1,19 +1,18 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.staticfiles import StaticFiles
4
- from PIL import Image
 
 
 
 
5
  import io
6
  import os
 
7
  import time
8
  import numpy as np
9
- import pickle
10
- import base64
11
  from io import BytesIO
12
- from fastapi.responses import JSONResponse
13
- from fastapi import Query
14
- from typing import List
15
- import gdown
16
-
17
 
18
  segmentationColors = [
19
  (255, 0, 0),
@@ -32,7 +31,6 @@ segmentationColors = [
32
 
33
  data_path = os.getenv('DATA_PATH')
34
  if not os.path.exists(data_path):
35
- # raise FileNotFoundError(f"The data file at {data_path} was not found.")
36
  url = os.getenv('DATA_URL')
37
  file_id = url.split('/')[-2]
38
  direct_link = f"https://drive.google.com/uc?id={file_id}"
@@ -54,59 +52,33 @@ app.add_middleware(
54
  allow_headers=["*"],
55
  )
56
 
57
- # app.mount("/", StaticFiles(directory="/app/frontend/build", html=True), name="frontend")
58
- # Serve the React frontend
59
  frontend_path = "/app/frontend/build"
60
  if os.path.exists(frontend_path):
61
  app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
62
  else:
63
  print(f"Warning: Frontend build directory '{frontend_path}' does not exist.")
64
 
65
- api = FastAPI()
 
66
 
67
  def overlay_mask(base_image, mask_image, color_idx):
68
- """
69
- Given the base_image and the 0/1 mask, overlay the mask with the color indexed by the color_idx.
70
- """
71
- # Convert inputs to NumPy arrays
72
  overlay = np.array(base_image, dtype=np.uint8)
73
  mask = np.array(mask_image).astype(bool)
74
-
75
  if overlay.shape[:2] != mask.shape:
76
  raise ValueError("Base image and mask must have the same dimensions.")
77
-
78
- # Ensure color index is valid
79
  if not (0 <= color_idx < len(segmentationColors)):
80
  raise ValueError(f"Color index {color_idx} is out of bounds.")
81
-
82
- # Retrieve color
83
  color = np.array(segmentationColors[color_idx], dtype=np.uint8)
84
-
85
- # Print debugging info (optional)
86
- print(f'Overlay shape: {overlay.shape}')
87
- print(f'Mask shape: {mask.shape}')
88
- print(f'Color idx: {color_idx}')
89
- print(f'Color: {color}')
90
-
91
- # Apply color blending
92
  overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8)
93
-
94
- # Convert back to Image
95
  return Image.fromarray(overlay)
96
 
97
-
98
  def convert_to_pil(image):
99
- """
100
- Ensure the image is a PIL Image.
101
- """
102
  if isinstance(image, np.ndarray):
103
  return Image.fromarray(image)
104
  return image
105
 
106
  async def return_thumbnails():
107
- """
108
- Return a list of thumbnail images.
109
- """
110
  thumbnails = []
111
  for item in data:
112
  pil_image = convert_to_pil(item['image'])
@@ -114,67 +86,56 @@ async def return_thumbnails():
114
  return thumbnails
115
 
116
  def rgb_to_hex(rgb):
117
- """Convert an RGB tuple to a HEX string."""
118
  return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
119
 
120
  async def return_state_data(state):
121
- print(state)
122
- """
123
- Return state-specific data including overlays and object validity.
124
- """
125
  image_data = data[state['image_index']]
126
  base_image = convert_to_pil(image_data['image'])
127
-
128
  response = {
129
  'mask_overlayed_image': base_image,
130
  'valid_object_color_tuples': [],
131
  'invalid_objects': []
132
  }
133
-
134
  mask_data = image_data['mask_data'].get(state['detail_level'], {})
135
- # print(mask_data)
136
-
137
  for object_type, mask_info in mask_data.items():
138
  if mask_info['valid']:
139
  idx = len(response['valid_object_color_tuples'])
140
  if idx in state['object_list']:
141
- print(f"Overlaying mask for {object_type} with color {segmentationColors[idx]} as {idx} is in {state['object_list']}")
142
- response['mask_overlayed_image'] = overlay_mask(response['mask_overlayed_image'], mask_info['mask'], idx)
143
- # append the color at idx in hex
 
144
  color = segmentationColors[idx]
145
  response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color)))
146
  else:
147
  response['invalid_objects'].append(object_type)
148
-
149
  buffer = BytesIO()
150
- response['mask_overlayed_image'].save(buffer, format="PNG") # Save image to a buffer in PNG format
151
- base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") # Encode to Base64
152
  response['mask_overlayed_image'] = base64_str
153
  return response
154
 
155
- @api.get("/return_thumbnails")
156
  async def return_thumbnails_endpoint():
157
  thumbnails = await return_thumbnails()
158
  encoded_images = []
159
  for thumbnail in thumbnails:
160
  buffer = BytesIO()
161
- thumbnail.save(buffer, format="PNG") # Save image to a buffer in PNG format
162
- base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") # Encode to Base64
163
  encoded_images.append(base64_str)
164
  return JSONResponse(content={"thumbnails": encoded_images})
165
 
166
- @api.get("/return_state_data")
167
  async def return_state_data_endpoint(
168
  image_index: int = Query(...),
169
  detail_level: int = Query(...),
170
  object_list: str = Query(...)
171
  ):
172
- print(object_list)
173
  if object_list == 'None':
174
  object_list = []
175
  else:
176
  object_list = [int(x) for x in object_list.split(",")]
177
- print(object_list)
178
  state = {
179
  "image_index": image_index,
180
  "detail_level": detail_level,
@@ -183,8 +144,9 @@ async def return_state_data_endpoint(
183
  response = await return_state_data(state)
184
  return response
185
 
186
- app.mount("/api", api)
 
187
 
188
  if __name__ == "__main__":
189
  import uvicorn
190
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, File, UploadFile, Query
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.routing import APIRouter
6
+ from typing import List
7
+ import base64
8
+ import gdown
9
  import io
10
  import os
11
+ import pickle
12
  import time
13
  import numpy as np
14
+ from PIL import Image
 
15
  from io import BytesIO
 
 
 
 
 
16
 
17
  segmentationColors = [
18
  (255, 0, 0),
 
31
 
32
  data_path = os.getenv('DATA_PATH')
33
  if not os.path.exists(data_path):
 
34
  url = os.getenv('DATA_URL')
35
  file_id = url.split('/')[-2]
36
  direct_link = f"https://drive.google.com/uc?id={file_id}"
 
52
  allow_headers=["*"],
53
  )
54
 
55
+ # Serve the React frontend if available
 
56
  frontend_path = "/app/frontend/build"
57
  if os.path.exists(frontend_path):
58
  app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
59
  else:
60
  print(f"Warning: Frontend build directory '{frontend_path}' does not exist.")
61
 
62
+ # Create a router for the API endpoints
63
+ router = APIRouter()
64
 
65
  def overlay_mask(base_image, mask_image, color_idx):
 
 
 
 
66
  overlay = np.array(base_image, dtype=np.uint8)
67
  mask = np.array(mask_image).astype(bool)
 
68
  if overlay.shape[:2] != mask.shape:
69
  raise ValueError("Base image and mask must have the same dimensions.")
 
 
70
  if not (0 <= color_idx < len(segmentationColors)):
71
  raise ValueError(f"Color index {color_idx} is out of bounds.")
 
 
72
  color = np.array(segmentationColors[color_idx], dtype=np.uint8)
 
 
 
 
 
 
 
 
73
  overlay[mask] = (overlay[mask] * 0.4 + color * 0.6).astype(np.uint8)
 
 
74
  return Image.fromarray(overlay)
75
 
 
76
  def convert_to_pil(image):
 
 
 
77
  if isinstance(image, np.ndarray):
78
  return Image.fromarray(image)
79
  return image
80
 
81
  async def return_thumbnails():
 
 
 
82
  thumbnails = []
83
  for item in data:
84
  pil_image = convert_to_pil(item['image'])
 
86
  return thumbnails
87
 
88
  def rgb_to_hex(rgb):
 
89
  return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
90
 
91
  async def return_state_data(state):
 
 
 
 
92
  image_data = data[state['image_index']]
93
  base_image = convert_to_pil(image_data['image'])
 
94
  response = {
95
  'mask_overlayed_image': base_image,
96
  'valid_object_color_tuples': [],
97
  'invalid_objects': []
98
  }
 
99
  mask_data = image_data['mask_data'].get(state['detail_level'], {})
 
 
100
  for object_type, mask_info in mask_data.items():
101
  if mask_info['valid']:
102
  idx = len(response['valid_object_color_tuples'])
103
  if idx in state['object_list']:
104
+ response['mask_overlayed_image'] = overlay_mask(
105
+ response['mask_overlayed_image'],
106
+ mask_info['mask'], idx
107
+ )
108
  color = segmentationColors[idx]
109
  response['valid_object_color_tuples'].append((object_type, rgb_to_hex(color)))
110
  else:
111
  response['invalid_objects'].append(object_type)
 
112
  buffer = BytesIO()
113
+ response['mask_overlayed_image'].save(buffer, format="PNG")
114
+ base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
115
  response['mask_overlayed_image'] = base64_str
116
  return response
117
 
118
+ @router.get("/return_thumbnails")
119
  async def return_thumbnails_endpoint():
120
  thumbnails = await return_thumbnails()
121
  encoded_images = []
122
  for thumbnail in thumbnails:
123
  buffer = BytesIO()
124
+ thumbnail.save(buffer, format="PNG")
125
+ base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
126
  encoded_images.append(base64_str)
127
  return JSONResponse(content={"thumbnails": encoded_images})
128
 
129
+ @router.get("/return_state_data")
130
  async def return_state_data_endpoint(
131
  image_index: int = Query(...),
132
  detail_level: int = Query(...),
133
  object_list: str = Query(...)
134
  ):
 
135
  if object_list == 'None':
136
  object_list = []
137
  else:
138
  object_list = [int(x) for x in object_list.split(",")]
 
139
  state = {
140
  "image_index": image_index,
141
  "detail_level": detail_level,
 
144
  response = await return_state_data(state)
145
  return response
146
 
147
+ # Include the router with a prefix, making endpoints accessible under /api
148
+ app.include_router(router, prefix="/api")
149
 
150
  if __name__ == "__main__":
151
  import uvicorn
152
+ uvicorn.run(app, host="0.0.0.0", port=7860)