LiuZichen commited on
Commit
08a4792
·
verified ·
1 Parent(s): f1cf9b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -358
app.py CHANGED
@@ -1,39 +1,14 @@
1
- import subprocess
2
- import shlex
3
- subprocess.run(
4
- shlex.split(
5
- "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl"
6
- )
7
- )
8
- import sys
9
  import os
10
  import gradio as gr
11
  import spaces
12
- import tempfile
13
- import numpy as np
14
- import io
15
- import base64
16
- from gradio_client import Client, handle_file
17
- from huggingface_hub import snapshot_download
18
- from gradio_magicquillv2 import MagicQuillV2
19
- from fastapi import FastAPI, Request
20
- from fastapi.middleware.cors import CORSMiddleware
21
- import uvicorn
22
- import requests
23
- from PIL import Image, ImageOps
24
- import random
25
- import time
26
  import torch
27
- import json
28
-
29
- # Try importing as a package (recommended)
30
  from edit_space import KontextEditModel
31
  from util import (
32
  load_and_preprocess_image,
33
  read_base64_image as read_base64_image_utils,
34
  create_alpha_mask,
35
  tensor_to_base64,
36
- get_mask_bbox
37
  )
38
 
39
  # Initialize models
@@ -43,21 +18,9 @@ snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", lo
43
 
44
  print("Initializing models...")
45
  kontext_model = KontextEditModel()
46
-
47
- # Initialize SAM Client
48
- # Replace with your actual SAM Space ID
49
- sam_client = Client("LiuZichen/MagicQuillHelper")
50
  print("Models initialized.")
51
 
52
- css = """
53
- .ms {
54
- width: 60%;
55
- margin: auto
56
- }
57
- """
58
-
59
- url = "http://localhost:7860"
60
-
61
  def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
62
  print("prompt is:", positive_prompt)
63
  print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
