File size: 4,736 Bytes
f460ce6
c57bc42
f460ce6
 
08a4792
f460ce6
 
 
 
 
 
 
05bce8d
c57bc42
f460ce6
 
3847fca
f460ce6
 
 
 
 
08a4792
f460ce6
 
 
54b1753
 
f460ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460ae4e
 
 
f460ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08a4792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc568e
 
f460ce6
 
 
0d43e95
08a4792
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from edit_space import KontextEditModel
from util import (
    load_and_preprocess_image,
    read_base64_image as read_base64_image_utils,
    create_alpha_mask,
    tensor_to_base64,
)
import random

# Initialize models
print("Downloading models...")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models")

print("Initializing models...")
kontext_model = KontextEditModel()
print("Models initialized.")

@spaces.GPU
def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
    print("prompt is:", positive_prompt)
    print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)

    steps = int(steps)
    
    if kontext_model is None:
        raise RuntimeError("KontextEditModel not initialized")

    # Preprocess inputs
    merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
    total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
    original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
    
    if add_color_image:
        add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
    else:
        add_color_image_tensor = original_image_tensor
        
    add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor) 
    remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
    add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
    fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)

    if seed == -1:
        seed = random.randint(0, 2**32 - 1)
    
    # Determine flag and modify prompt
    flag = "kontext"
    if torch.sum(add_prop_mask) > 0:
        flag = "foreground"
        positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
    elif torch.sum(fill_mask_tensor).item() > 0:
        flag = "local"
    elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
        positive_prompt = "remove the instance"
        flag = "removal"
    elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
        flag = "precise_edit"
    
    print("positive prompt: ", positive_prompt)
    print("current flag: ", flag)

    final_image, condition, mask = kontext_model.process(
        original_image_tensor,
        add_color_image_tensor,
        merged_image_tensor,
        positive_prompt,
        total_mask_tensor,
        add_mask,
        remove_mask,
        add_prop_mask,
        fill_mask_tensor,
        fine_edge,
        fix_perspective,
        edge_strength,
        color_strength,
        local_strength,
        grow_size,
        seed,
        steps,
        cfg,
        flag,
    )

    # tensor_to_base64 returns pure base64 string
    res_base64 = tensor_to_base64(final_image)
    return res_base64

# Create Gradio Interface
# All image inputs are passed as base64 strings (Textboxes)
inputs = [
    gr.Textbox(label="merged_image"),
    gr.Textbox(label="total_mask"),
    gr.Textbox(label="original_image"),
    gr.Textbox(label="add_color_image"),
    gr.Textbox(label="add_edge_mask"),
    gr.Textbox(label="remove_edge_mask"),
    gr.Textbox(label="fill_mask"),
    gr.Textbox(label="add_prop_image"),
    gr.Textbox(label="positive_prompt"),
    gr.Textbox(label="negative_prompt"),
    gr.Textbox(label="fine_edge"),
    gr.Textbox(label="fix_perspective"),
    gr.Number(label="grow_size"),
    gr.Number(label="edge_strength"),
    gr.Number(label="color_strength"),
    gr.Number(label="local_strength"),
    gr.Number(label="seed"),
    gr.Number(label="steps"),
    gr.Number(label="cfg"),
]

outputs = gr.Textbox(label="generated_image_base64")

demo = gr.Interface(
    fn=generate,
    inputs=inputs,
    outputs=outputs,
    api_name="generate",
    concurrency_limit=20
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch(max_threads=20)