MagicQuillV2 / app.py
LiuZichen's picture
Update app.py
0d43e95 verified
raw
history blame
4.77 kB
import os
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from edit_space import KontextEditModel
from util import (
load_and_preprocess_image,
read_base64_image as read_base64_image_utils,
create_alpha_mask,
tensor_to_base64,
)
import random
# Initialize models
print("Downloading models...")
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token)
print("Initializing models...")
kontext_model = KontextEditModel()
print("Models initialized.")
@spaces.GPU
def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
print("prompt is:", positive_prompt)
print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
if kontext_model is None:
raise RuntimeError("KontextEditModel not initialized")
# Preprocess inputs
merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
else:
add_color_image_tensor = original_image_tensor
add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor)
remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)
if seed == -1:
seed = random.randint(0, 2**32 - 1)
# Determine flag and modify prompt
flag = "kontext"
if torch.sum(add_prop_mask) > 0:
flag = "foreground"
positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
elif torch.sum(fill_mask_tensor).item() > 0:
flag = "local"
elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
positive_prompt = "remove the instance"
flag = "removal"
elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
flag = "precise_edit"
print("positive prompt: ", positive_prompt)
print("current flag: ", flag)
final_image, condition, mask = kontext_model.process(
original_image_tensor,
add_color_image_tensor,
merged_image_tensor,
positive_prompt,
total_mask_tensor,
add_mask,
remove_mask,
add_prop_mask,
fill_mask_tensor,
fine_edge,
fix_perspective,
edge_strength,
color_strength,
local_strength,
grow_size,
seed,
steps,
cfg,
flag,
)
# tensor_to_base64 returns pure base64 string
res_base64 = tensor_to_base64(final_image)
return res_base64
# Create Gradio Interface
# All image inputs are passed as base64 strings (Textboxes)
inputs = [
gr.Textbox(label="merged_image"),
gr.Textbox(label="total_mask"),
gr.Textbox(label="original_image"),
gr.Textbox(label="add_color_image"),
gr.Textbox(label="add_edge_mask"),
gr.Textbox(label="remove_edge_mask"),
gr.Textbox(label="fill_mask"),
gr.Textbox(label="add_prop_image"),
gr.Textbox(label="positive_prompt"),
gr.Textbox(label="negative_prompt"),
gr.Textbox(label="fine_edge"),
gr.Textbox(label="fix_perspective"),
gr.Number(label="grow_size"),
gr.Number(label="edge_strength"),
gr.Number(label="color_strength"),
gr.Number(label="local_strength"),
gr.Number(label="seed"),
gr.Number(label="steps"),
gr.Number(label="cfg"),
]
outputs = gr.Textbox(label="generated_image_base64")
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
api_name="generate",
concurrency_limit=20
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(max_threads=20)