@@ -66,10 +29,6 @@ def generate(merged_image, total_mask, original_image, add_color_image, add_edge
66
  raise RuntimeError("KontextEditModel not initialized")
67
 
68
  # Preprocess inputs
69
- # utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open)
70
- # load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object.
71
- # utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works.
72
-
73
  merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
74
  total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
75
  original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
@@ -126,322 +85,39 @@ def generate(merged_image, total_mask, original_image, add_color_image, add_edge
126
  res_base64 = tensor_to_base64(final_image)
127
  return res_base64
128
 
129
- @spaces.GPU
130
- def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
131
- merged_image = x['from_frontend']['img']
132
- total_mask = x['from_frontend']['total_mask']
133
- original_image = x['from_frontend']['original_image']
134
- add_color_image = x['from_frontend']['add_color_image']
135
- add_edge_mask = x['from_frontend']['add_edge_mask']
136
- remove_edge_mask = x['from_frontend']['remove_edge_mask']
137
- fill_mask = x['from_frontend']['fill_mask']
138
- add_prop_image = x['from_frontend']['add_prop_image']
139
- positive_prompt = x['from_backend']['prompt']
140
-
141
- try:
142
- res_base64 = generate(
143
- merged_image,
144
- total_mask,
145
- original_image,
146
- add_color_image,
147
- add_edge_mask,
148
- remove_edge_mask,
149
- fill_mask,
150
- add_prop_image,
151
- positive_prompt,
152
- negative_prompt,
153
- fine_edge,
154
- fix_perspective,
155
- grow_size,
156
- edge_strength,
157
- color_strength,
158
- local_strength,
159
- seed,
160
- steps,
161
- cfg
162
- )
163
- x["from_backend"]["generated_image"] = res_base64
164
- except Exception as e:
165
- print(f"Error in generation: {e}")
166
- x["from_backend"]["generated_image"] = None
167
-
168
- return x
169
-
170
-
171
- with gr.Blocks(title="MagicQuill V2") as demo:
172
- with gr.Row():
173
- ms = MagicQuillV2()
174
-
175
- with gr.Row():
176
- with gr.Column():
177
- btn = gr.Button("Run", variant="primary")
178
- with gr.Column():
179
- with gr.Accordion("parameters", open=False):
180
- negative_prompt = gr.Textbox(
181
- label="Negative Prompt",
182
- value="",
183
- interactive=True
184
- )
185
- fine_edge = gr.Radio(
186
- label="Fine Edge",
187
- choices=['enable', 'disable'],
188
- value='disable',
189
- interactive=True
190
- )
191
- fix_perspective = gr.Radio(
192
- label="Fix Perspective",
193
- choices=['enable', 'disable'],
194
- value='disable',
195
- interactive=True
196
- )
197
- grow_size = gr.Slider(
198
- label="Grow Size",
199
- minimum=10,
200
- maximum=100,
201
- value=50,
202
- step=1,
203
- interactive=True
204
- )
205
- edge_strength = gr.Slider(
206
- label="Edge Strength",
207
- minimum=0.0,
208
- maximum=5.0,
209
- value=0.6,
210
- step=0.01,
211
- interactive=True
212
- )
213
- color_strength = gr.Slider(
214
- label="Color Strength",
215
- minimum=0.0,
216
- maximum=5.0,
217
- value=1.5,
218
- step=0.01,
219
- interactive=True
220
- )
221
- local_strength = gr.Slider(
222
- label="Local Strength",
223
- minimum=0.0,
224
- maximum=5.0,
225
- value=1.0,
226
- step=0.01,
227
- interactive=True
228
- )
229
- seed = gr.Number(
230
- label="Seed",
231
- value=-1,
232
- precision=0,
233
- interactive=True
234
- )
235
- steps = gr.Slider(
236
- label="Steps",
237
- minimum=0,
238
- maximum=50,
239
- value=20,
240
- interactive=True
241
- )
242
- cfg = gr.Slider(
243
- label="CFG",
244
- minimum=0.0,
245
- maximum=20.0,
246
- value=3.5,
247
- step=0.1,
248
- interactive=True
249
- )
250
-
251
- btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms)
252
-
253
- app = FastAPI()
254
- app.add_middleware(
255
- CORSMiddleware,
256
- allow_origins=['*'],
257
- allow_credentials=True,
258
- allow_methods=["*"],
259
- allow_headers=["*"],
260
  )
261
 
262
- def get_root_url(
263
- request: Request, route_path: str, root_path: str | None
264
- ):
265
- print(root_path)
266
- return root_path
267
- import gradio.route_utils
268
- gr.route_utils.get_root_url = get_root_url
269
-
270
- # @app.post("/magic_quill/generate_image")
271
- # async def generate_image(request: Request):
272
- # data = await request.json()
273
- # res = generate(
274
- # data["merged_image"],
275
- # data["total_mask"],
276
- # data["original_image"],
277
- # data["add_color_image"],
278
- # data["add_edge_mask"],
279
- # data["remove_edge_mask"],
280
- # data["fill_mask"],
281
- # data["add_prop_image"],
282
- # data["positive_prompt"],
283
- # data["negative_prompt"],
284
- # data["fine_edge"],
285
- # data["fix_perspective"],
286
- # data["grow_size"],
287
- # data["edge_strength"],
288
- # data["color_strength"],
289
- # data["local_strength"],
290
- # data["seed"],
291
- # data["steps"],
292
- # data["cfg"]
293
- # )
294
- # return {'res': res}
295
-
296
- @app.post("/magic_quill/process_background_img")
297
- async def process_background_img(request: Request):
298
- img = await request.json()
299
- from util import process_background
300
- # process_background returns tensor [1, H, W, 3] in uint8 or float
301
- resized_img_tensor = process_background(img)
302
-
303
- # tensor_to_base64 from util expects tensor
304
- resized_img_base64 = "data:image/webp;base64," + tensor_to_base64(
305
- resized_img_tensor,
306
- quality=80,
307
- method=6
308
- )
309
- return resized_img_base64
310
-
311
- @app.post("/magic_quill/segmentation")
312
- async def segmentation(request: Request):
313
- json_data = await request.json()
314
- image_base64 = json_data.get("image", None)
315
- coordinates_positive = json_data.get("coordinates_positive", None)
316
- coordinates_negative = json_data.get("coordinates_negative", None)
317
- bboxes = json_data.get("bboxes", None)
318
-
319
- if sam_client is None:
320
- return {"error": "sam client not initialized"}
321
-
322
- # Process coordinates and bboxes
323
- pos_coordinates = None
324
- if coordinates_positive and len(coordinates_positive) > 0:
325
- pos_coordinates = []
326
- for coord in coordinates_positive:
327
- coord['x'] = int(round(coord['x']))
328
- coord['y'] = int(round(coord['y']))
329
- pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
330
- pos_coordinates = json.dumps(pos_coordinates)
331
-
332
- neg_coordinates = None
333
- if coordinates_negative and len(coordinates_negative) > 0:
334
- neg_coordinates = []
335
- for coord in coordinates_negative:
336
- coord['x'] = int(round(coord['x']))
337
- coord['y'] = int(round(coord['y']))
338
- neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
339
- neg_coordinates = json.dumps(neg_coordinates)
340
-
341
- bboxes_xyxy = None
342
- if bboxes and len(bboxes) > 0:
343
- valid_bboxes = []
344
- for bbox in bboxes:
345
- if (bbox.get("startX") is None or
346
- bbox.get("startY") is None or
347
- bbox.get("endX") is None or
348
- bbox.get("endY") is None):
349
- continue
350
- else:
351
- x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0)
352
- y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0)
353
- # Note: image_tensor not available here easily without loading image,
354
- # but usually we don't need to clip strictly if SAM handles it or we clip to large values
355
- # For now, we skip strict clipping against image dims or assume 10000
356
- x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"])
357
- y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"])
358
- valid_bboxes.append((x_min, y_min, x_max, y_max))
359
-
360
- bboxes_xyxy = []
361
- for bbox in valid_bboxes:
362
- x_min, y_min, x_max, y_max = bbox
363
- bboxes_xyxy.append((x_min, y_min, x_max, y_max))
364
-
365
- # Convert to JSON string if that's what the client expects, or keep as list
366
- # Assuming JSON string for consistency with coords
367
- if bboxes_xyxy:
368
- bboxes_xyxy = json.dumps(bboxes_xyxy)
369
-
370
- print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}")
371
-
372
- try:
373
- # Save base64 image to temp file
374
- image_bytes = read_base64_image_utils(image_base64)
375
- # Image.open to verify and save as WebP (smaller size)
376
- pil_image = Image.open(image_bytes)
377
- with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in:
378
- pil_image.save(temp_in.name, format="WEBP", quality=80)
379
- temp_in_path = temp_in.name
380
-
381
- # Execute segmentation via Client
382
- # We assume the remote space returns a filepath to the segmented image (with alpha)
383
- # NOW it returns mask_np image
384
- result_path = sam_client.predict(
385
- handle_file(temp_in_path),
386
- pos_coordinates,
387
- neg_coordinates,
388
- bboxes_xyxy,
389
- api_name="/segment"
390
- )
391
-
392
- # Clean up input temp
393
- os.unlink(temp_in_path)
394
-
395
- # Process result
396
- # result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path
397
- # Depending on how the remote space is implemented.
398
- if isinstance(result_path, (list, tuple)):
399
- result_path = result_path[0] # Take the first return value if multiple
400
-
401
- if not result_path or not os.path.exists(result_path):
402
- raise RuntimeError("Client returned invalid result path")
403
-
404
- # result_path is the Mask Image (White=Selected, Black=Background)
405
- mask_pil = Image.open(result_path)
406
- if mask_pil.mode != 'L':
407
- mask_pil = mask_pil.convert('L')
408
-
409
- pil_image = pil_image.convert("RGB")
410
- if pil_image.size != mask_pil.size:
411
- mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST)
412
-
413
- r, g, b = pil_image.split()
414
- res_pil = Image.merge("RGBA", (r, g, b, mask_pil))
415
-
416
- # Extract bbox from mask (alpha)
417
- mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0)
418
- mask_bbox = get_mask_bbox(mask_tensor)
419
- if mask_bbox:
420
- x_min, y_min, x_max, y_max = mask_bbox
421
- seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
422
- else:
423
- seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
424
-
425
- print(seg_bbox)
426
-
427
- # Convert result to base64
428
- # We need to convert the PIL image to base64 string
429
- buffered = io.BytesIO()
430
- res_pil.save(buffered, format="PNG")
431
- image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
432
-
433
- return {
434
- "error": False,
435
- "segmentation_image": "data:image/png;base64," + image_base64_res,
436
- "segmentation_bbox": seg_bbox
437
- }
438
-
439
- except Exception as e:
440
- print(f"Error in segmentation: {e}")
441
- return {"error": str(e)}
442
-
443
- app = gr.mount_gradio_app(app, demo, "/")
444
-
445
  if __name__ == "__main__":
