Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,736 Bytes
f460ce6 c57bc42 f460ce6 08a4792 f460ce6 05bce8d c57bc42 f460ce6 3847fca f460ce6 08a4792 f460ce6 54b1753 f460ce6 460ae4e f460ce6 08a4792 6fc568e f460ce6 0d43e95 08a4792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from edit_space import KontextEditModel
from util import (
load_and_preprocess_image,
read_base64_image as read_base64_image_utils,
create_alpha_mask,
tensor_to_base64,
)
import random
# Initialize models
print("Downloading models...")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models")
print("Initializing models...")
kontext_model = KontextEditModel()
print("Models initialized.")
@spaces.GPU
def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
print("prompt is:", positive_prompt)
print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
steps = int(steps)
if kontext_model is None:
raise RuntimeError("KontextEditModel not initialized")
# Preprocess inputs
merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
else:
add_color_image_tensor = original_image_tensor
add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor)
remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)
if seed == -1:
seed = random.randint(0, 2**32 - 1)
# Determine flag and modify prompt
flag = "kontext"
if torch.sum(add_prop_mask) > 0:
flag = "foreground"
positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
elif torch.sum(fill_mask_tensor).item() > 0:
flag = "local"
elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
positive_prompt = "remove the instance"
flag = "removal"
elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
flag = "precise_edit"
print("positive prompt: ", positive_prompt)
print("current flag: ", flag)
final_image, condition, mask = kontext_model.process(
original_image_tensor,
add_color_image_tensor,
merged_image_tensor,
positive_prompt,
total_mask_tensor,
add_mask,
remove_mask,
add_prop_mask,
fill_mask_tensor,
fine_edge,
fix_perspective,
edge_strength,
color_strength,
local_strength,
grow_size,
seed,
steps,
cfg,
flag,
)
# tensor_to_base64 returns pure base64 string
res_base64 = tensor_to_base64(final_image)
return res_base64
# Create Gradio Interface
# All image inputs are passed as base64 strings (Textboxes)
inputs = [
gr.Textbox(label="merged_image"),
gr.Textbox(label="total_mask"),
gr.Textbox(label="original_image"),
gr.Textbox(label="add_color_image"),
gr.Textbox(label="add_edge_mask"),
gr.Textbox(label="remove_edge_mask"),
gr.Textbox(label="fill_mask"),
gr.Textbox(label="add_prop_image"),
gr.Textbox(label="positive_prompt"),
gr.Textbox(label="negative_prompt"),
gr.Textbox(label="fine_edge"),
gr.Textbox(label="fix_perspective"),
gr.Number(label="grow_size"),
gr.Number(label="edge_strength"),
gr.Number(label="color_strength"),
gr.Number(label="local_strength"),
gr.Number(label="seed"),
gr.Number(label="steps"),
gr.Number(label="cfg"),
]
outputs = gr.Textbox(label="generated_image_base64")
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
api_name="generate",
concurrency_limit=20
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(max_threads=20)
|