Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from edit_space import KontextEditModel | |
| from util import ( | |
| load_and_preprocess_image, | |
| read_base64_image as read_base64_image_utils, | |
| create_alpha_mask, | |
| tensor_to_base64, | |
| ) | |
| import random | |
| # Initialize models | |
| print("Downloading models...") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token) | |
| print("Initializing models...") | |
| kontext_model = KontextEditModel() | |
| print("Models initialized.") | |
| def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): | |
| print("prompt is:", positive_prompt) | |
| print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg) | |
| if kontext_model is None: | |
| raise RuntimeError("KontextEditModel not initialized") | |
| # Preprocess inputs | |
| merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image)) | |
| total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask)) | |
| original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image)) | |
| if add_color_image: | |
| add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image)) | |
| else: | |
| add_color_image_tensor = original_image_tensor | |
| add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor) | |
| remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor) | |
| add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor) | |
| fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor) | |
| if seed == -1: | |
| seed = random.randint(0, 2**32 - 1) | |
| # Determine flag and modify prompt | |
| flag = "kontext" | |
| if torch.sum(add_prop_mask) > 0: | |
| flag = "foreground" | |
| positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt | |
| elif torch.sum(fill_mask_tensor).item() > 0: | |
| flag = "local" | |
| elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0): | |
| positive_prompt = "remove the instance" | |
| flag = "removal" | |
| elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))): | |
| flag = "precise_edit" | |
| print("positive prompt: ", positive_prompt) | |
| print("current flag: ", flag) | |
| final_image, condition, mask = kontext_model.process( | |
| original_image_tensor, | |
| add_color_image_tensor, | |
| merged_image_tensor, | |
| positive_prompt, | |
| total_mask_tensor, | |
| add_mask, | |
| remove_mask, | |
| add_prop_mask, | |
| fill_mask_tensor, | |
| fine_edge, | |
| fix_perspective, | |
| edge_strength, | |
| color_strength, | |
| local_strength, | |
| grow_size, | |
| seed, | |
| steps, | |
| cfg, | |
| flag, | |
| ) | |
| # tensor_to_base64 returns pure base64 string | |
| res_base64 = tensor_to_base64(final_image) | |
| return res_base64 | |
| # Create Gradio Interface | |
| # All image inputs are passed as base64 strings (Textboxes) | |
| inputs = [ | |
| gr.Textbox(label="merged_image"), | |
| gr.Textbox(label="total_mask"), | |
| gr.Textbox(label="original_image"), | |
| gr.Textbox(label="add_color_image"), | |
| gr.Textbox(label="add_edge_mask"), | |
| gr.Textbox(label="remove_edge_mask"), | |
| gr.Textbox(label="fill_mask"), | |
| gr.Textbox(label="add_prop_image"), | |
| gr.Textbox(label="positive_prompt"), | |
| gr.Textbox(label="negative_prompt"), | |
| gr.Textbox(label="fine_edge"), | |
| gr.Textbox(label="fix_perspective"), | |
| gr.Number(label="grow_size"), | |
| gr.Number(label="edge_strength"), | |
| gr.Number(label="color_strength"), | |
| gr.Number(label="local_strength"), | |
| gr.Number(label="seed"), | |
| gr.Number(label="steps"), | |
| gr.Number(label="cfg"), | |
| ] | |
| outputs = gr.Textbox(label="generated_image_base64") | |
| demo = gr.Interface( | |
| fn=generate, | |
| inputs=inputs, | |
| outputs=outputs, | |
| api_name="generate", | |
| concurrency_limit=20 | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(max_threads=20) | |