import os import gradio as gr import spaces import torch from huggingface_hub import snapshot_download from edit_space import KontextEditModel from util import ( load_and_preprocess_image, read_base64_image as read_base64_image_utils, create_alpha_mask, tensor_to_base64, ) import random # Initialize models print("Downloading models...") hf_token = os.environ.get("HF_TOKEN") snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token) print("Initializing models...") kontext_model = KontextEditModel() print("Models initialized.") @spaces.GPU def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): print("prompt is:", positive_prompt) print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg) if kontext_model is None: raise RuntimeError("KontextEditModel not initialized") # Preprocess inputs merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image)) total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask)) original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image)) if add_color_image: add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image)) else: add_color_image_tensor = original_image_tensor add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor) remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor) add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor) fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor) if seed == -1: seed = random.randint(0, 2**32 - 1) # Determine flag and modify prompt flag = "kontext" if torch.sum(add_prop_mask) > 0: flag = "foreground" positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt elif torch.sum(fill_mask_tensor).item() > 0: flag = "local" elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0): positive_prompt = "remove the instance" flag = "removal" elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))): flag = "precise_edit" print("positive prompt: ", positive_prompt) print("current flag: ", flag) final_image, condition, mask = kontext_model.process( original_image_tensor, add_color_image_tensor, merged_image_tensor, positive_prompt, total_mask_tensor, add_mask, remove_mask, add_prop_mask, fill_mask_tensor, fine_edge, fix_perspective, edge_strength, color_strength, local_strength, grow_size, seed, steps, cfg, flag, ) # tensor_to_base64 returns pure base64 string res_base64 = tensor_to_base64(final_image) return res_base64 # Create Gradio Interface # All image inputs are passed as base64 strings (Textboxes) inputs = [ gr.Textbox(label="merged_image"), gr.Textbox(label="total_mask"), gr.Textbox(label="original_image"), gr.Textbox(label="add_color_image"), gr.Textbox(label="add_edge_mask"), gr.Textbox(label="remove_edge_mask"), gr.Textbox(label="fill_mask"), gr.Textbox(label="add_prop_image"), gr.Textbox(label="positive_prompt"), gr.Textbox(label="negative_prompt"), gr.Textbox(label="fine_edge"), gr.Textbox(label="fix_perspective"), gr.Number(label="grow_size"), gr.Number(label="edge_strength"), gr.Number(label="color_strength"), gr.Number(label="local_strength"), gr.Number(label="seed"), gr.Number(label="steps"), gr.Number(label="cfg"), ] outputs = gr.Textbox(label="generated_image_base64") demo = gr.Interface( fn=generate, inputs=inputs, outputs=outputs, api_name="generate", concurrency_limit=20 ) if __name__ == "__main__": demo.queue(max_size=20).launch(max_threads=20)