446
- # uvicorn.run(app, host="0.0.0.0", port=7860)
447
  demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
+ from huggingface_hub import snapshot_download
 
 
6
  from edit_space import KontextEditModel
7
  from util import (
8
  load_and_preprocess_image,
9
  read_base64_image as read_base64_image_utils,
10
  create_alpha_mask,
11
  tensor_to_base64,
 
12
  )
13
 
14
  # Initialize models
 
18
 
19
  print("Initializing models...")
20
  kontext_model = KontextEditModel()
 
 
 
 
21
  print("Models initialized.")
22
 
23
+ @spaces.GPU
 
 
 
 
 
 
 
 
24
  def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
25
  print("prompt is:", positive_prompt)
26
  print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
 
29
  raise RuntimeError("KontextEditModel not initialized")
30
 
31
  # Preprocess inputs
 
 
 
 
32
  merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
33
  total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
34
  original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
 
85
  res_base64 = tensor_to_base64(final_image)
86
  return res_base64
87
 
88
+ # Create Gradio Interface
89
+ # All image inputs are passed as base64 strings (Textboxes)
90
+ inputs = [
91
+ gr.Textbox(label="merged_image"),
92
+ gr.Textbox(label="total_mask"),
93
+ gr.Textbox(label="original_image"),
94
+ gr.Textbox(label="add_color_image"),
95
+ gr.Textbox(label="add_edge_mask"),
96
+ gr.Textbox(label="remove_edge_mask"),
97
+ gr.Textbox(label="fill_mask"),
98
+ gr.Textbox(label="add_prop_image"),
99
+ gr.Textbox(label="positive_prompt"),
100
+ gr.Textbox(label="negative_prompt"),
101
+ gr.Textbox(label="fine_edge"),
102
+ gr.Textbox(label="fix_perspective"),
103
+ gr.Number(label="grow_size"),
104
+ gr.Number(label="edge_strength"),
105
+ gr.Number(label="color_strength"),
106
+ gr.Number(label="local_strength"),
107
+ gr.Number(label="seed"),
108
+ gr.Number(label="steps"),
109
+ gr.Number(label="cfg"),
110
+ ]
111
+
112
+ outputs = gr.Textbox(label="generated_image_base64")
113
+
114
+ demo = gr.Interface(
115
+ fn=generate,
116
+ inputs=inputs,
117
+ outputs=outputs,
118
+ api_name="generate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if __name__ == "__main__":
 
122
  demo.launch()
123
+