Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- README.md +2 -2
- app.py +439 -4
- edit_space.py +461 -0
- requirements.txt +28 -0
- src/__init__.py +0 -0
- src/layers_cache.py +406 -0
- src/lora_helper.py +194 -0
- src/pipeline_flux_kontext_control.py +1230 -0
- src/transformer_flux.py +608 -0
- train/default_config.yaml +16 -0
- train/src/__init__.py +0 -0
- train/src/condition/edge_extraction.py +356 -0
- train/src/condition/hed.py +56 -0
- train/src/condition/informative_drawing.py +279 -0
- train/src/condition/lineart.py +86 -0
- train/src/condition/pidi.py +681 -0
- train/src/condition/ted.py +296 -0
- train/src/condition/util.py +202 -0
- train/src/generate_diff_mask.py +301 -0
- train/src/jsonl_datasets_kontext_color.py +166 -0
- train/src/jsonl_datasets_kontext_complete_lora.py +363 -0
- train/src/jsonl_datasets_kontext_edge.py +225 -0
- train/src/jsonl_datasets_kontext_interactive_lora.py +1332 -0
- train/src/jsonl_datasets_kontext_local.py +312 -0
- train/src/layers.py +279 -0
- train/src/lora_helper.py +196 -0
- train/src/masks_integrated.py +322 -0
- train/src/pipeline_flux_kontext_control.py +1009 -0
- train/src/prompt_helper.py +205 -0
- train/src/transformer_flux.py +625 -0
- train/train_kontext_color.py +858 -0
- train/train_kontext_color.sh +25 -0
- train/train_kontext_complete_lora.sh +20 -0
- train/train_kontext_edge.py +814 -0
- train/train_kontext_edge.sh +25 -0
- train/train_kontext_interactive_lora.sh +18 -0
- train/train_kontext_local.py +876 -0
- train/train_kontext_local.sh +26 -0
- train/train_kontext_lora.py +871 -0
- util.py +188 -0
- utils_node.py +199 -0
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: MagicQuillV2
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 1 |
---
|
| 2 |
title: MagicQuillV2
|
| 3 |
+
emoji: 🪶
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.4.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
app.py
CHANGED
|
@@ -1,7 +1,442 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
import gradio as gr
|
| 4 |
+
import spaces
|
| 5 |
+
import tempfile
|
| 6 |
+
import numpy as np
|
| 7 |
+
import io
|
| 8 |
+
import base64
|
| 9 |
+
from gradio_client import Client, handle_file
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
from gradio_magicquillv2 import MagicQuillV2
|
| 12 |
+
from fastapi import FastAPI, Request
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
import uvicorn
|
| 15 |
+
import requests
|
| 16 |
+
from PIL import Image, ImageOps
|
| 17 |
+
import random
|
| 18 |
+
import time
|
| 19 |
+
import torch
|
| 20 |
+
import json
|
| 21 |
|
| 22 |
+
# Try importing as a package (recommended)
|
| 23 |
+
from edit_space import KontextEditModel
|
| 24 |
+
from util import (
|
| 25 |
+
load_and_preprocess_image,
|
| 26 |
+
read_base64_image as read_base64_image_utils,
|
| 27 |
+
create_alpha_mask,
|
| 28 |
+
tensor_to_base64,
|
| 29 |
+
get_mask_bbox
|
| 30 |
+
)
|
| 31 |
|
| 32 |
+
# Initialize models
|
| 33 |
+
print("Downloading models...")
|
| 34 |
+
hf_token = os.environ.get("hf_token")
|
| 35 |
+
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token)
|
| 36 |
+
|
| 37 |
+
print("Initializing models...")
|
| 38 |
+
kontext_model = KontextEditModel()
|
| 39 |
+
|
| 40 |
+
# Initialize SAM Client
|
| 41 |
+
# Replace with your actual SAM Space ID
|
| 42 |
+
sam_client = Client("LiuZichen/MagicQuillHelper")
|
| 43 |
+
print("Models initialized.")
|
| 44 |
+
|
| 45 |
+
css = """
|
| 46 |
+
.ms {
|
| 47 |
+
width: 60%;
|
| 48 |
+
margin: auto
|
| 49 |
+
}
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
url = "http://localhost:7860"
|
| 53 |
+
|
| 54 |
+
@spaces.GPU
|
| 55 |
+
def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
|
| 56 |
+
print("prompt is:", positive_prompt)
|
| 57 |
+
print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
|
| 58 |
+
|
| 59 |
+
if kontext_model is None:
|
| 60 |
+
raise RuntimeError("KontextEditModel not initialized")
|
| 61 |
+
|
| 62 |
+
# Preprocess inputs
|
| 63 |
+
# utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open)
|
| 64 |
+
# load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object.
|
| 65 |
+
# utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works.
|
| 66 |
+
|
| 67 |
+
merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
|
| 68 |
+
total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
|
| 69 |
+
original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
|
| 70 |
+
|
| 71 |
+
if add_color_image:
|
| 72 |
+
add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
|
| 73 |
+
else:
|
| 74 |
+
add_color_image_tensor = original_image_tensor
|
| 75 |
+
|
| 76 |
+
add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor)
|
| 77 |
+
remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
|
| 78 |
+
add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
|
| 79 |
+
fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)
|
| 80 |
+
|
| 81 |
+
# Determine flag and modify prompt
|
| 82 |
+
flag = "kontext"
|
| 83 |
+
if torch.sum(add_prop_mask) > 0:
|
| 84 |
+
flag = "foreground"
|
| 85 |
+
positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
|
| 86 |
+
elif torch.sum(fill_mask_tensor).item() > 0:
|
| 87 |
+
flag = "local"
|
| 88 |
+
elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
|
| 89 |
+
positive_prompt = "remove the instance"
|
| 90 |
+
flag = "removal"
|
| 91 |
+
elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
|
| 92 |
+
flag = "precise_edit"
|
| 93 |
+
|
| 94 |
+
print("positive prompt: ", positive_prompt)
|
| 95 |
+
print("current flag: ", flag)
|
| 96 |
+
|
| 97 |
+
final_image, condition, mask = kontext_model.process(
|
| 98 |
+
original_image_tensor,
|
| 99 |
+
add_color_image_tensor,
|
| 100 |
+
merged_image_tensor,
|
| 101 |
+
positive_prompt,
|
| 102 |
+
total_mask_tensor,
|
| 103 |
+
add_mask,
|
| 104 |
+
remove_mask,
|
| 105 |
+
add_prop_mask,
|
| 106 |
+
fill_mask_tensor,
|
| 107 |
+
fine_edge,
|
| 108 |
+
fix_perspective,
|
| 109 |
+
edge_strength,
|
| 110 |
+
color_strength,
|
| 111 |
+
local_strength,
|
| 112 |
+
grow_size,
|
| 113 |
+
seed,
|
| 114 |
+
steps,
|
| 115 |
+
cfg,
|
| 116 |
+
flag,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# tensor_to_base64 returns pure base64 string
|
| 120 |
+
res_base64 = tensor_to_base64(final_image)
|
| 121 |
+
return res_base64
|
| 122 |
+
|
| 123 |
+
def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
|
| 124 |
+
merged_image = x['from_frontend']['img']
|
| 125 |
+
total_mask = x['from_frontend']['total_mask']
|
| 126 |
+
original_image = x['from_frontend']['original_image']
|
| 127 |
+
add_color_image = x['from_frontend']['add_color_image']
|
| 128 |
+
add_edge_mask = x['from_frontend']['add_edge_mask']
|
| 129 |
+
remove_edge_mask = x['from_frontend']['remove_edge_mask']
|
| 130 |
+
fill_mask = x['from_frontend']['fill_mask']
|
| 131 |
+
add_prop_image = x['from_frontend']['add_prop_image']
|
| 132 |
+
positive_prompt = x['from_backend']['prompt']
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
res_base64 = generate(
|
| 136 |
+
merged_image,
|
| 137 |
+
total_mask,
|
| 138 |
+
original_image,
|
| 139 |
+
add_color_image,
|
| 140 |
+
add_edge_mask,
|
| 141 |
+
remove_edge_mask,
|
| 142 |
+
fill_mask,
|
| 143 |
+
add_prop_image,
|
| 144 |
+
positive_prompt,
|
| 145 |
+
negative_prompt,
|
| 146 |
+
fine_edge,
|
| 147 |
+
fix_perspective,
|
| 148 |
+
grow_size,
|
| 149 |
+
edge_strength,
|
| 150 |
+
color_strength,
|
| 151 |
+
local_strength,
|
| 152 |
+
seed,
|
| 153 |
+
steps,
|
| 154 |
+
cfg
|
| 155 |
+
)
|
| 156 |
+
x["from_backend"]["generated_image"] = res_base64
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error in generation: {e}")
|
| 159 |
+
x["from_backend"]["generated_image"] = None
|
| 160 |
+
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
with gr.Blocks(title="MagicQuill V2") as demo:
|
| 165 |
+
with gr.Row():
|
| 166 |
+
ms = MagicQuillV2()
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
with gr.Column():
|
| 170 |
+
btn = gr.Button("Run", variant="primary")
|
| 171 |
+
with gr.Column():
|
| 172 |
+
with gr.Accordion("parameters", open=False):
|
| 173 |
+
negative_prompt = gr.Textbox(
|
| 174 |
+
label="Negative Prompt",
|
| 175 |
+
value="",
|
| 176 |
+
interactive=True
|
| 177 |
+
)
|
| 178 |
+
fine_edge = gr.Radio(
|
| 179 |
+
label="Fine Edge",
|
| 180 |
+
choices=['enable', 'disable'],
|
| 181 |
+
value='disable',
|
| 182 |
+
interactive=True
|
| 183 |
+
)
|
| 184 |
+
fix_perspective = gr.Radio(
|
| 185 |
+
label="Fix Perspective",
|
| 186 |
+
choices=['enable', 'disable'],
|
| 187 |
+
value='disable',
|
| 188 |
+
interactive=True
|
| 189 |
+
)
|
| 190 |
+
grow_size = gr.Slider(
|
| 191 |
+
label="Grow Size",
|
| 192 |
+
minimum=10,
|
| 193 |
+
maximum=100,
|
| 194 |
+
value=50,
|
| 195 |
+
step=1,
|
| 196 |
+
interactive=True
|
| 197 |
+
)
|
| 198 |
+
edge_strength = gr.Slider(
|
| 199 |
+
label="Edge Strength",
|
| 200 |
+
minimum=0.0,
|
| 201 |
+
maximum=5.0,
|
| 202 |
+
value=0.6,
|
| 203 |
+
step=0.01,
|
| 204 |
+
interactive=True
|
| 205 |
+
)
|
| 206 |
+
color_strength = gr.Slider(
|
| 207 |
+
label="Color Strength",
|
| 208 |
+
minimum=0.0,
|
| 209 |
+
maximum=5.0,
|
| 210 |
+
value=1.5,
|
| 211 |
+
step=0.01,
|
| 212 |
+
interactive=True
|
| 213 |
+
)
|
| 214 |
+
local_strength = gr.Slider(
|
| 215 |
+
label="Local Strength",
|
| 216 |
+
minimum=0.0,
|
| 217 |
+
maximum=5.0,
|
| 218 |
+
value=1.0,
|
| 219 |
+
step=0.01,
|
| 220 |
+
interactive=True
|
| 221 |
+
)
|
| 222 |
+
seed = gr.Number(
|
| 223 |
+
label="Seed",
|
| 224 |
+
value=-1,
|
| 225 |
+
precision=0,
|
| 226 |
+
interactive=True
|
| 227 |
+
)
|
| 228 |
+
steps = gr.Slider(
|
| 229 |
+
label="Steps",
|
| 230 |
+
minimum=0,
|
| 231 |
+
maximum=50,
|
| 232 |
+
value=20,
|
| 233 |
+
interactive=True
|
| 234 |
+
)
|
| 235 |
+
cfg = gr.Slider(
|
| 236 |
+
label="CFG",
|
| 237 |
+
minimum=0.0,
|
| 238 |
+
maximum=20.0,
|
| 239 |
+
value=3.5,
|
| 240 |
+
step=0.1,
|
| 241 |
+
interactive=True
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms)
|
| 245 |
+
|
| 246 |
+
app = FastAPI()
|
| 247 |
+
app.add_middleware(
|
| 248 |
+
CORSMiddleware,
|
| 249 |
+
allow_origins=['*'],
|
| 250 |
+
allow_credentials=True,
|
| 251 |
+
allow_methods=["*"],
|
| 252 |
+
allow_headers=["*"],
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def get_root_url(
|
| 256 |
+
request: Request, route_path: str, root_path: str | None
|
| 257 |
+
):
|
| 258 |
+
print(root_path)
|
| 259 |
+
return root_path
|
| 260 |
+
import gradio.route_utils
|
| 261 |
+
gr.route_utils.get_root_url = get_root_url
|
| 262 |
+
|
| 263 |
+
gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo")
|
| 264 |
+
|
| 265 |
+
@app.post("/magic_quill/generate_image")
|
| 266 |
+
async def generate_image(request: Request):
|
| 267 |
+
data = await request.json()
|
| 268 |
+
res = generate(
|
| 269 |
+
data["merged_image"],
|
| 270 |
+
data["total_mask"],
|
| 271 |
+
data["original_image"],
|
| 272 |
+
data["add_color_image"],
|
| 273 |
+
data["add_edge_mask"],
|
| 274 |
+
data["remove_edge_mask"],
|
| 275 |
+
data["fill_mask"],
|
| 276 |
+
data["add_prop_image"],
|
| 277 |
+
data["positive_prompt"],
|
| 278 |
+
data["negative_prompt"],
|
| 279 |
+
data["fine_edge"],
|
| 280 |
+
data["fix_perspective"],
|
| 281 |
+
data["grow_size"],
|
| 282 |
+
data["edge_strength"],
|
| 283 |
+
data["color_strength"],
|
| 284 |
+
data["local_strength"],
|
| 285 |
+
data["seed"],
|
| 286 |
+
data["steps"],
|
| 287 |
+
data["cfg"]
|
| 288 |
+
)
|
| 289 |
+
return {'res': res}
|
| 290 |
+
|
| 291 |
+
@app.post("/magic_quill/process_background_img")
|
| 292 |
+
async def process_background_img(request: Request):
|
| 293 |
+
img = await request.json()
|
| 294 |
+
from util import process_background
|
| 295 |
+
# process_background returns tensor [1, H, W, 3] in uint8 or float
|
| 296 |
+
resized_img_tensor = process_background(img)
|
| 297 |
+
|
| 298 |
+
# tensor_to_base64 from util expects tensor
|
| 299 |
+
resized_img_base64 = "data:image/webp;base64," + tensor_to_base64(
|
| 300 |
+
resized_img_tensor,
|
| 301 |
+
quality=80,
|
| 302 |
+
method=6
|
| 303 |
+
)
|
| 304 |
+
return resized_img_base64
|
| 305 |
+
|
| 306 |
+
@app.post("/magic_quill/segmentation")
|
| 307 |
+
async def segmentation(request: Request):
|
| 308 |
+
json_data = await request.json()
|
| 309 |
+
image_base64 = json_data.get("image", None)
|
| 310 |
+
coordinates_positive = json_data.get("coordinates_positive", None)
|
| 311 |
+
coordinates_negative = json_data.get("coordinates_negative", None)
|
| 312 |
+
bboxes = json_data.get("bboxes", None)
|
| 313 |
+
|
| 314 |
+
if sam_client is None:
|
| 315 |
+
return {"error": "sam client not initialized"}
|
| 316 |
+
|
| 317 |
+
# Process coordinates and bboxes
|
| 318 |
+
pos_coordinates = None
|
| 319 |
+
if coordinates_positive and len(coordinates_positive) > 0:
|
| 320 |
+
pos_coordinates = []
|
| 321 |
+
for coord in coordinates_positive:
|
| 322 |
+
coord['x'] = int(round(coord['x']))
|
| 323 |
+
coord['y'] = int(round(coord['y']))
|
| 324 |
+
pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
|
| 325 |
+
pos_coordinates = json.dumps(pos_coordinates)
|
| 326 |
+
|
| 327 |
+
neg_coordinates = None
|
| 328 |
+
if coordinates_negative and len(coordinates_negative) > 0:
|
| 329 |
+
neg_coordinates = []
|
| 330 |
+
for coord in coordinates_negative:
|
| 331 |
+
coord['x'] = int(round(coord['x']))
|
| 332 |
+
coord['y'] = int(round(coord['y']))
|
| 333 |
+
neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
|
| 334 |
+
neg_coordinates = json.dumps(neg_coordinates)
|
| 335 |
+
|
| 336 |
+
bboxes_xyxy = None
|
| 337 |
+
if bboxes and len(bboxes) > 0:
|
| 338 |
+
valid_bboxes = []
|
| 339 |
+
for bbox in bboxes:
|
| 340 |
+
if (bbox.get("startX") is None or
|
| 341 |
+
bbox.get("startY") is None or
|
| 342 |
+
bbox.get("endX") is None or
|
| 343 |
+
bbox.get("endY") is None):
|
| 344 |
+
continue
|
| 345 |
+
else:
|
| 346 |
+
x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0)
|
| 347 |
+
y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0)
|
| 348 |
+
# Note: image_tensor not available here easily without loading image,
|
| 349 |
+
# but usually we don't need to clip strictly if SAM handles it or we clip to large values
|
| 350 |
+
# For now, we skip strict clipping against image dims or assume 10000
|
| 351 |
+
x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"])
|
| 352 |
+
y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"])
|
| 353 |
+
valid_bboxes.append((x_min, y_min, x_max, y_max))
|
| 354 |
+
|
| 355 |
+
bboxes_xyxy = []
|
| 356 |
+
for bbox in valid_bboxes:
|
| 357 |
+
x_min, y_min, x_max, y_max = bbox
|
| 358 |
+
bboxes_xyxy.append((x_min, y_min, x_max, y_max))
|
| 359 |
+
|
| 360 |
+
# Convert to JSON string if that's what the client expects, or keep as list
|
| 361 |
+
# Assuming JSON string for consistency with coords
|
| 362 |
+
if bboxes_xyxy:
|
| 363 |
+
bboxes_xyxy = json.dumps(bboxes_xyxy)
|
| 364 |
+
|
| 365 |
+
print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}")
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
# Save base64 image to temp file
|
| 369 |
+
image_bytes = read_base64_image_utils(image_base64)
|
| 370 |
+
# Image.open to verify and save as WebP (smaller size)
|
| 371 |
+
pil_image = Image.open(image_bytes)
|
| 372 |
+
with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in:
|
| 373 |
+
pil_image.save(temp_in.name, format="WEBP", quality=80)
|
| 374 |
+
temp_in_path = temp_in.name
|
| 375 |
+
|
| 376 |
+
# Execute segmentation via Client
|
| 377 |
+
# We assume the remote space returns a filepath to the segmented image (with alpha)
|
| 378 |
+
# NOW it returns mask_np image
|
| 379 |
+
result_path = sam_client.predict(
|
| 380 |
+
handle_file(temp_in_path),
|
| 381 |
+
pos_coordinates,
|
| 382 |
+
neg_coordinates,
|
| 383 |
+
bboxes_xyxy,
|
| 384 |
+
api_name="/segment"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Clean up input temp
|
| 388 |
+
os.unlink(temp_in_path)
|
| 389 |
+
|
| 390 |
+
# Process result
|
| 391 |
+
# result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path
|
| 392 |
+
# Depending on how the remote space is implemented.
|
| 393 |
+
if isinstance(result_path, (list, tuple)):
|
| 394 |
+
result_path = result_path[0] # Take the first return value if multiple
|
| 395 |
+
|
| 396 |
+
if not result_path or not os.path.exists(result_path):
|
| 397 |
+
raise RuntimeError("Client returned invalid result path")
|
| 398 |
+
|
| 399 |
+
# result_path is the Mask Image (White=Selected, Black=Background)
|
| 400 |
+
mask_pil = Image.open(result_path)
|
| 401 |
+
if mask_pil.mode != 'L':
|
| 402 |
+
mask_pil = mask_pil.convert('L')
|
| 403 |
+
|
| 404 |
+
pil_image = pil_image.convert("RGB")
|
| 405 |
+
if pil_image.size != mask_pil.size:
|
| 406 |
+
mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST)
|
| 407 |
+
|
| 408 |
+
r, g, b = pil_image.split()
|
| 409 |
+
res_pil = Image.merge("RGBA", (r, g, b, mask_pil))
|
| 410 |
+
|
| 411 |
+
# Extract bbox from mask (alpha)
|
| 412 |
+
mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0)
|
| 413 |
+
mask_bbox = get_mask_bbox(mask_tensor)
|
| 414 |
+
if mask_bbox:
|
| 415 |
+
x_min, y_min, x_max, y_max = mask_bbox
|
| 416 |
+
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
|
| 417 |
+
else:
|
| 418 |
+
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
|
| 419 |
+
|
| 420 |
+
print(seg_bbox)
|
| 421 |
+
|
| 422 |
+
# Convert result to base64
|
| 423 |
+
# We need to convert the PIL image to base64 string
|
| 424 |
+
buffered = io.BytesIO()
|
| 425 |
+
res_pil.save(buffered, format="PNG")
|
| 426 |
+
image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 427 |
+
|
| 428 |
+
return {
|
| 429 |
+
"error": False,
|
| 430 |
+
"segmentation_image": "data:image/png;base64," + image_base64_res,
|
| 431 |
+
"segmentation_bbox": seg_bbox
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"Error in segmentation: {e}")
|
| 436 |
+
return {"error": str(e)}
|
| 437 |
+
|
| 438 |
+
app = gr.mount_gradio_app(app, demo, "/")
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 442 |
+
# demo.launch()
|
edit_space.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import sys
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# New imports for the diffuser pipeline
|
| 12 |
+
from src.pipeline_flux_kontext_control import FluxKontextControlPipeline
|
| 13 |
+
from src.transformer_flux import FluxTransformer2DModel
|
| 14 |
+
|
| 15 |
+
import tempfile
|
| 16 |
+
from safetensors.torch import load_file, save_file
|
| 17 |
+
|
| 18 |
+
_original_load_lora_weights = FluxKontextControlPipeline.load_lora_weights
|
| 19 |
+
|
| 20 |
+
def _patched_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
|
| 21 |
+
"""自动转换混合格式的 LoRA 并添加 transformer 前缀"""
|
| 22 |
+
weight_name = kwargs.get("weight_name", "pytorch_lora_weights.safetensors")
|
| 23 |
+
|
| 24 |
+
if isinstance(pretrained_model_name_or_path_or_dict, str):
|
| 25 |
+
if os.path.isdir(pretrained_model_name_or_path_or_dict):
|
| 26 |
+
lora_file = os.path.join(pretrained_model_name_or_path_or_dict, weight_name)
|
| 27 |
+
else:
|
| 28 |
+
lora_file = pretrained_model_name_or_path_or_dict
|
| 29 |
+
|
| 30 |
+
if os.path.exists(lora_file):
|
| 31 |
+
state_dict = load_file(lora_file)
|
| 32 |
+
|
| 33 |
+
# 检查是否需要转换格式或添加前缀
|
| 34 |
+
needs_format_conversion = any('lora_A.weight' in k or 'lora_B.weight' in k for k in state_dict.keys())
|
| 35 |
+
needs_prefix = not any(k.startswith('transformer.') for k in state_dict.keys())
|
| 36 |
+
|
| 37 |
+
if needs_format_conversion or needs_prefix:
|
| 38 |
+
print(f"🔄 Processing LoRA: {lora_file}")
|
| 39 |
+
if needs_format_conversion:
|
| 40 |
+
print(f" - Converting PEFT format to diffusers format")
|
| 41 |
+
if needs_prefix:
|
| 42 |
+
print(f" - Adding 'transformer.' prefix to keys")
|
| 43 |
+
|
| 44 |
+
converted_state = {}
|
| 45 |
+
converted_count = 0
|
| 46 |
+
|
| 47 |
+
for key, value in state_dict.items():
|
| 48 |
+
new_key = key
|
| 49 |
+
|
| 50 |
+
# 步骤 1: 转换 PEFT 格式到 diffusers 格式
|
| 51 |
+
if 'lora_A.weight' in new_key:
|
| 52 |
+
new_key = new_key.replace('lora_A.weight', 'lora.down.weight')
|
| 53 |
+
converted_count += 1
|
| 54 |
+
elif 'lora_B.weight' in new_key:
|
| 55 |
+
new_key = new_key.replace('lora_B.weight', 'lora.up.weight')
|
| 56 |
+
converted_count += 1
|
| 57 |
+
|
| 58 |
+
# 步骤 2: 添加 transformer 前缀(如果还没有的话)
|
| 59 |
+
if not new_key.startswith('transformer.'):
|
| 60 |
+
new_key = f'transformer.{new_key}'
|
| 61 |
+
|
| 62 |
+
converted_state[new_key] = value
|
| 63 |
+
|
| 64 |
+
if needs_format_conversion:
|
| 65 |
+
print(f" ✅ Converted {converted_count} PEFT keys")
|
| 66 |
+
print(f" ✅ Total keys: {len(converted_state)}")
|
| 67 |
+
|
| 68 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 69 |
+
temp_file = os.path.join(temp_dir, weight_name)
|
| 70 |
+
save_file(converted_state, temp_file)
|
| 71 |
+
return _original_load_lora_weights(self, temp_dir, **kwargs)
|
| 72 |
+
else:
|
| 73 |
+
print(f"✅ LoRA already in correct format: {lora_file}")
|
| 74 |
+
|
| 75 |
+
# 不需要转换,使用原始方法
|
| 76 |
+
return _original_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs)
|
| 77 |
+
|
| 78 |
+
# 应用 monkey patch
|
| 79 |
+
FluxKontextControlPipeline.load_lora_weights = _patched_load_lora_weights
|
| 80 |
+
print("✅ Monkey patch applied to FluxKontextPipeline.load_lora_weights")
|
| 81 |
+
|
| 82 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 83 |
+
sys.path.append(current_dir)
|
| 84 |
+
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
|
| 85 |
+
sys.path.append(os.path.abspath(os.path.join(current_dir, '..', '..', 'comfy_extras')))
|
| 86 |
+
|
| 87 |
+
from train.src.condition.edge_extraction import InformativeDetector, HEDDetector
|
| 88 |
+
from utils_node import BlendInpaint, JoinImageWithAlpha, GrowMask, InvertMask, ColorDetector
|
| 89 |
+
|
| 90 |
+
TEST_MODE = False
|
| 91 |
+
|
| 92 |
+
class KontextEditModel():
|
| 93 |
+
def __init__(self, base_model_path="/data0/lzc/FLUX.1-Kontext-dev", device="cuda",
|
| 94 |
+
aux_lora_dir="models/v2_ckpt", easycontrol_base_dir="models/v2_ckpt",
|
| 95 |
+
aux_lora_weight_name="puzzle_lora.safetensors",
|
| 96 |
+
aux_lora_weight=1.0):
|
| 97 |
+
# Keep necessary preprocessors
|
| 98 |
+
self.mask_processor = GrowMask()
|
| 99 |
+
self.scribble_processor = HEDDetector.from_pretrained()
|
| 100 |
+
self.lineart_processor = InformativeDetector.from_pretrained()
|
| 101 |
+
self.color_processor = ColorDetector()
|
| 102 |
+
self.blender = BlendInpaint()
|
| 103 |
+
|
| 104 |
+
# Initialize the new pipeline (Kontext version)
|
| 105 |
+
self.device = device
|
| 106 |
+
self.pipe = FluxKontextControlPipeline.from_pretrained(base_model_path, torch_dtype=torch.bfloat16)
|
| 107 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 108 |
+
base_model_path,
|
| 109 |
+
subfolder="transformer",
|
| 110 |
+
torch_dtype=torch.bfloat16,
|
| 111 |
+
device=self.device
|
| 112 |
+
)
|
| 113 |
+
self.pipe.transformer = transformer
|
| 114 |
+
self.pipe.to(self.device, dtype=torch.bfloat16)
|
| 115 |
+
|
| 116 |
+
control_lora_config = {
|
| 117 |
+
"local": {
|
| 118 |
+
"path": os.path.join(easycontrol_base_dir, "local_lora.safetensors"),
|
| 119 |
+
"lora_weights": [1.0],
|
| 120 |
+
"cond_size": 512,
|
| 121 |
+
},
|
| 122 |
+
"removal": {
|
| 123 |
+
"path": os.path.join(easycontrol_base_dir, "removal_lora.safetensors"),
|
| 124 |
+
"lora_weights": [1.0],
|
| 125 |
+
"cond_size": 512,
|
| 126 |
+
},
|
| 127 |
+
"edge": {
|
| 128 |
+
"path": os.path.join(easycontrol_base_dir, "edge_lora.safetensors"),
|
| 129 |
+
"lora_weights": [1.0],
|
| 130 |
+
"cond_size": 512,
|
| 131 |
+
},
|
| 132 |
+
"color": {
|
| 133 |
+
"path": os.path.join(easycontrol_base_dir, "color_lora.safetensors"),
|
| 134 |
+
"lora_weights": [1.0],
|
| 135 |
+
"cond_size": 512,
|
| 136 |
+
},
|
| 137 |
+
}
|
| 138 |
+
self.pipe.load_control_loras(control_lora_config)
|
| 139 |
+
|
| 140 |
+
# Aux LoRA for foreground mode
|
| 141 |
+
self.aux_lora_weight_name = aux_lora_weight_name
|
| 142 |
+
self.aux_lora_dir = aux_lora_dir
|
| 143 |
+
self.aux_lora_weight = aux_lora_weight
|
| 144 |
+
self.aux_adapter_name = "aux"
|
| 145 |
+
|
| 146 |
+
from safetensors.torch import load_file as _sft_load
|
| 147 |
+
aux_path = os.path.join(self.aux_lora_dir, self.aux_lora_weight_name)
|
| 148 |
+
if os.path.isfile(aux_path):
|
| 149 |
+
self.pipe.load_lora_weights(aux_path, adapter_name=self.aux_adapter_name)
|
| 150 |
+
print(f"Loaded aux LoRA: {aux_path}")
|
| 151 |
+
# Ensure aux LoRA is disabled by default; it will be enabled only in foreground_edit
|
| 152 |
+
self._disable_aux_lora()
|
| 153 |
+
else:
|
| 154 |
+
print(f"Aux LoRA not found at {aux_path}, foreground mode will run without it.")
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# gamma is now applied inside the pipeline based on control_dict
|
| 158 |
+
|
| 159 |
+
def _tensor_to_pil(self, tensor_image):
|
| 160 |
+
# Converts a ComfyUI-style tensor [1, H, W, 3] to a PIL Image
|
| 161 |
+
return Image.fromarray(np.clip(255. * tensor_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
| 162 |
+
|
| 163 |
+
def _pil_to_tensor(self, pil_image):
|
| 164 |
+
# Converts a PIL image to a ComfyUI-style tensor [1, H, W, 3]
|
| 165 |
+
return torch.from_numpy(np.array(pil_image).astype(np.float32) / 255.0).unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
def clear_cache(self):
|
| 168 |
+
for name, attn_processor in self.pipe.transformer.attn_processors.items():
|
| 169 |
+
if hasattr(attn_processor, 'bank_kv'):
|
| 170 |
+
attn_processor.bank_kv.clear()
|
| 171 |
+
if hasattr(attn_processor, 'bank_attn'):
|
| 172 |
+
attn_processor.bank_attn = None
|
| 173 |
+
|
| 174 |
+
def _enable_aux_lora(self):
|
| 175 |
+
self.pipe.enable_lora()
|
| 176 |
+
self.pipe.set_adapters([self.aux_adapter_name], adapter_weights=[self.aux_lora_weight])
|
| 177 |
+
print(f"Enabled aux LoRA '{self.aux_adapter_name}' with weight {self.aux_lora_weight}")
|
| 178 |
+
|
| 179 |
+
def _disable_aux_lora(self):
|
| 180 |
+
self.pipe.disable_lora()
|
| 181 |
+
print("Disabled aux LoRA")
|
| 182 |
+
|
| 183 |
+
def _expand_mask(self, mask_tensor: torch.Tensor, expand: int = 0) -> torch.Tensor:
|
| 184 |
+
if expand <= 0:
|
| 185 |
+
return mask_tensor
|
| 186 |
+
expanded = self.mask_processor.expand_mask(mask_tensor, expand=expand, tapered_corners=True)[0]
|
| 187 |
+
return expanded
|
| 188 |
+
|
| 189 |
+
def _tensor_mask_to_pil3(self, mask_tensor: torch.Tensor) -> Image.Image:
|
| 190 |
+
mask_01 = torch.clamp(mask_tensor, 0.0, 1.0)
|
| 191 |
+
if mask_01.ndim == 3 and mask_01.shape[-1] == 3:
|
| 192 |
+
mask_01 = mask_01[..., 0]
|
| 193 |
+
if mask_01.ndim == 3 and mask_01.shape[0] == 1:
|
| 194 |
+
mask_01 = mask_01[0]
|
| 195 |
+
pil = self._tensor_to_pil(mask_01.unsqueeze(-1).repeat(1, 1, 3))
|
| 196 |
+
return pil
|
| 197 |
+
|
| 198 |
+
def _apply_black_mask(self, image_tensor: torch.Tensor, binary_mask: torch.Tensor) -> Image.Image:
|
| 199 |
+
# image_tensor: [1, H, W, 3] in [0,1]
|
| 200 |
+
# binary_mask: [H, W] or [1, H, W], 1=mask area (white)
|
| 201 |
+
if binary_mask.ndim == 3:
|
| 202 |
+
binary_mask = binary_mask[0]
|
| 203 |
+
mask_bool = (binary_mask > 0.5)
|
| 204 |
+
img = image_tensor.clone()
|
| 205 |
+
img[0][mask_bool] = 0.0
|
| 206 |
+
return self._tensor_to_pil(img)
|
| 207 |
+
|
| 208 |
+
def edge_edit(self,
|
| 209 |
+
image, colored_image, positive_prompt,
|
| 210 |
+
base_mask, add_mask, remove_mask,
|
| 211 |
+
fine_edge,
|
| 212 |
+
edge_strength, color_strength,
|
| 213 |
+
seed, steps, cfg):
|
| 214 |
+
|
| 215 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 216 |
+
|
| 217 |
+
# Prepare mask and original image
|
| 218 |
+
original_image_tensor = image.clone()
|
| 219 |
+
original_mask = base_mask
|
| 220 |
+
original_mask = self._expand_mask(original_mask, expand=25)
|
| 221 |
+
|
| 222 |
+
image_pil = self._tensor_to_pil(image)
|
| 223 |
+
# image_pil.save("image_pil.png")
|
| 224 |
+
control_dict = {}
|
| 225 |
+
lineart_output = None
|
| 226 |
+
|
| 227 |
+
# Determine control type: color or edge
|
| 228 |
+
if not torch.equal(image, colored_image):
|
| 229 |
+
print("Apply color control")
|
| 230 |
+
colored_image_pil = self._tensor_to_pil(colored_image)
|
| 231 |
+
# Create color block condition
|
| 232 |
+
color_image_np = np.array(colored_image_pil)
|
| 233 |
+
downsampled = cv2.resize(color_image_np, (32, 32), interpolation=cv2.INTER_AREA)
|
| 234 |
+
upsampled = cv2.resize(downsampled, (256, 256), interpolation=cv2.INTER_NEAREST)
|
| 235 |
+
color_block = Image.fromarray(upsampled)
|
| 236 |
+
# Create grayscale condition
|
| 237 |
+
|
| 238 |
+
control_dict = {
|
| 239 |
+
"type": "color",
|
| 240 |
+
"spatial_images": [color_block],
|
| 241 |
+
"gammas": [color_strength]
|
| 242 |
+
}
|
| 243 |
+
else:
|
| 244 |
+
print("Apply edge control")
|
| 245 |
+
if fine_edge == "enable":
|
| 246 |
+
lineart_image = self.lineart_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), detect_resolution=1024, style="contour", output_type="pil")
|
| 247 |
+
lineart_output = self._pil_to_tensor(lineart_image)
|
| 248 |
+
else:
|
| 249 |
+
scribble_image = self.scribble_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), safe=True, resolution=512, output_type="pil")
|
| 250 |
+
lineart_output = self._pil_to_tensor(scribble_image)
|
| 251 |
+
|
| 252 |
+
if lineart_output is None:
|
| 253 |
+
raise ValueError("Preprocessor failed to generate lineart.")
|
| 254 |
+
|
| 255 |
+
# Apply user sketches to the lineart
|
| 256 |
+
add_mask_resized = F.interpolate(add_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0)
|
| 257 |
+
remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0)
|
| 258 |
+
|
| 259 |
+
bool_add_mask_resized = (add_mask_resized > 0.5)
|
| 260 |
+
bool_remove_mask_resized = (remove_mask_resized > 0.5)
|
| 261 |
+
|
| 262 |
+
lineart_output[bool_remove_mask_resized] = 0.0
|
| 263 |
+
lineart_output[bool_add_mask_resized] = 1.0
|
| 264 |
+
|
| 265 |
+
control_dict = {
|
| 266 |
+
"type": "edge",
|
| 267 |
+
"spatial_images": [self._tensor_to_pil(lineart_output)],
|
| 268 |
+
"gammas": [edge_strength]
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
# Prepare debug/output images
|
| 272 |
+
debug_image = lineart_output if lineart_output is not None else self.color_processor.execute(colored_image, resolution=1024)[0]
|
| 273 |
+
|
| 274 |
+
# Run inference
|
| 275 |
+
result_pil = self.pipe(
|
| 276 |
+
prompt=positive_prompt,
|
| 277 |
+
image=image_pil,
|
| 278 |
+
height=image_pil.height,
|
| 279 |
+
width=image_pil.width,
|
| 280 |
+
guidance_scale=cfg,
|
| 281 |
+
num_inference_steps=steps,
|
| 282 |
+
generator=generator,
|
| 283 |
+
max_sequence_length=128,
|
| 284 |
+
control_dict=control_dict,
|
| 285 |
+
).images[0]
|
| 286 |
+
|
| 287 |
+
self.clear_cache()
|
| 288 |
+
|
| 289 |
+
# result_pil.save("result_pil.png")
|
| 290 |
+
result_tensor = self._pil_to_tensor(result_pil)
|
| 291 |
+
# final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
|
| 292 |
+
final_image = result_tensor
|
| 293 |
+
return (final_image, debug_image, original_mask)
|
| 294 |
+
|
| 295 |
+
def object_removal(self,
|
| 296 |
+
image, positive_prompt,
|
| 297 |
+
remove_mask,
|
| 298 |
+
local_strength,
|
| 299 |
+
seed, steps, cfg):
|
| 300 |
+
|
| 301 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 302 |
+
|
| 303 |
+
original_image_tensor = image.clone()
|
| 304 |
+
original_mask = remove_mask
|
| 305 |
+
original_mask = self._expand_mask(remove_mask, expand=25)
|
| 306 |
+
|
| 307 |
+
image_pil = self._tensor_to_pil(image)
|
| 308 |
+
# image_pil.save("image_pil.png")
|
| 309 |
+
# Prepare spatial image: original masked to black in the remove area
|
| 310 |
+
spatial_pil = self._apply_black_mask(image, original_mask)
|
| 311 |
+
# spatial_pil.save("spatial_pil.png")
|
| 312 |
+
# Note: mask is not passed to pipeline; we use it only for blending
|
| 313 |
+
control_dict = {
|
| 314 |
+
"type": "removal",
|
| 315 |
+
"spatial_images": [spatial_pil],
|
| 316 |
+
"gammas": [local_strength]
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
result_pil = self.pipe(
|
| 320 |
+
prompt=positive_prompt,
|
| 321 |
+
image=image_pil,
|
| 322 |
+
height=image_pil.height,
|
| 323 |
+
width=image_pil.width,
|
| 324 |
+
guidance_scale=cfg,
|
| 325 |
+
num_inference_steps=steps,
|
| 326 |
+
generator=generator,
|
| 327 |
+
control_dict=control_dict,
|
| 328 |
+
).images[0]
|
| 329 |
+
|
| 330 |
+
self.clear_cache()
|
| 331 |
+
|
| 332 |
+
result_tensor = self._pil_to_tensor(result_pil)
|
| 333 |
+
final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
|
| 334 |
+
# final_image = result_tensor
|
| 335 |
+
return (final_image, self._pil_to_tensor(spatial_pil), original_mask)
|
| 336 |
+
|
| 337 |
+
def local_edit(self,
|
| 338 |
+
image, positive_prompt, fill_mask, local_strength,
|
| 339 |
+
seed, steps, cfg):
|
| 340 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 341 |
+
original_image_tensor = image.clone()
|
| 342 |
+
original_mask = self._expand_mask(fill_mask, expand=25)
|
| 343 |
+
image_pil = self._tensor_to_pil(image)
|
| 344 |
+
# image_pil.save("image_pil.png")
|
| 345 |
+
|
| 346 |
+
spatial_pil = self._apply_black_mask(image, original_mask)
|
| 347 |
+
# spatial_pil.save("spatial_pil.png")
|
| 348 |
+
control_dict = {
|
| 349 |
+
"type": "local",
|
| 350 |
+
"spatial_images": [spatial_pil],
|
| 351 |
+
"gammas": [local_strength]
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
result_pil = self.pipe(
|
| 355 |
+
prompt=positive_prompt,
|
| 356 |
+
image=image_pil,
|
| 357 |
+
height=image_pil.height,
|
| 358 |
+
width=image_pil.width,
|
| 359 |
+
guidance_scale=cfg,
|
| 360 |
+
num_inference_steps=steps,
|
| 361 |
+
generator=generator,
|
| 362 |
+
max_sequence_length=128,
|
| 363 |
+
control_dict=control_dict,
|
| 364 |
+
).images[0]
|
| 365 |
+
|
| 366 |
+
self.clear_cache()
|
| 367 |
+
result_tensor = self._pil_to_tensor(result_pil)
|
| 368 |
+
final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
|
| 369 |
+
# final_image = result_tensor
|
| 370 |
+
return (final_image, self._pil_to_tensor(spatial_pil), original_mask)
|
| 371 |
+
|
| 372 |
+
def foreground_edit(self,
|
| 373 |
+
merged_image, positive_prompt,
|
| 374 |
+
add_prop_mask, fill_mask, fix_perspective, grow_size,
|
| 375 |
+
seed, steps, cfg):
|
| 376 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 377 |
+
|
| 378 |
+
edit_mask = torch.clamp(self._expand_mask(add_prop_mask, expand=grow_size) + fill_mask, 0.0, 1.0)
|
| 379 |
+
final_mask = self._expand_mask(edit_mask, expand=25)
|
| 380 |
+
if fix_perspective == "enable":
|
| 381 |
+
positive_prompt = positive_prompt + " Fix the perspective if necessary."
|
| 382 |
+
# Prepare edited input image: inside edit_mask but outside add_prop_mask set to white
|
| 383 |
+
img = merged_image.clone()
|
| 384 |
+
base_mask = (edit_mask > 0.5)
|
| 385 |
+
add_only = (add_prop_mask <= 0.5) & base_mask # [1, H, W] bool
|
| 386 |
+
add_only_3 = add_only.squeeze(0).unsqueeze(-1).expand(-1, -1, img.shape[-1]) # [H, W, 3]
|
| 387 |
+
img[0] = torch.where(add_only_3, torch.ones_like(img[0]), img[0])
|
| 388 |
+
|
| 389 |
+
image_pil = self._tensor_to_pil(img)
|
| 390 |
+
# image_pil.save("image_pil.png")
|
| 391 |
+
|
| 392 |
+
# Enable aux LoRA only for foreground
|
| 393 |
+
self._enable_aux_lora()
|
| 394 |
+
|
| 395 |
+
result_pil = self.pipe(
|
| 396 |
+
prompt=positive_prompt,
|
| 397 |
+
image=image_pil,
|
| 398 |
+
height=image_pil.height,
|
| 399 |
+
width=image_pil.width,
|
| 400 |
+
guidance_scale=cfg,
|
| 401 |
+
num_inference_steps=steps,
|
| 402 |
+
generator=generator,
|
| 403 |
+
max_sequence_length=128,
|
| 404 |
+
control_dict=None,
|
| 405 |
+
).images[0]
|
| 406 |
+
|
| 407 |
+
# Disable aux LoRA afterwards
|
| 408 |
+
self._disable_aux_lora()
|
| 409 |
+
|
| 410 |
+
self.clear_cache()
|
| 411 |
+
final_image = self._pil_to_tensor(result_pil)
|
| 412 |
+
# final_image = self.blender.blend_inpaint(final_image, img, final_mask, kernel=10, sigma=10)[0]
|
| 413 |
+
return (final_image, self._pil_to_tensor(image_pil), edit_mask)
|
| 414 |
+
|
| 415 |
+
def kontext_edit(self,
|
| 416 |
+
image, positive_prompt,
|
| 417 |
+
seed, steps, cfg):
|
| 418 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 419 |
+
image_pil = self._tensor_to_pil(image)
|
| 420 |
+
|
| 421 |
+
result_pil = self.pipe(
|
| 422 |
+
prompt=positive_prompt,
|
| 423 |
+
image=image_pil,
|
| 424 |
+
height=image_pil.height,
|
| 425 |
+
width=image_pil.width,
|
| 426 |
+
guidance_scale=cfg,
|
| 427 |
+
num_inference_steps=steps,
|
| 428 |
+
generator=generator,
|
| 429 |
+
max_sequence_length=128,
|
| 430 |
+
control_dict=None,
|
| 431 |
+
).images[0]
|
| 432 |
+
|
| 433 |
+
final_image = self._pil_to_tensor(result_pil)
|
| 434 |
+
mask = torch.zeros((1, final_image.shape[1], final_image.shape[2]), dtype=torch.float32, device=final_image.device)
|
| 435 |
+
return (final_image, image, mask)
|
| 436 |
+
|
| 437 |
+
def process(self, image, colored_image,
|
| 438 |
+
merged_image, positive_prompt,
|
| 439 |
+
total_mask, add_mask, remove_mask, add_prop_mask, fill_mask,
|
| 440 |
+
fine_edge, fix_perspective, edge_strength, color_strength, local_strength, grow_size,
|
| 441 |
+
seed, steps, cfg, flag="precise_edit"):
|
| 442 |
+
if flag == "foreground":
|
| 443 |
+
return self.foreground_edit(merged_image, positive_prompt, add_prop_mask, fill_mask, fix_perspective, grow_size, seed, steps, cfg)
|
| 444 |
+
elif flag == "local":
|
| 445 |
+
return self.local_edit(image, positive_prompt, fill_mask, local_strength, seed, steps, cfg)
|
| 446 |
+
elif flag == "removal":
|
| 447 |
+
return self.object_removal(image, positive_prompt, remove_mask, local_strength, seed, steps, cfg)
|
| 448 |
+
elif flag == "precise_edit":
|
| 449 |
+
return self.edge_edit(
|
| 450 |
+
image, colored_image, positive_prompt,
|
| 451 |
+
total_mask, add_mask, remove_mask,
|
| 452 |
+
fine_edge,
|
| 453 |
+
edge_strength, color_strength,
|
| 454 |
+
local_strength,
|
| 455 |
+
seed, steps, cfg
|
| 456 |
+
)
|
| 457 |
+
elif flag == "kontext":
|
| 458 |
+
return self.kontext_edit(image, positive_prompt, seed, steps, cfg)
|
| 459 |
+
else:
|
| 460 |
+
raise ValueError("Invalid Editing Type: {}".format(flag))
|
| 461 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
datasets
|
| 3 |
+
diffusers
|
| 4 |
+
easydict
|
| 5 |
+
einops
|
| 6 |
+
fastapi
|
| 7 |
+
gradio==5.4.0
|
| 8 |
+
gradio_client
|
| 9 |
+
huggingface_hub
|
| 10 |
+
numpy
|
| 11 |
+
opencv-python
|
| 12 |
+
peft
|
| 13 |
+
pillow
|
| 14 |
+
protobuf
|
| 15 |
+
requests
|
| 16 |
+
safetensors
|
| 17 |
+
scikit-image
|
| 18 |
+
scipy
|
| 19 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
| 20 |
+
sentencepiece
|
| 21 |
+
spaces
|
| 22 |
+
torch
|
| 23 |
+
torchaudio
|
| 24 |
+
torchvision
|
| 25 |
+
tqdm
|
| 26 |
+
transformers
|
| 27 |
+
uvicorn
|
| 28 |
+
./gradio_magicquillv2-0.0.1-py3-none-any.whl
|
src/__init__.py
ADDED
|
File without changes
|
src/layers_cache.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, List, Optional, Tuple, Union, Any, Dict
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from diffusers.models.attention_processor import Attention
|
| 10 |
+
|
| 11 |
+
TXTLEN = 128
|
| 12 |
+
KONTEXT = False
|
| 13 |
+
|
| 14 |
+
class LoRALinearLayer(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
out_features: int,
|
| 19 |
+
rank: int = 4,
|
| 20 |
+
network_alpha: Optional[float] = None,
|
| 21 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 22 |
+
dtype: Optional[torch.dtype] = None,
|
| 23 |
+
cond_widths: Optional[List[int]] = None,
|
| 24 |
+
cond_heights: Optional[List[int]] = None,
|
| 25 |
+
lora_index: int = 0,
|
| 26 |
+
n_loras: int = 1,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 30 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 31 |
+
self.network_alpha = network_alpha
|
| 32 |
+
self.rank = rank
|
| 33 |
+
self.out_features = out_features
|
| 34 |
+
self.in_features = in_features
|
| 35 |
+
|
| 36 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 37 |
+
nn.init.zeros_(self.up.weight)
|
| 38 |
+
|
| 39 |
+
self.cond_heights = cond_heights if cond_heights is not None else [512]
|
| 40 |
+
self.cond_widths = cond_widths if cond_widths is not None else [512]
|
| 41 |
+
self.lora_index = lora_index
|
| 42 |
+
self.n_loras = n_loras
|
| 43 |
+
|
| 44 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
orig_dtype = hidden_states.dtype
|
| 46 |
+
dtype = self.down.weight.dtype
|
| 47 |
+
|
| 48 |
+
batch_size = hidden_states.shape[0]
|
| 49 |
+
|
| 50 |
+
cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
|
| 51 |
+
total_cond_size = sum(cond_sizes)
|
| 52 |
+
block_size = hidden_states.shape[1] - total_cond_size
|
| 53 |
+
|
| 54 |
+
offset = sum(cond_sizes[:self.lora_index])
|
| 55 |
+
current_cond_size = cond_sizes[self.lora_index]
|
| 56 |
+
|
| 57 |
+
shape = (batch_size, hidden_states.shape[1], 3072)
|
| 58 |
+
mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
|
| 59 |
+
|
| 60 |
+
mask[:, :block_size + offset, :] = 0
|
| 61 |
+
mask[:, block_size + offset + current_cond_size:, :] = 0
|
| 62 |
+
|
| 63 |
+
hidden_states = mask * hidden_states
|
| 64 |
+
|
| 65 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 66 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 67 |
+
|
| 68 |
+
if self.network_alpha is not None:
|
| 69 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 70 |
+
|
| 71 |
+
return up_hidden_states.to(orig_dtype)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MultiSingleStreamBlockLoraProcessor(nn.Module):
|
| 75 |
+
def __init__(self, dim: int, ranks: List[int], lora_weights: List[float], network_alphas: List[float], device=None, dtype=None, cond_widths: Optional[List[int]] = None, cond_heights: Optional[List[int]] = None, n_loras=1):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.n_loras = n_loras
|
| 78 |
+
self.cond_widths = cond_widths if cond_widths is not None else [512]
|
| 79 |
+
self.cond_heights = cond_heights if cond_heights is not None else [512]
|
| 80 |
+
|
| 81 |
+
self.q_loras = nn.ModuleList([
|
| 82 |
+
LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
|
| 83 |
+
for i in range(n_loras)
|
| 84 |
+
])
|
| 85 |
+
self.k_loras = nn.ModuleList([
|
| 86 |
+
LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
|
| 87 |
+
for i in range(n_loras)
|
| 88 |
+
])
|
| 89 |
+
self.v_loras = nn.ModuleList([
|
| 90 |
+
LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
|
| 91 |
+
for i in range(n_loras)
|
| 92 |
+
])
|
| 93 |
+
self.lora_weights = lora_weights
|
| 94 |
+
self.bank_attn = None
|
| 95 |
+
self.bank_kv: List[torch.Tensor] = []
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def __call__(self,
|
| 99 |
+
attn: Attention,
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 102 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 103 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 104 |
+
use_cond = False
|
| 105 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 106 |
+
|
| 107 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 108 |
+
scaled_seq_len = hidden_states.shape[1]
|
| 109 |
+
|
| 110 |
+
cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
|
| 111 |
+
total_cond_size = sum(cond_sizes)
|
| 112 |
+
block_size = scaled_seq_len - total_cond_size
|
| 113 |
+
|
| 114 |
+
scaled_cond_sizes = cond_sizes
|
| 115 |
+
scaled_block_size = block_size
|
| 116 |
+
|
| 117 |
+
global TXTLEN
|
| 118 |
+
global KONTEXT
|
| 119 |
+
if KONTEXT:
|
| 120 |
+
img_start, img_end = TXTLEN, (TXTLEN + block_size) // 2
|
| 121 |
+
else:
|
| 122 |
+
img_start, img_end = TXTLEN, block_size
|
| 123 |
+
cond_start, cond_end = block_size, scaled_seq_len
|
| 124 |
+
|
| 125 |
+
cache = len(self.bank_kv) == 0
|
| 126 |
+
|
| 127 |
+
if cache:
|
| 128 |
+
query = attn.to_q(hidden_states)
|
| 129 |
+
key = attn.to_k(hidden_states)
|
| 130 |
+
value = attn.to_v(hidden_states)
|
| 131 |
+
for i in range(self.n_loras):
|
| 132 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 133 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 134 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 135 |
+
|
| 136 |
+
inner_dim = key.shape[-1]
|
| 137 |
+
head_dim = inner_dim // attn.heads
|
| 138 |
+
|
| 139 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 140 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 141 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 142 |
+
|
| 143 |
+
self.bank_kv.extend([key[:, :, scaled_block_size:, :], value[:, :, scaled_block_size:, :]])
|
| 144 |
+
|
| 145 |
+
if attn.norm_q is not None: query = attn.norm_q(query)
|
| 146 |
+
if attn.norm_k is not None: key = attn.norm_k(key)
|
| 147 |
+
|
| 148 |
+
if image_rotary_emb is not None:
|
| 149 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 150 |
+
query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
|
| 151 |
+
|
| 152 |
+
mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
|
| 153 |
+
mask[ :scaled_block_size, :] = 0
|
| 154 |
+
|
| 155 |
+
current_offset = 0
|
| 156 |
+
for i in range(self.n_loras):
|
| 157 |
+
start, end = scaled_block_size + current_offset, scaled_block_size + current_offset + scaled_cond_sizes[i]
|
| 158 |
+
mask[start:end, start:end] = 0
|
| 159 |
+
current_offset += scaled_cond_sizes[i]
|
| 160 |
+
|
| 161 |
+
mask *= -1e20
|
| 162 |
+
|
| 163 |
+
c_factor = getattr(self, "c_factor", None)
|
| 164 |
+
if c_factor is not None:
|
| 165 |
+
# print(f"Using c_factor: {c_factor}")
|
| 166 |
+
current_offset = 0
|
| 167 |
+
for i in range(self.n_loras):
|
| 168 |
+
bias = torch.log(c_factor[i])
|
| 169 |
+
cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
|
| 170 |
+
mask[img_start:img_end, cond_i_start:cond_i_end] = bias
|
| 171 |
+
current_offset += scaled_cond_sizes[i]
|
| 172 |
+
|
| 173 |
+
# c_factor_kontext = getattr(self, "c_factor_kontext", None)
|
| 174 |
+
# if c_factor_kontext is not None:
|
| 175 |
+
# bias = torch.log(c_factor_kontext)
|
| 176 |
+
# kontext_start, kontext_end = img_end, block_size
|
| 177 |
+
# mask[img_start:img_end, kontext_start:kontext_end] = bias
|
| 178 |
+
# mask[kontext_start:kontext_end, img_start:img_end] = bias
|
| 179 |
+
|
| 180 |
+
# mask[kontext_start:kontext_end, kontext_end:] = -1e20
|
| 181 |
+
|
| 182 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask.to(query.dtype))
|
| 183 |
+
self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
|
| 184 |
+
|
| 185 |
+
else:
|
| 186 |
+
query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
|
| 187 |
+
|
| 188 |
+
inner_dim = query.shape[-1]
|
| 189 |
+
head_dim = inner_dim // attn.heads
|
| 190 |
+
|
| 191 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 192 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 193 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 194 |
+
|
| 195 |
+
key = torch.cat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
|
| 196 |
+
value = torch.cat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
|
| 197 |
+
|
| 198 |
+
if attn.norm_q is not None: query = attn.norm_q(query)
|
| 199 |
+
if attn.norm_k is not None: key = attn.norm_k(key)
|
| 200 |
+
|
| 201 |
+
if image_rotary_emb is not None:
|
| 202 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 203 |
+
query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
|
| 204 |
+
|
| 205 |
+
query = query[:, :, :scaled_block_size, :]
|
| 206 |
+
|
| 207 |
+
attn_mask = None
|
| 208 |
+
c_factor = getattr(self, "c_factor", None)
|
| 209 |
+
if c_factor is not None:
|
| 210 |
+
# print(f"Using c_factor: {c_factor}")
|
| 211 |
+
attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
|
| 212 |
+
current_offset = 0
|
| 213 |
+
for i in range(self.n_loras):
|
| 214 |
+
bias = torch.log(c_factor[i])
|
| 215 |
+
cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
|
| 216 |
+
attn_mask[img_start:img_end, cond_i_start:cond_i_end] = bias
|
| 217 |
+
current_offset += scaled_cond_sizes[i]
|
| 218 |
+
|
| 219 |
+
# c_factor_kontext = getattr(self, "c_factor_kontext", None)
|
| 220 |
+
# if c_factor_kontext is not None:
|
| 221 |
+
# if attn_mask is None:
|
| 222 |
+
# attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
|
| 223 |
+
# bias = torch.log(c_factor_kontext)
|
| 224 |
+
# kontext_start, kontext_end = img_end, block_size
|
| 225 |
+
# attn_mask[img_start:img_end, kontext_start:kontext_end] = bias
|
| 226 |
+
# attn_mask[kontext_start:kontext_end, img_start:img_end] = bias
|
| 227 |
+
|
| 228 |
+
# attn_mask[kontext_start:kontext_end, kontext_end:] = -1e20
|
| 229 |
+
|
| 230 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attn_mask)
|
| 231 |
+
if self.bank_attn is not None: hidden_states = torch.cat([hidden_states, self.bank_attn], dim=-2)
|
| 232 |
+
|
| 233 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 234 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 235 |
+
|
| 236 |
+
cond_hidden_states = hidden_states[:, block_size:,:]
|
| 237 |
+
hidden_states = hidden_states[:, : block_size,:]
|
| 238 |
+
|
| 239 |
+
return (hidden_states, cond_hidden_states) if use_cond else hidden_states
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
|
| 243 |
+
def __init__(self, dim: int, ranks: List[int], lora_weights: List[float], network_alphas: List[float], device=None, dtype=None, cond_widths: Optional[List[int]] = None, cond_heights: Optional[List[int]] = None, n_loras=1):
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
self.n_loras = n_loras
|
| 247 |
+
self.cond_widths = cond_widths if cond_widths is not None else [512]
|
| 248 |
+
self.cond_heights = cond_heights if cond_heights is not None else [512]
|
| 249 |
+
self.q_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
|
| 250 |
+
self.k_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
|
| 251 |
+
self.v_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
|
| 252 |
+
self.proj_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
|
| 253 |
+
self.lora_weights = lora_weights
|
| 254 |
+
self.bank_attn = None
|
| 255 |
+
self.bank_kv: List[torch.Tensor] = []
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def __call__(self,
|
| 259 |
+
attn: Attention,
|
| 260 |
+
hidden_states: torch.Tensor,
|
| 261 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 262 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 263 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 264 |
+
use_cond=False,
|
| 265 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
| 266 |
+
|
| 267 |
+
global TXTLEN
|
| 268 |
+
global KONTEXT
|
| 269 |
+
TXTLEN = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 128
|
| 270 |
+
|
| 271 |
+
batch_size, _, _ = hidden_states.shape
|
| 272 |
+
|
| 273 |
+
cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
|
| 274 |
+
block_size = hidden_states.shape[1] - sum(cond_sizes)
|
| 275 |
+
|
| 276 |
+
scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
|
| 277 |
+
scaled_cond_sizes = cond_sizes
|
| 278 |
+
scaled_block_size = scaled_seq_len - sum(scaled_cond_sizes)
|
| 279 |
+
|
| 280 |
+
if KONTEXT:
|
| 281 |
+
img_start, img_end = TXTLEN, (TXTLEN + block_size) // 2
|
| 282 |
+
else:
|
| 283 |
+
img_start, img_end = TXTLEN, block_size
|
| 284 |
+
cond_start, cond_end = scaled_block_size, scaled_seq_len
|
| 285 |
+
|
| 286 |
+
inner_dim, head_dim = 3072, 3072 // attn.heads
|
| 287 |
+
|
| 288 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 289 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 290 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 291 |
+
|
| 292 |
+
if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 293 |
+
if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 294 |
+
|
| 295 |
+
cache = len(self.bank_kv) == 0
|
| 296 |
+
|
| 297 |
+
if cache:
|
| 298 |
+
query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
|
| 299 |
+
for i in range(self.n_loras):
|
| 300 |
+
query, key, value = query + self.lora_weights[i] * self.q_loras[i](hidden_states), key + self.lora_weights[i] * self.k_loras[i](hidden_states), value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 301 |
+
|
| 302 |
+
query, key, value = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 303 |
+
|
| 304 |
+
self.bank_kv.extend([key[:, :, block_size:, :], value[:, :, block_size:, :]])
|
| 305 |
+
|
| 306 |
+
if attn.norm_q is not None: query = attn.norm_q(query)
|
| 307 |
+
if attn.norm_k is not None: key = attn.norm_k(key)
|
| 308 |
+
|
| 309 |
+
query, key, value = torch.cat([encoder_hidden_states_query_proj, query], dim=2), torch.cat([encoder_hidden_states_key_proj, key], dim=2), torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 310 |
+
|
| 311 |
+
if image_rotary_emb is not None:
|
| 312 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 313 |
+
query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
|
| 314 |
+
|
| 315 |
+
mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
|
| 316 |
+
mask[:scaled_block_size, :] = 0
|
| 317 |
+
|
| 318 |
+
current_offset = 0
|
| 319 |
+
for i in range(self.n_loras):
|
| 320 |
+
start, end = scaled_block_size + current_offset, scaled_block_size + current_offset + scaled_cond_sizes[i]
|
| 321 |
+
mask[start:end, start:end] = 0
|
| 322 |
+
current_offset += scaled_cond_sizes[i]
|
| 323 |
+
|
| 324 |
+
mask *= -1e20
|
| 325 |
+
|
| 326 |
+
c_factor = getattr(self, "c_factor", None)
|
| 327 |
+
if c_factor is not None:
|
| 328 |
+
# print(f"Using c_factor: {c_factor}")
|
| 329 |
+
current_offset = 0
|
| 330 |
+
for i in range(self.n_loras):
|
| 331 |
+
bias = torch.log(c_factor[i])
|
| 332 |
+
cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
|
| 333 |
+
mask[img_start:img_end, cond_i_start:cond_i_end] = bias
|
| 334 |
+
current_offset += scaled_cond_sizes[i]
|
| 335 |
+
|
| 336 |
+
# c_factor_kontext = getattr(self, "c_factor_kontext", None)
|
| 337 |
+
# if c_factor_kontext is not None:
|
| 338 |
+
# bias = torch.log(c_factor_kontext)
|
| 339 |
+
# kontext_start, kontext_end = img_end, block_size
|
| 340 |
+
# mask[img_start:img_end, kontext_start:kontext_end] = bias
|
| 341 |
+
# mask[kontext_start:kontext_end, img_start:img_end] = bias
|
| 342 |
+
|
| 343 |
+
# mask[kontext_start:kontext_end, kontext_end:] = -1e20
|
| 344 |
+
|
| 345 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask.to(query.dtype))
|
| 346 |
+
self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
|
| 347 |
+
|
| 348 |
+
else:
|
| 349 |
+
query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
|
| 350 |
+
|
| 351 |
+
query, key, value = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 352 |
+
|
| 353 |
+
key, value = torch.cat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2), torch.cat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
|
| 354 |
+
|
| 355 |
+
if attn.norm_q is not None: query = attn.norm_q(query)
|
| 356 |
+
if attn.norm_k is not None: key = attn.norm_k(key)
|
| 357 |
+
|
| 358 |
+
query, key, value = torch.cat([encoder_hidden_states_query_proj, query], dim=2), torch.cat([encoder_hidden_states_key_proj, key], dim=2), torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 359 |
+
|
| 360 |
+
if image_rotary_emb is not None:
|
| 361 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 362 |
+
query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
|
| 363 |
+
|
| 364 |
+
query = query[:, :, :scaled_block_size, :]
|
| 365 |
+
|
| 366 |
+
attn_mask = None
|
| 367 |
+
c_factor = getattr(self, "c_factor", None)
|
| 368 |
+
if c_factor is not None:
|
| 369 |
+
# print(f"Using c_factor: {c_factor}")
|
| 370 |
+
attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
|
| 371 |
+
current_offset = 0
|
| 372 |
+
for i in range(self.n_loras):
|
| 373 |
+
bias = torch.log(c_factor[i])
|
| 374 |
+
cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
|
| 375 |
+
attn_mask[img_start:img_end, cond_i_start:cond_i_end] = bias
|
| 376 |
+
current_offset += scaled_cond_sizes[i]
|
| 377 |
+
|
| 378 |
+
# c_factor_kontext = getattr(self, "c_factor_kontext", None)
|
| 379 |
+
# if c_factor_kontext is not None:
|
| 380 |
+
# if attn_mask is None:
|
| 381 |
+
# attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
|
| 382 |
+
# bias = torch.log(c_factor_kontext)
|
| 383 |
+
# kontext_start, kontext_end = img_end, block_size
|
| 384 |
+
# attn_mask[img_start:img_end, kontext_start:kontext_end] = bias
|
| 385 |
+
# attn_mask[kontext_start:kontext_end, img_start:img_end] = bias
|
| 386 |
+
|
| 387 |
+
# attn_mask[kontext_start:kontext_end, kontext_end:] = -1e20
|
| 388 |
+
|
| 389 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attn_mask)
|
| 390 |
+
if self.bank_attn is not None: hidden_states = torch.cat([hidden_states, self.bank_attn], dim=-2)
|
| 391 |
+
|
| 392 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 393 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 394 |
+
|
| 395 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1]:]
|
| 396 |
+
|
| 397 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 398 |
+
for i in range(self.n_loras):
|
| 399 |
+
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
|
| 400 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 401 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 402 |
+
|
| 403 |
+
cond_hidden_states = hidden_states[:, block_size:,:]
|
| 404 |
+
hidden_states = hidden_states[:, :block_size,:]
|
| 405 |
+
|
| 406 |
+
return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
|
src/lora_helper.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 2 |
+
from safetensors.torch import load_file
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 6 |
+
|
| 7 |
+
device = "cuda"
|
| 8 |
+
|
| 9 |
+
def load_safetensors(path):
|
| 10 |
+
"""Safely loads tensors from a file and maps them to the CPU."""
|
| 11 |
+
return load_file(path, device="cpu")
|
| 12 |
+
|
| 13 |
+
def get_lora_count_from_checkpoint(checkpoint):
|
| 14 |
+
"""
|
| 15 |
+
Infers the number of LoRA modules stored in a checkpoint by inspecting its keys.
|
| 16 |
+
Also prints a sample of keys for debugging.
|
| 17 |
+
"""
|
| 18 |
+
lora_indices = set()
|
| 19 |
+
# Regex to find '..._loras.X.' where X is a number.
|
| 20 |
+
indexed_pattern = re.compile(r'._loras\.(\d+)\.')
|
| 21 |
+
found_keys = []
|
| 22 |
+
|
| 23 |
+
for key in checkpoint.keys():
|
| 24 |
+
match = indexed_pattern.search(key)
|
| 25 |
+
if match:
|
| 26 |
+
lora_indices.add(int(match.group(1)))
|
| 27 |
+
if len(found_keys) < 5 and key not in found_keys:
|
| 28 |
+
found_keys.append(key)
|
| 29 |
+
|
| 30 |
+
if lora_indices:
|
| 31 |
+
lora_count = max(lora_indices) + 1
|
| 32 |
+
print("INFO: Auto-detected indexed LoRA keys in checkpoint.")
|
| 33 |
+
print(f" Found {lora_count} LoRA module(s).")
|
| 34 |
+
print(" Sample keys:", found_keys)
|
| 35 |
+
return lora_count
|
| 36 |
+
|
| 37 |
+
# Fallback for legacy, non-indexed checkpoints.
|
| 38 |
+
legacy_found = False
|
| 39 |
+
legacy_key_sample = ""
|
| 40 |
+
for key in checkpoint.keys():
|
| 41 |
+
if '.q_lora.' in key:
|
| 42 |
+
legacy_found = True
|
| 43 |
+
legacy_key_sample = key
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
if legacy_found:
|
| 47 |
+
print("INFO: Auto-detected legacy (non-indexed) LoRA keys in checkpoint.")
|
| 48 |
+
print(" Assuming 1 LoRA module.")
|
| 49 |
+
print(" Sample key:", legacy_key_sample)
|
| 50 |
+
return 1
|
| 51 |
+
|
| 52 |
+
print("WARNING: No LoRA keys found in the checkpoint.")
|
| 53 |
+
return 0
|
| 54 |
+
|
| 55 |
+
def get_lora_ranks(checkpoint, num_loras):
|
| 56 |
+
"""
|
| 57 |
+
Determines the rank for each LoRA module from the checkpoint.
|
| 58 |
+
It supports both indexed (e.g., 'loras.0') and legacy non-indexed formats.
|
| 59 |
+
"""
|
| 60 |
+
ranks = {}
|
| 61 |
+
|
| 62 |
+
# First, try to find ranks for all indexed LoRA modules.
|
| 63 |
+
for i in range(num_loras):
|
| 64 |
+
# Find a key that uniquely identifies the i-th LoRA's down projection.
|
| 65 |
+
rank_pattern = re.compile(f'._loras\.({i})\.down\.weight')
|
| 66 |
+
for k, v in checkpoint.items():
|
| 67 |
+
if rank_pattern.search(k):
|
| 68 |
+
ranks[i] = v.shape[0]
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
# If not all ranks were found, there might be legacy keys or a mismatch.
|
| 72 |
+
if len(ranks) != num_loras:
|
| 73 |
+
# Fallback for single, non-indexed LoRA checkpoints.
|
| 74 |
+
if num_loras == 1:
|
| 75 |
+
for k, v in checkpoint.items():
|
| 76 |
+
if ".q_lora.down.weight" in k:
|
| 77 |
+
return [v.shape[0]]
|
| 78 |
+
|
| 79 |
+
# If still unresolved, use the rank of the very first LoRA found as a default for all.
|
| 80 |
+
first_found_rank = next((v.shape[0] for k, v in checkpoint.items() if k.endswith(".down.weight")), None)
|
| 81 |
+
|
| 82 |
+
if first_found_rank is None:
|
| 83 |
+
raise ValueError("Could not determine any LoRA rank from the provided checkpoint.")
|
| 84 |
+
|
| 85 |
+
# Return a list where missing ranks are filled with the first one found.
|
| 86 |
+
return [ranks.get(i, first_found_rank) for i in range(num_loras)]
|
| 87 |
+
|
| 88 |
+
# Return the list of ranks sorted by LoRA index.
|
| 89 |
+
return [ranks[i] for i in range(num_loras)]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_checkpoint(local_path):
|
| 93 |
+
if local_path is not None:
|
| 94 |
+
if '.safetensors' in local_path:
|
| 95 |
+
print(f"Loading .safetensors checkpoint from {local_path}")
|
| 96 |
+
checkpoint = load_safetensors(local_path)
|
| 97 |
+
else:
|
| 98 |
+
print(f"Loading checkpoint from {local_path}")
|
| 99 |
+
checkpoint = torch.load(local_path, map_location='cpu')
|
| 100 |
+
return checkpoint
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def prepare_lora_processors(checkpoint, lora_weights, transformer, cond_size, number=None):
|
| 104 |
+
# Ensure processors match the transformer's device and dtype
|
| 105 |
+
try:
|
| 106 |
+
first_param = next(transformer.parameters())
|
| 107 |
+
target_device = first_param.device
|
| 108 |
+
target_dtype = first_param.dtype
|
| 109 |
+
except StopIteration:
|
| 110 |
+
target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 111 |
+
target_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 112 |
+
|
| 113 |
+
if number is None:
|
| 114 |
+
number = get_lora_count_from_checkpoint(checkpoint)
|
| 115 |
+
if number == 0:
|
| 116 |
+
return {}
|
| 117 |
+
|
| 118 |
+
if lora_weights and len(lora_weights) != number:
|
| 119 |
+
print(f"WARNING: Provided `lora_weights` length ({len(lora_weights)}) differs from detected LoRA count ({number}).")
|
| 120 |
+
final_weights = (lora_weights + [1.0] * number)[:number]
|
| 121 |
+
print(f" Adjusting weights to: {final_weights}")
|
| 122 |
+
lora_weights = final_weights
|
| 123 |
+
elif not lora_weights:
|
| 124 |
+
print(f"INFO: No `lora_weights` provided. Defaulting to weights of 1.0 for all {number} LoRAs.")
|
| 125 |
+
lora_weights = [1.0] * number
|
| 126 |
+
|
| 127 |
+
ranks = get_lora_ranks(checkpoint, number)
|
| 128 |
+
print("INFO: Determined ranks for LoRA modules:", ranks)
|
| 129 |
+
|
| 130 |
+
cond_widths = cond_size if isinstance(cond_size, list) else [cond_size] * number
|
| 131 |
+
cond_heights = cond_size if isinstance(cond_size, list) else [cond_size] * number
|
| 132 |
+
|
| 133 |
+
lora_attn_procs = {}
|
| 134 |
+
double_blocks_idx = list(range(19))
|
| 135 |
+
single_blocks_idx = list(range(38))
|
| 136 |
+
|
| 137 |
+
# Get all attention processor names from the transformer to iterate over
|
| 138 |
+
for name in transformer.attn_processors.keys():
|
| 139 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 140 |
+
if not match:
|
| 141 |
+
continue
|
| 142 |
+
layer_index = int(match.group(1))
|
| 143 |
+
|
| 144 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 145 |
+
lora_state_dicts = {
|
| 146 |
+
key: value for key, value in checkpoint.items()
|
| 147 |
+
if f"transformer_blocks.{layer_index}." in key
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 151 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
|
| 152 |
+
device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
for n in range(number):
|
| 156 |
+
lora_prefix_q = f"{name}.q_loras.{n}"
|
| 157 |
+
lora_prefix_k = f"{name}.k_loras.{n}"
|
| 158 |
+
lora_prefix_v = f"{name}.v_loras.{n}"
|
| 159 |
+
lora_prefix_proj = f"{name}.proj_loras.{n}"
|
| 160 |
+
|
| 161 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
|
| 162 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
|
| 163 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
|
| 164 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
|
| 165 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
|
| 166 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
|
| 167 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.down.weight')
|
| 168 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.up.weight')
|
| 169 |
+
lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
|
| 170 |
+
|
| 171 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 172 |
+
lora_state_dicts = {
|
| 173 |
+
key: value for key, value in checkpoint.items()
|
| 174 |
+
if f"single_transformer_blocks.{layer_index}." in key
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 178 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
|
| 179 |
+
device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
for n in range(number):
|
| 183 |
+
lora_prefix_q = f"{name}.q_loras.{n}"
|
| 184 |
+
lora_prefix_k = f"{name}.k_loras.{n}"
|
| 185 |
+
lora_prefix_v = f"{name}.v_loras.{n}"
|
| 186 |
+
|
| 187 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
|
| 188 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
|
| 189 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
|
| 190 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
|
| 191 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
|
| 192 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
|
| 193 |
+
lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
|
| 194 |
+
return lora_attn_procs
|
src/pipeline_flux_kontext_control.py
ADDED
|
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import (
|
| 7 |
+
CLIPImageProcessor,
|
| 8 |
+
CLIPTextModel,
|
| 9 |
+
CLIPTokenizer,
|
| 10 |
+
CLIPVisionModelWithProjection,
|
| 11 |
+
T5EncoderModel,
|
| 12 |
+
T5TokenizerFast,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 16 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 17 |
+
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
|
| 18 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 19 |
+
from diffusers.utils import (
|
| 20 |
+
USE_PEFT_BACKEND,
|
| 21 |
+
is_torch_xla_available,
|
| 22 |
+
logging,
|
| 23 |
+
replace_example_docstring,
|
| 24 |
+
scale_lora_layers,
|
| 25 |
+
unscale_lora_layers,
|
| 26 |
+
)
|
| 27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 28 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 29 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 30 |
+
from torchvision.transforms.functional import pad
|
| 31 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 32 |
+
from .lora_helper import prepare_lora_processors, load_checkpoint
|
| 33 |
+
from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 34 |
+
import re
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_torch_xla_available():
|
| 38 |
+
import torch_xla.core.xla_model as xm
|
| 39 |
+
|
| 40 |
+
XLA_AVAILABLE = True
|
| 41 |
+
else:
|
| 42 |
+
XLA_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 48 |
+
(672, 1568),
|
| 49 |
+
(688, 1504),
|
| 50 |
+
(720, 1456),
|
| 51 |
+
(752, 1392),
|
| 52 |
+
(800, 1328),
|
| 53 |
+
(832, 1248),
|
| 54 |
+
(880, 1184),
|
| 55 |
+
(944, 1104),
|
| 56 |
+
(1024, 1024),
|
| 57 |
+
(1104, 944),
|
| 58 |
+
(1184, 880),
|
| 59 |
+
(1248, 832),
|
| 60 |
+
(1328, 800),
|
| 61 |
+
(1392, 752),
|
| 62 |
+
(1456, 720),
|
| 63 |
+
(1504, 688),
|
| 64 |
+
(1568, 672),
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def calculate_shift(
|
| 69 |
+
image_seq_len,
|
| 70 |
+
base_seq_len: int = 256,
|
| 71 |
+
max_seq_len: int = 4096,
|
| 72 |
+
base_shift: float = 0.5,
|
| 73 |
+
max_shift: float = 1.15,
|
| 74 |
+
):
|
| 75 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 76 |
+
b = base_shift - m * base_seq_len
|
| 77 |
+
mu = image_seq_len * m + b
|
| 78 |
+
return mu
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def prepare_latent_image_ids_(height, width, device, dtype):
|
| 82 |
+
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
| 83 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y
|
| 84 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] # x
|
| 85 |
+
return latent_image_ids
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def prepare_latent_subject_ids(height, width, device, dtype):
|
| 89 |
+
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
| 90 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]
|
| 91 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]
|
| 92 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 93 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 94 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 95 |
+
)
|
| 96 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def resize_position_encoding(
|
| 100 |
+
batch_size, original_height, original_width, target_height, target_width, device, dtype
|
| 101 |
+
):
|
| 102 |
+
latent_image_ids = prepare_latent_image_ids_(original_height // 2, original_width // 2, device, dtype)
|
| 103 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 104 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 105 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
scale_h = original_height / target_height
|
| 109 |
+
scale_w = original_width / target_width
|
| 110 |
+
latent_image_ids_resized = torch.zeros(target_height // 2, target_width // 2, 3, device=device, dtype=dtype)
|
| 111 |
+
latent_image_ids_resized[..., 1] = (
|
| 112 |
+
latent_image_ids_resized[..., 1] + torch.arange(target_height // 2, device=device)[:, None] * scale_h
|
| 113 |
+
)
|
| 114 |
+
latent_image_ids_resized[..., 2] = (
|
| 115 |
+
latent_image_ids_resized[..., 2] + torch.arange(target_width // 2, device=device)[None, :] * scale_w
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = (
|
| 119 |
+
latent_image_ids_resized.shape
|
| 120 |
+
)
|
| 121 |
+
cond_latent_image_ids = latent_image_ids_resized.reshape(
|
| 122 |
+
cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
|
| 123 |
+
)
|
| 124 |
+
return latent_image_ids, cond_latent_image_ids
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 128 |
+
def retrieve_timesteps(
|
| 129 |
+
scheduler,
|
| 130 |
+
num_inference_steps: Optional[int] = None,
|
| 131 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 132 |
+
timesteps: Optional[List[int]] = None,
|
| 133 |
+
sigmas: Optional[List[float]] = None,
|
| 134 |
+
**kwargs,
|
| 135 |
+
):
|
| 136 |
+
r"""
|
| 137 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 138 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
scheduler (`SchedulerMixin`):
|
| 142 |
+
The scheduler to get timesteps from.
|
| 143 |
+
num_inference_steps (`int`):
|
| 144 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 145 |
+
must be `None`.
|
| 146 |
+
device (`str` or `torch.device`, *optional*):
|
| 147 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 148 |
+
timesteps (`List[int]`, *optional*):
|
| 149 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 150 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 151 |
+
sigmas (`List[float]`, *optional*):
|
| 152 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 153 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 157 |
+
second element is the number of inference steps.
|
| 158 |
+
"""
|
| 159 |
+
if timesteps is not None and sigmas is not None:
|
| 160 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 161 |
+
if timesteps is not None:
|
| 162 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 163 |
+
if not accepts_timesteps:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 166 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 167 |
+
)
|
| 168 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 169 |
+
timesteps = scheduler.timesteps
|
| 170 |
+
num_inference_steps = len(timesteps)
|
| 171 |
+
elif sigmas is not None:
|
| 172 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 173 |
+
if not accept_sigmas:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 176 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 177 |
+
)
|
| 178 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 179 |
+
timesteps = scheduler.timesteps
|
| 180 |
+
num_inference_steps = len(timesteps)
|
| 181 |
+
else:
|
| 182 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 183 |
+
timesteps = scheduler.timesteps
|
| 184 |
+
return timesteps, num_inference_steps
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 188 |
+
def retrieve_latents(
|
| 189 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 190 |
+
):
|
| 191 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 192 |
+
return encoder_output.latent_dist.sample(generator)
|
| 193 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 194 |
+
return encoder_output.latent_dist.mode()
|
| 195 |
+
elif hasattr(encoder_output, "latents"):
|
| 196 |
+
return encoder_output.latents
|
| 197 |
+
else:
|
| 198 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FluxKontextControlPipeline(
|
| 202 |
+
DiffusionPipeline,
|
| 203 |
+
FluxLoraLoaderMixin,
|
| 204 |
+
FromSingleFileMixin,
|
| 205 |
+
TextualInversionLoaderMixin,
|
| 206 |
+
):
|
| 207 |
+
r"""
|
| 208 |
+
The Flux Kontext pipeline for image-to-image and text-to-image generation with control module.
|
| 209 |
+
|
| 210 |
+
Reference: https://bfl.ai/announcements/flux-1-kontext-dev
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 214 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 215 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 216 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 217 |
+
vae ([`AutoencoderKL`]):
|
| 218 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 219 |
+
text_encoder ([`CLIPTextModel`]):
|
| 220 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 221 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 222 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 223 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 224 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 225 |
+
tokenizer (`CLIPTokenizer`):
|
| 226 |
+
Tokenizer of class
|
| 227 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 228 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 229 |
+
Second Tokenizer of class
|
| 230 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 234 |
+
_optional_components = []
|
| 235 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 236 |
+
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 240 |
+
vae: AutoencoderKL,
|
| 241 |
+
text_encoder: CLIPTextModel,
|
| 242 |
+
tokenizer: CLIPTokenizer,
|
| 243 |
+
text_encoder_2: T5EncoderModel,
|
| 244 |
+
tokenizer_2: T5TokenizerFast,
|
| 245 |
+
transformer: FluxTransformer2DModel,
|
| 246 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 247 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
self.register_modules(
|
| 252 |
+
vae=vae,
|
| 253 |
+
text_encoder=text_encoder,
|
| 254 |
+
text_encoder_2=text_encoder_2,
|
| 255 |
+
tokenizer=tokenizer,
|
| 256 |
+
tokenizer_2=tokenizer_2,
|
| 257 |
+
transformer=transformer,
|
| 258 |
+
scheduler=scheduler,
|
| 259 |
+
image_encoder=None,
|
| 260 |
+
feature_extractor=None,
|
| 261 |
+
)
|
| 262 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 263 |
+
# Flux latents are packed into 2x2 patches, so use VAE factor multiplied by patch size for image processing
|
| 264 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 265 |
+
self.tokenizer_max_length = (
|
| 266 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 267 |
+
)
|
| 268 |
+
self.default_sample_size = 128
|
| 269 |
+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
| 270 |
+
self.control_lora_processors: Dict[str, Dict[str, Any]] = {}
|
| 271 |
+
self.control_lora_cond_sizes: Dict[str, Any] = {}
|
| 272 |
+
self.control_lora_weights: Dict[str, Any] = {}
|
| 273 |
+
self.current_control_type: Optional[Union[str, List[str]]] = None
|
| 274 |
+
|
| 275 |
+
def load_control_loras(self, lora_config: Dict[str, Dict[str, Any]]):
|
| 276 |
+
"""
|
| 277 |
+
Loads and prepares LoRA attention processors for different control types.
|
| 278 |
+
Args:
|
| 279 |
+
lora_config: A dict where keys are control types (e.g., 'edge') and values are dicts
|
| 280 |
+
containing 'path', 'lora_weights', and 'cond_size'.
|
| 281 |
+
"""
|
| 282 |
+
for control_type, config in lora_config.items():
|
| 283 |
+
print(f"Loading LoRA for control type: {control_type}")
|
| 284 |
+
checkpoint = load_checkpoint(config["path"])
|
| 285 |
+
processors = prepare_lora_processors(
|
| 286 |
+
checkpoint=checkpoint,
|
| 287 |
+
lora_weights=config["lora_weights"],
|
| 288 |
+
transformer=self.transformer,
|
| 289 |
+
cond_size=config["cond_size"],
|
| 290 |
+
number=len(config["lora_weights"]) if config.get("lora_weights") is not None else None,
|
| 291 |
+
)
|
| 292 |
+
self.control_lora_processors[control_type] = processors
|
| 293 |
+
self.control_lora_cond_sizes[control_type] = config["cond_size"]
|
| 294 |
+
self.control_lora_weights[control_type] = config["lora_weights"]
|
| 295 |
+
print("All control LoRAs loaded and prepared.")
|
| 296 |
+
|
| 297 |
+
def _combine_control_loras(self, control_types: List[str]):
|
| 298 |
+
"""
|
| 299 |
+
Combines multiple control LoRAs into a single set of attention processors.
|
| 300 |
+
"""
|
| 301 |
+
if not control_types:
|
| 302 |
+
return FluxAttnProcessor2_0()
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
first_param = next(self.transformer.parameters())
|
| 306 |
+
target_device = first_param.device
|
| 307 |
+
target_dtype = first_param.dtype
|
| 308 |
+
except StopIteration:
|
| 309 |
+
target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 310 |
+
target_dtype = torch.float32
|
| 311 |
+
|
| 312 |
+
combined_procs = {}
|
| 313 |
+
# LoRA weights must come from configuration, not from gammas (which control strength)
|
| 314 |
+
all_lora_weights = []
|
| 315 |
+
|
| 316 |
+
# Determine total number of LoRAs and ranks across all control types
|
| 317 |
+
total_loras = 0
|
| 318 |
+
all_ranks = []
|
| 319 |
+
all_cond_sizes = []
|
| 320 |
+
|
| 321 |
+
for control_type in control_types:
|
| 322 |
+
procs = self.control_lora_processors.get(control_type)
|
| 323 |
+
if not procs:
|
| 324 |
+
raise ValueError(f"Control type '{control_type}' not loaded.")
|
| 325 |
+
# Collect configured LoRA weights for this control type
|
| 326 |
+
conf_weights = self.control_lora_weights.get(control_type)
|
| 327 |
+
if conf_weights is None:
|
| 328 |
+
raise ValueError(f"Control type '{control_type}' has no configured lora_weights.")
|
| 329 |
+
all_lora_weights.extend(conf_weights)
|
| 330 |
+
|
| 331 |
+
# Get n_loras from the first processor
|
| 332 |
+
first_proc = next(iter(procs.values()))
|
| 333 |
+
n_loras_in_control = first_proc.n_loras
|
| 334 |
+
total_loras += n_loras_in_control
|
| 335 |
+
|
| 336 |
+
# Correctly get ranks from the processor's LoRA layers
|
| 337 |
+
proc_ranks = [lora.down.weight.shape[0] for lora in first_proc.q_loras]
|
| 338 |
+
all_ranks.extend(proc_ranks)
|
| 339 |
+
|
| 340 |
+
cond_size = self.control_lora_cond_sizes[control_type]
|
| 341 |
+
cond_sizes = [cond_size] * n_loras_in_control if not isinstance(cond_size, list) else cond_size
|
| 342 |
+
all_cond_sizes.extend(cond_sizes)
|
| 343 |
+
|
| 344 |
+
for name in self.transformer.attn_processors.keys():
|
| 345 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 346 |
+
if not match:
|
| 347 |
+
continue
|
| 348 |
+
layer_index = int(match.group(1))
|
| 349 |
+
|
| 350 |
+
if name.startswith("transformer_blocks"):
|
| 351 |
+
new_proc = MultiDoubleStreamBlockLoraProcessor(
|
| 352 |
+
dim=3072, ranks=all_ranks, network_alphas=all_ranks, lora_weights=all_lora_weights,
|
| 353 |
+
device=target_device, dtype=target_dtype,
|
| 354 |
+
cond_widths=all_cond_sizes, cond_heights=all_cond_sizes, n_loras=total_loras
|
| 355 |
+
)
|
| 356 |
+
elif name.startswith("single_transformer_blocks"):
|
| 357 |
+
new_proc = MultiSingleStreamBlockLoraProcessor(
|
| 358 |
+
dim=3072, ranks=all_ranks, network_alphas=all_ranks, lora_weights=all_lora_weights,
|
| 359 |
+
device=target_device, dtype=target_dtype,
|
| 360 |
+
cond_widths=all_cond_sizes, cond_heights=all_cond_sizes, n_loras=total_loras
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
lora_idx_offset = 0
|
| 366 |
+
for control_type in control_types:
|
| 367 |
+
source_proc = self.control_lora_processors[control_type][name]
|
| 368 |
+
for i in range(source_proc.n_loras):
|
| 369 |
+
current_lora_idx = lora_idx_offset + i
|
| 370 |
+
# Copy weights for q, k, v, proj
|
| 371 |
+
new_proc.q_loras[current_lora_idx].load_state_dict(source_proc.q_loras[i].state_dict())
|
| 372 |
+
new_proc.k_loras[current_lora_idx].load_state_dict(source_proc.k_loras[i].state_dict())
|
| 373 |
+
new_proc.v_loras[current_lora_idx].load_state_dict(source_proc.v_loras[i].state_dict())
|
| 374 |
+
if hasattr(new_proc, 'proj_loras'):
|
| 375 |
+
new_proc.proj_loras[current_lora_idx].load_state_dict(source_proc.proj_loras[i].state_dict())
|
| 376 |
+
|
| 377 |
+
lora_idx_offset += source_proc.n_loras
|
| 378 |
+
|
| 379 |
+
combined_procs[name] = new_proc.to(device=target_device, dtype=target_dtype)
|
| 380 |
+
|
| 381 |
+
return combined_procs
|
| 382 |
+
|
| 383 |
+
def set_gamma_values(self, gammas: List[float]):
|
| 384 |
+
"""
|
| 385 |
+
Set gamma values for bias control modulation on current attention processors and attention modules.
|
| 386 |
+
"""
|
| 387 |
+
print(f"Setting gamma values to: {gammas}")
|
| 388 |
+
# Resolve device/dtype robustly from model parameters
|
| 389 |
+
try:
|
| 390 |
+
first_param = next(self.transformer.parameters())
|
| 391 |
+
device = first_param.device
|
| 392 |
+
dtype = first_param.dtype
|
| 393 |
+
except StopIteration:
|
| 394 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 395 |
+
dtype = torch.float32
|
| 396 |
+
gamma_tensor = torch.tensor(gammas, device=device, dtype=dtype)
|
| 397 |
+
for name, attn_processor in self.transformer.attn_processors.items():
|
| 398 |
+
if hasattr(attn_processor, 'q_loras'):
|
| 399 |
+
setattr(attn_processor, 'c_factor', gamma_tensor)
|
| 400 |
+
# print(f" Set c_factor {gamma_tensor} on processor {name}")
|
| 401 |
+
|
| 402 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
| 403 |
+
def _get_t5_prompt_embeds(
|
| 404 |
+
self,
|
| 405 |
+
prompt: Union[str, List[str]] = None,
|
| 406 |
+
num_images_per_prompt: int = 1,
|
| 407 |
+
max_sequence_length: int = 512,
|
| 408 |
+
device: Optional[torch.device] = None,
|
| 409 |
+
dtype: Optional[torch.dtype] = None,
|
| 410 |
+
):
|
| 411 |
+
device = device or self._execution_device
|
| 412 |
+
dtype = dtype or self.text_encoder.dtype
|
| 413 |
+
|
| 414 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 415 |
+
batch_size = len(prompt)
|
| 416 |
+
|
| 417 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 418 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
| 419 |
+
|
| 420 |
+
text_inputs = self.tokenizer_2(
|
| 421 |
+
prompt,
|
| 422 |
+
padding="max_length",
|
| 423 |
+
max_length=max_sequence_length,
|
| 424 |
+
truncation=True,
|
| 425 |
+
return_length=False,
|
| 426 |
+
return_overflowing_tokens=False,
|
| 427 |
+
return_tensors="pt",
|
| 428 |
+
)
|
| 429 |
+
text_input_ids = text_inputs.input_ids
|
| 430 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 431 |
+
|
| 432 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 433 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 434 |
+
logger.warning(
|
| 435 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 436 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 440 |
+
|
| 441 |
+
dtype = self.text_encoder_2.dtype
|
| 442 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 443 |
+
|
| 444 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 445 |
+
|
| 446 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 447 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 448 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 449 |
+
|
| 450 |
+
return prompt_embeds
|
| 451 |
+
|
| 452 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
| 453 |
+
def _get_clip_prompt_embeds(
|
| 454 |
+
self,
|
| 455 |
+
prompt: Union[str, List[str]],
|
| 456 |
+
num_images_per_prompt: int = 1,
|
| 457 |
+
device: Optional[torch.device] = None,
|
| 458 |
+
):
|
| 459 |
+
device = device or self._execution_device
|
| 460 |
+
|
| 461 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 462 |
+
batch_size = len(prompt)
|
| 463 |
+
|
| 464 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 465 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 466 |
+
|
| 467 |
+
text_inputs = self.tokenizer(
|
| 468 |
+
prompt,
|
| 469 |
+
padding="max_length",
|
| 470 |
+
max_length=self.tokenizer_max_length,
|
| 471 |
+
truncation=True,
|
| 472 |
+
return_overflowing_tokens=False,
|
| 473 |
+
return_length=False,
|
| 474 |
+
return_tensors="pt",
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
text_input_ids = text_inputs.input_ids
|
| 478 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 479 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 480 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 481 |
+
logger.warning(
|
| 482 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 483 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 484 |
+
)
|
| 485 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 486 |
+
|
| 487 |
+
# Use pooled output of CLIPTextModel
|
| 488 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 489 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 490 |
+
|
| 491 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 492 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 493 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 494 |
+
|
| 495 |
+
return prompt_embeds
|
| 496 |
+
|
| 497 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
| 498 |
+
def encode_prompt(
|
| 499 |
+
self,
|
| 500 |
+
prompt: Union[str, List[str]],
|
| 501 |
+
prompt_2: Union[str, List[str]],
|
| 502 |
+
device: Optional[torch.device] = None,
|
| 503 |
+
num_images_per_prompt: int = 1,
|
| 504 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 505 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 506 |
+
max_sequence_length: int = 512,
|
| 507 |
+
lora_scale: Optional[float] = None,
|
| 508 |
+
):
|
| 509 |
+
r"""
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 513 |
+
prompt to be encoded
|
| 514 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 515 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 516 |
+
used in all text-encoders
|
| 517 |
+
device: (`torch.device`):
|
| 518 |
+
torch device
|
| 519 |
+
num_images_per_prompt (`int`):
|
| 520 |
+
number of images that should be generated per prompt
|
| 521 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 522 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 523 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 524 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 525 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 526 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 527 |
+
lora_scale (`float`, *optional*):
|
| 528 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 529 |
+
"""
|
| 530 |
+
device = device or self._execution_device
|
| 531 |
+
|
| 532 |
+
# set lora scale so that monkey patched LoRA
|
| 533 |
+
# function of text encoder can correctly access it
|
| 534 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 535 |
+
self._lora_scale = lora_scale
|
| 536 |
+
|
| 537 |
+
# dynamically adjust the LoRA scale
|
| 538 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 539 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 540 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 541 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 542 |
+
|
| 543 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 544 |
+
|
| 545 |
+
if prompt_embeds is None:
|
| 546 |
+
prompt_2 = prompt_2 or prompt
|
| 547 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 548 |
+
|
| 549 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 550 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 551 |
+
prompt=prompt,
|
| 552 |
+
device=device,
|
| 553 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 554 |
+
)
|
| 555 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 556 |
+
prompt=prompt_2,
|
| 557 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 558 |
+
max_sequence_length=max_sequence_length,
|
| 559 |
+
device=device,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
if self.text_encoder is not None:
|
| 563 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 564 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 565 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 566 |
+
|
| 567 |
+
if self.text_encoder_2 is not None:
|
| 568 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 569 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 570 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 571 |
+
|
| 572 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 573 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 574 |
+
|
| 575 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 576 |
+
|
| 577 |
+
# Adapted from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
|
| 578 |
+
def check_inputs(
|
| 579 |
+
self,
|
| 580 |
+
prompt,
|
| 581 |
+
prompt_2,
|
| 582 |
+
height,
|
| 583 |
+
width,
|
| 584 |
+
prompt_embeds=None,
|
| 585 |
+
pooled_prompt_embeds=None,
|
| 586 |
+
callback_on_step_end_tensor_inputs=None,
|
| 587 |
+
max_sequence_length=None,
|
| 588 |
+
):
|
| 589 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 590 |
+
raise ValueError(
|
| 591 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 595 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 596 |
+
):
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
if prompt is not None and prompt_embeds is not None:
|
| 602 |
+
raise ValueError(
|
| 603 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 604 |
+
" only forward one of the two."
|
| 605 |
+
)
|
| 606 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 609 |
+
" only forward one of the two."
|
| 610 |
+
)
|
| 611 |
+
elif prompt is None and prompt_embeds is None:
|
| 612 |
+
raise ValueError(
|
| 613 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 614 |
+
)
|
| 615 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 616 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 617 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 618 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 619 |
+
|
| 620 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 626 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 627 |
+
|
| 628 |
+
@staticmethod
|
| 629 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
| 630 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 631 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 632 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 633 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 634 |
+
|
| 635 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 636 |
+
|
| 637 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 638 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 642 |
+
|
| 643 |
+
@staticmethod
|
| 644 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
| 645 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 646 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 647 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 648 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 649 |
+
|
| 650 |
+
return latents
|
| 651 |
+
|
| 652 |
+
@staticmethod
|
| 653 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
| 654 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 655 |
+
batch_size, num_patches, channels = latents.shape
|
| 656 |
+
|
| 657 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 658 |
+
# latent height and width to be divisible by 2.
|
| 659 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 660 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 661 |
+
|
| 662 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 663 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 664 |
+
|
| 665 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 666 |
+
|
| 667 |
+
return latents
|
| 668 |
+
|
| 669 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 670 |
+
if isinstance(generator, list):
|
| 671 |
+
image_latents = [
|
| 672 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 673 |
+
for i in range(image.shape[0])
|
| 674 |
+
]
|
| 675 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 676 |
+
else:
|
| 677 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 678 |
+
|
| 679 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 680 |
+
|
| 681 |
+
return image_latents
|
| 682 |
+
|
| 683 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
|
| 684 |
+
def enable_vae_slicing(self):
|
| 685 |
+
r"""
|
| 686 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 687 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 688 |
+
"""
|
| 689 |
+
self.vae.enable_slicing()
|
| 690 |
+
|
| 691 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
|
| 692 |
+
def disable_vae_slicing(self):
|
| 693 |
+
r"""
|
| 694 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 695 |
+
computing decoding in one step.
|
| 696 |
+
"""
|
| 697 |
+
self.vae.disable_slicing()
|
| 698 |
+
|
| 699 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
|
| 700 |
+
def enable_vae_tiling(self):
|
| 701 |
+
r"""
|
| 702 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 703 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 704 |
+
processing larger images.
|
| 705 |
+
"""
|
| 706 |
+
self.vae.enable_tiling()
|
| 707 |
+
|
| 708 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
| 709 |
+
def disable_vae_tiling(self):
|
| 710 |
+
r"""
|
| 711 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 712 |
+
computing decoding in one step.
|
| 713 |
+
"""
|
| 714 |
+
self.vae.disable_tiling()
|
| 715 |
+
|
| 716 |
+
def prepare_latents(
|
| 717 |
+
self,
|
| 718 |
+
batch_size,
|
| 719 |
+
num_channels_latents,
|
| 720 |
+
height,
|
| 721 |
+
width,
|
| 722 |
+
dtype,
|
| 723 |
+
device,
|
| 724 |
+
generator,
|
| 725 |
+
image,
|
| 726 |
+
subject_images,
|
| 727 |
+
spatial_images,
|
| 728 |
+
latents=None,
|
| 729 |
+
cond_size=512,
|
| 730 |
+
num_subject_images: int = 0,
|
| 731 |
+
num_spatial_images: int = 0,
|
| 732 |
+
):
|
| 733 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 734 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 735 |
+
height_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
|
| 736 |
+
width_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
|
| 737 |
+
|
| 738 |
+
image_latents = image_ids = None
|
| 739 |
+
image_latent_h = 0 # Initialize to handle case where image is None
|
| 740 |
+
|
| 741 |
+
# Prepare noise latents
|
| 742 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 743 |
+
if latents is None:
|
| 744 |
+
noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 745 |
+
else:
|
| 746 |
+
noise_latents = latents.to(device=device, dtype=dtype)
|
| 747 |
+
|
| 748 |
+
noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
|
| 749 |
+
# print(noise_latents.shape)
|
| 750 |
+
noise_latent_image_ids, cond_latent_image_ids_resized = resize_position_encoding(
|
| 751 |
+
batch_size, height, width, height_cond, width_cond, device, dtype
|
| 752 |
+
)
|
| 753 |
+
# noise IDs are marked with 0 in the first channel
|
| 754 |
+
noise_latent_image_ids[..., 0] = 0
|
| 755 |
+
|
| 756 |
+
cond_latents_to_concat = []
|
| 757 |
+
latents_ids_to_concat = [noise_latent_image_ids]
|
| 758 |
+
|
| 759 |
+
# 1. Prepare `image` (Kontext) latents
|
| 760 |
+
if image is not None:
|
| 761 |
+
image = image.to(device=device, dtype=dtype)
|
| 762 |
+
if image.shape[1] != self.latent_channels:
|
| 763 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 764 |
+
else:
|
| 765 |
+
image_latents = image
|
| 766 |
+
|
| 767 |
+
image_latent_h, image_latent_w = image_latents.shape[2:]
|
| 768 |
+
image_latents = self._pack_latents(
|
| 769 |
+
image_latents, batch_size, num_channels_latents, image_latent_h, image_latent_w
|
| 770 |
+
)
|
| 771 |
+
image_ids = self._prepare_latent_image_ids(
|
| 772 |
+
batch_size, image_latent_h // 2, image_latent_w // 2, device, dtype
|
| 773 |
+
)
|
| 774 |
+
image_ids[..., 0] = 1 # Mark as condition
|
| 775 |
+
latents_ids_to_concat.append(image_ids)
|
| 776 |
+
|
| 777 |
+
# 2. Prepare `subject_images` latents
|
| 778 |
+
if subject_images is not None and num_subject_images > 0:
|
| 779 |
+
subject_images = subject_images.to(device=device, dtype=dtype)
|
| 780 |
+
subject_image_latents = self._encode_vae_image(image=subject_images, generator=generator)
|
| 781 |
+
subject_latent_h, subject_latent_w = subject_image_latents.shape[2:]
|
| 782 |
+
subject_latents = self._pack_latents(
|
| 783 |
+
subject_image_latents, batch_size, num_channels_latents, subject_latent_h, subject_latent_w
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, device, dtype)
|
| 787 |
+
latent_subject_ids[..., 0] = 1
|
| 788 |
+
latent_subject_ids[:, 1] += image_latent_h // 2
|
| 789 |
+
subject_latent_image_ids = torch.cat([latent_subject_ids for _ in range(num_subject_images)], dim=0)
|
| 790 |
+
|
| 791 |
+
cond_latents_to_concat.append(subject_latents)
|
| 792 |
+
latents_ids_to_concat.append(subject_latent_image_ids)
|
| 793 |
+
|
| 794 |
+
# 3. Prepare `spatial_images` latents
|
| 795 |
+
if spatial_images is not None and num_spatial_images > 0:
|
| 796 |
+
spatial_images = spatial_images.to(device=device, dtype=dtype)
|
| 797 |
+
spatial_image_latents = self._encode_vae_image(image=spatial_images, generator=generator)
|
| 798 |
+
spatial_latent_h, spatial_latent_w = spatial_image_latents.shape[2:]
|
| 799 |
+
cond_latents = self._pack_latents(
|
| 800 |
+
spatial_image_latents, batch_size, num_channels_latents, spatial_latent_h, spatial_latent_w
|
| 801 |
+
)
|
| 802 |
+
cond_latent_image_ids_resized[..., 0] = 2 # Mark as condition
|
| 803 |
+
cond_latent_image_ids = torch.cat(
|
| 804 |
+
[cond_latent_image_ids_resized for _ in range(num_spatial_images)], dim=0
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
cond_latents_to_concat.append(cond_latents)
|
| 808 |
+
latents_ids_to_concat.append(cond_latent_image_ids)
|
| 809 |
+
|
| 810 |
+
cond_latents = torch.cat(cond_latents_to_concat, dim=1) if cond_latents_to_concat else None
|
| 811 |
+
latent_image_ids = torch.cat(latents_ids_to_concat, dim=0)
|
| 812 |
+
|
| 813 |
+
return noise_latents, image_latents, cond_latents, latent_image_ids
|
| 814 |
+
|
| 815 |
+
@property
|
| 816 |
+
def guidance_scale(self):
|
| 817 |
+
return self._guidance_scale
|
| 818 |
+
|
| 819 |
+
@property
|
| 820 |
+
def joint_attention_kwargs(self):
|
| 821 |
+
return self._joint_attention_kwargs
|
| 822 |
+
|
| 823 |
+
@property
|
| 824 |
+
def num_timesteps(self):
|
| 825 |
+
return self._num_timesteps
|
| 826 |
+
|
| 827 |
+
@property
|
| 828 |
+
def current_timestep(self):
|
| 829 |
+
return self._current_timestep
|
| 830 |
+
|
| 831 |
+
@property
|
| 832 |
+
def interrupt(self):
|
| 833 |
+
return self._interrupt
|
| 834 |
+
|
| 835 |
+
@torch.no_grad()
|
| 836 |
+
def __call__(
|
| 837 |
+
self,
|
| 838 |
+
image: Optional[PipelineImageInput] = None,
|
| 839 |
+
prompt: Union[str, List[str]] = None,
|
| 840 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 841 |
+
height: Optional[int] = None,
|
| 842 |
+
width: Optional[int] = None,
|
| 843 |
+
num_inference_steps: int = 28,
|
| 844 |
+
sigmas: Optional[List[float]] = None,
|
| 845 |
+
guidance_scale: float = 3.5,
|
| 846 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 847 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 848 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 849 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 850 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 851 |
+
output_type: Optional[str] = "pil",
|
| 852 |
+
return_dict: bool = True,
|
| 853 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 854 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 855 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 856 |
+
max_sequence_length: int = 512,
|
| 857 |
+
cond_size: int = 512,
|
| 858 |
+
control_dict: Optional[Dict[str, Any]] = None,
|
| 859 |
+
):
|
| 860 |
+
r"""
|
| 861 |
+
Function invoked when calling the pipeline for generation.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 865 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 866 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 867 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 868 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 869 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 870 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 871 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 872 |
+
instead.
|
| 873 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 874 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 875 |
+
will be used instead.
|
| 876 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 877 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 878 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 879 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 880 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 881 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 882 |
+
expense of slower inference.
|
| 883 |
+
sigmas (`List[float]`, *optional*):
|
| 884 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 885 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 886 |
+
will be used.
|
| 887 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 888 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 889 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 890 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 891 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 892 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 893 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 894 |
+
The number of images to generate per prompt.
|
| 895 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 896 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 897 |
+
to make generation deterministic.
|
| 898 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 899 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 900 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 901 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 902 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 903 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 904 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 905 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 906 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 907 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 908 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 909 |
+
The output format of the generate image. Choose between
|
| 910 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 911 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 912 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 913 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 914 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 915 |
+
`self.processor` in
|
| 916 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 917 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 918 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 919 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 920 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 921 |
+
`callback_on_step_end_tensor_inputs`.
|
| 922 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 923 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 924 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 925 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 926 |
+
max_sequence_length (`int` defaults to 512):
|
| 927 |
+
Maximum sequence length to use with the `prompt`.
|
| 928 |
+
cond_size (`int`, *optional*, defaults to 512):
|
| 929 |
+
The size for conditioning images.
|
| 930 |
+
|
| 931 |
+
Examples:
|
| 932 |
+
|
| 933 |
+
Returns:
|
| 934 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 935 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 936 |
+
images.
|
| 937 |
+
"""
|
| 938 |
+
|
| 939 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 940 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 941 |
+
|
| 942 |
+
# 1. Check inputs. Raise error if not correct
|
| 943 |
+
self.check_inputs(
|
| 944 |
+
prompt,
|
| 945 |
+
prompt_2,
|
| 946 |
+
height,
|
| 947 |
+
width,
|
| 948 |
+
prompt_embeds=prompt_embeds,
|
| 949 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 950 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 951 |
+
max_sequence_length=max_sequence_length,
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
self._guidance_scale = guidance_scale
|
| 955 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 956 |
+
self._current_timestep = None
|
| 957 |
+
self._interrupt = False
|
| 958 |
+
|
| 959 |
+
# Normalize control_dict to an empty dict so kontext-only inference works without controls
|
| 960 |
+
control_dict = control_dict or {}
|
| 961 |
+
|
| 962 |
+
spatial_images = control_dict.get("spatial_images", [])
|
| 963 |
+
num_spatial_images = len(spatial_images)
|
| 964 |
+
subject_images = control_dict.get("subject_images", [])
|
| 965 |
+
num_subject_images = len(subject_images)
|
| 966 |
+
|
| 967 |
+
requested_control_type = control_dict.get("type") or None
|
| 968 |
+
|
| 969 |
+
# Normalize to list for unified handling
|
| 970 |
+
if requested_control_type and isinstance(requested_control_type, str):
|
| 971 |
+
requested_control_type = [requested_control_type]
|
| 972 |
+
|
| 973 |
+
# Revert to default if no control type is requested and a control is active
|
| 974 |
+
if not requested_control_type and self.current_control_type:
|
| 975 |
+
print("Reverting to default attention processors.")
|
| 976 |
+
self.transformer.set_attn_processor(FluxAttnProcessor2_0())
|
| 977 |
+
self.current_control_type = None
|
| 978 |
+
# Switch processors only if the control type(s) have changed
|
| 979 |
+
elif requested_control_type != self.current_control_type:
|
| 980 |
+
if requested_control_type:
|
| 981 |
+
print(f"Switching to LoRA control type(s): {requested_control_type}")
|
| 982 |
+
processors = self._combine_control_loras(requested_control_type)
|
| 983 |
+
self.transformer.set_attn_processor(processors)
|
| 984 |
+
# For cond_size, we assume they are compatible and just use the first one.
|
| 985 |
+
self.cond_size = self.control_lora_cond_sizes[requested_control_type[0]]
|
| 986 |
+
self.current_control_type = requested_control_type
|
| 987 |
+
|
| 988 |
+
# Align cond_size to selected control type (if any)
|
| 989 |
+
if hasattr(self, "cond_size"):
|
| 990 |
+
selected_cond_size = self.cond_size
|
| 991 |
+
if isinstance(selected_cond_size, list) and len(selected_cond_size) > 0:
|
| 992 |
+
cond_size = int(selected_cond_size[0])
|
| 993 |
+
elif isinstance(selected_cond_size, int):
|
| 994 |
+
cond_size = selected_cond_size
|
| 995 |
+
|
| 996 |
+
# Set gamma values simply based on provided control_dict['gammas'].
|
| 997 |
+
if requested_control_type:
|
| 998 |
+
raw_gammas = control_dict.get("gammas", [])
|
| 999 |
+
if not isinstance(raw_gammas, list):
|
| 1000 |
+
raw_gammas = [raw_gammas]
|
| 1001 |
+
# flatten one level
|
| 1002 |
+
flattened_gammas: List[float] = []
|
| 1003 |
+
for g in raw_gammas:
|
| 1004 |
+
if isinstance(g, (list, tuple)):
|
| 1005 |
+
flattened_gammas.extend([float(x) for x in g])
|
| 1006 |
+
else:
|
| 1007 |
+
flattened_gammas.append(float(g))
|
| 1008 |
+
if len(flattened_gammas) > 0:
|
| 1009 |
+
self.set_gamma_values(flattened_gammas)
|
| 1010 |
+
|
| 1011 |
+
# 2. Define call parameters
|
| 1012 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1013 |
+
batch_size = 1
|
| 1014 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1015 |
+
batch_size = len(prompt)
|
| 1016 |
+
else:
|
| 1017 |
+
batch_size = prompt_embeds.shape[0]
|
| 1018 |
+
|
| 1019 |
+
device = self._execution_device
|
| 1020 |
+
|
| 1021 |
+
lora_scale = (
|
| 1022 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 1023 |
+
)
|
| 1024 |
+
(
|
| 1025 |
+
prompt_embeds,
|
| 1026 |
+
pooled_prompt_embeds,
|
| 1027 |
+
text_ids,
|
| 1028 |
+
) = self.encode_prompt(
|
| 1029 |
+
prompt=prompt,
|
| 1030 |
+
prompt_2=prompt_2,
|
| 1031 |
+
prompt_embeds=prompt_embeds,
|
| 1032 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1033 |
+
device=device,
|
| 1034 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1035 |
+
max_sequence_length=max_sequence_length,
|
| 1036 |
+
lora_scale=lora_scale,
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
# 3. Preprocess images
|
| 1040 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 1041 |
+
img = image[0] if isinstance(image, list) else image
|
| 1042 |
+
image_height, image_width = self.image_processor.get_default_height_width(img)
|
| 1043 |
+
aspect_ratio = image_width / image_height
|
| 1044 |
+
# Kontext is trained on specific resolutions, using one of them is recommended
|
| 1045 |
+
_, image_width, image_height = min(
|
| 1046 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 1047 |
+
)
|
| 1048 |
+
multiple_of = self.vae_scale_factor * 2
|
| 1049 |
+
image_width = image_width // multiple_of * multiple_of
|
| 1050 |
+
image_height = image_height // multiple_of * multiple_of
|
| 1051 |
+
image = self.image_processor.resize(image, image_height, image_width)
|
| 1052 |
+
image = self.image_processor.preprocess(image, image_height, image_width)
|
| 1053 |
+
|
| 1054 |
+
if len(subject_images) > 0:
|
| 1055 |
+
subject_image_ls = []
|
| 1056 |
+
for subject_image in subject_images:
|
| 1057 |
+
w, h = subject_image.size[:2]
|
| 1058 |
+
scale = cond_size / max(h, w)
|
| 1059 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 1060 |
+
subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
|
| 1061 |
+
subject_image = subject_image.to(dtype=self.vae.dtype)
|
| 1062 |
+
pad_h = cond_size - subject_image.shape[-2]
|
| 1063 |
+
pad_w = cond_size - subject_image.shape[-1]
|
| 1064 |
+
subject_image = pad(
|
| 1065 |
+
subject_image, padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)), fill=0
|
| 1066 |
+
)
|
| 1067 |
+
subject_image_ls.append(subject_image)
|
| 1068 |
+
subject_images = torch.cat(subject_image_ls, dim=-2)
|
| 1069 |
+
else:
|
| 1070 |
+
subject_images = None
|
| 1071 |
+
|
| 1072 |
+
if len(spatial_images) > 0:
|
| 1073 |
+
condition_image_ls = []
|
| 1074 |
+
for img in spatial_images:
|
| 1075 |
+
condition_image = self.image_processor.preprocess(img, height=cond_size, width=cond_size)
|
| 1076 |
+
condition_image = condition_image.to(dtype=self.vae.dtype)
|
| 1077 |
+
condition_image_ls.append(condition_image)
|
| 1078 |
+
spatial_images = torch.cat(condition_image_ls, dim=-2)
|
| 1079 |
+
else:
|
| 1080 |
+
spatial_images = None
|
| 1081 |
+
|
| 1082 |
+
# 4. Prepare latent variables
|
| 1083 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 1084 |
+
latents, image_latents, cond_latents, latent_image_ids = self.prepare_latents(
|
| 1085 |
+
batch_size * num_images_per_prompt,
|
| 1086 |
+
num_channels_latents,
|
| 1087 |
+
height,
|
| 1088 |
+
width,
|
| 1089 |
+
prompt_embeds.dtype,
|
| 1090 |
+
device,
|
| 1091 |
+
generator,
|
| 1092 |
+
image,
|
| 1093 |
+
subject_images,
|
| 1094 |
+
spatial_images,
|
| 1095 |
+
latents,
|
| 1096 |
+
cond_size,
|
| 1097 |
+
num_subject_images=num_subject_images,
|
| 1098 |
+
num_spatial_images=num_spatial_images,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
# 5. Prepare timesteps
|
| 1102 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 1103 |
+
# sigmas = np.array([1.0000, 0.9836, 0.9660, 0.9471, 0.9266, 0.9045, 0.8805, 0.8543, 0.8257, 0.7942, 0.7595, 0.7210, 0.6780, 0.6297, 0.5751, 0.5128, 0.4412, 0.3579, 0.2598, 0.1425])
|
| 1104 |
+
image_seq_len = latents.shape[1]
|
| 1105 |
+
mu = calculate_shift(
|
| 1106 |
+
image_seq_len,
|
| 1107 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1108 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1109 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1110 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 1111 |
+
)
|
| 1112 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1113 |
+
self.scheduler,
|
| 1114 |
+
num_inference_steps,
|
| 1115 |
+
device,
|
| 1116 |
+
sigmas=sigmas,
|
| 1117 |
+
mu=mu,
|
| 1118 |
+
)
|
| 1119 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1120 |
+
self._num_timesteps = len(timesteps)
|
| 1121 |
+
|
| 1122 |
+
# handle guidance
|
| 1123 |
+
if self.transformer.config.guidance_embeds:
|
| 1124 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 1125 |
+
guidance = guidance.expand(latents.shape[0])
|
| 1126 |
+
else:
|
| 1127 |
+
guidance = None
|
| 1128 |
+
|
| 1129 |
+
if self.joint_attention_kwargs is None:
|
| 1130 |
+
self._joint_attention_kwargs = {}
|
| 1131 |
+
|
| 1132 |
+
# K/V Caching
|
| 1133 |
+
for name, attn_processor in self.transformer.attn_processors.items():
|
| 1134 |
+
if hasattr(attn_processor, "bank_kv"):
|
| 1135 |
+
attn_processor.bank_kv.clear()
|
| 1136 |
+
if hasattr(attn_processor, "bank_attn"):
|
| 1137 |
+
attn_processor.bank_attn = None
|
| 1138 |
+
|
| 1139 |
+
if cond_latents is not None:
|
| 1140 |
+
latent_model_input = latents
|
| 1141 |
+
if image_latents is not None:
|
| 1142 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 1143 |
+
print(latent_model_input.shape)
|
| 1144 |
+
warmup_latents = latent_model_input
|
| 1145 |
+
warmup_latent_ids = latent_image_ids
|
| 1146 |
+
t = torch.tensor([timesteps[0]], device=device)
|
| 1147 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 1148 |
+
_ = self.transformer(
|
| 1149 |
+
hidden_states=warmup_latents,
|
| 1150 |
+
cond_hidden_states=cond_latents,
|
| 1151 |
+
timestep=timestep / 1000,
|
| 1152 |
+
guidance=guidance,
|
| 1153 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1154 |
+
encoder_hidden_states=prompt_embeds,
|
| 1155 |
+
txt_ids=text_ids,
|
| 1156 |
+
img_ids=warmup_latent_ids,
|
| 1157 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1158 |
+
return_dict=False,
|
| 1159 |
+
)[0]
|
| 1160 |
+
|
| 1161 |
+
# 6. Denoising loop
|
| 1162 |
+
self.scheduler.set_begin_index(0)
|
| 1163 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1164 |
+
for i, t in enumerate(timesteps):
|
| 1165 |
+
if self.interrupt:
|
| 1166 |
+
continue
|
| 1167 |
+
|
| 1168 |
+
latent_model_input = latents
|
| 1169 |
+
if image_latents is not None:
|
| 1170 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 1171 |
+
|
| 1172 |
+
self._current_timestep = t
|
| 1173 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 1174 |
+
noise_pred = self.transformer(
|
| 1175 |
+
hidden_states=latent_model_input,
|
| 1176 |
+
cond_hidden_states=cond_latents,
|
| 1177 |
+
timestep=timestep / 1000,
|
| 1178 |
+
guidance=guidance,
|
| 1179 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1180 |
+
encoder_hidden_states=prompt_embeds,
|
| 1181 |
+
txt_ids=text_ids,
|
| 1182 |
+
img_ids=latent_image_ids,
|
| 1183 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1184 |
+
return_dict=False,
|
| 1185 |
+
)[0]
|
| 1186 |
+
|
| 1187 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 1188 |
+
|
| 1189 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1190 |
+
latents_dtype = latents.dtype
|
| 1191 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1192 |
+
|
| 1193 |
+
if latents.dtype != latents_dtype:
|
| 1194 |
+
if torch.backends.mps.is_available():
|
| 1195 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1196 |
+
latents = latents.to(latents_dtype)
|
| 1197 |
+
|
| 1198 |
+
if callback_on_step_end is not None:
|
| 1199 |
+
callback_kwargs = {}
|
| 1200 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1201 |
+
callback_kwargs[k] = locals()[k]
|
| 1202 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1203 |
+
|
| 1204 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1205 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1206 |
+
|
| 1207 |
+
# call the callback, if provided
|
| 1208 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1209 |
+
progress_bar.update()
|
| 1210 |
+
|
| 1211 |
+
if XLA_AVAILABLE:
|
| 1212 |
+
xm.mark_step()
|
| 1213 |
+
|
| 1214 |
+
self._current_timestep = None
|
| 1215 |
+
|
| 1216 |
+
if output_type == "latent":
|
| 1217 |
+
image = latents
|
| 1218 |
+
else:
|
| 1219 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1220 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1221 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1222 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1223 |
+
|
| 1224 |
+
# Offload all models
|
| 1225 |
+
self.maybe_free_model_hooks()
|
| 1226 |
+
|
| 1227 |
+
if not return_dict:
|
| 1228 |
+
return (image,)
|
| 1229 |
+
|
| 1230 |
+
return FluxPipelineOutput(images=image)
|
src/transformer_flux.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
| 10 |
+
from diffusers.models.attention import FeedForward
|
| 11 |
+
from diffusers.models.attention_processor import (
|
| 12 |
+
Attention,
|
| 13 |
+
AttentionProcessor,
|
| 14 |
+
FluxAttnProcessor2_0,
|
| 15 |
+
FluxAttnProcessor2_0_NPU,
|
| 16 |
+
FusedFluxAttnProcessor2_0,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 20 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 21 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
| 22 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 23 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
| 24 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 27 |
+
|
| 28 |
+
@maybe_allow_in_graph
|
| 29 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 34 |
+
|
| 35 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 36 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 37 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 38 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 39 |
+
|
| 40 |
+
if is_torch_npu_available():
|
| 41 |
+
processor = FluxAttnProcessor2_0_NPU()
|
| 42 |
+
else:
|
| 43 |
+
processor = FluxAttnProcessor2_0()
|
| 44 |
+
self.attn = Attention(
|
| 45 |
+
query_dim=dim,
|
| 46 |
+
cross_attention_dim=None,
|
| 47 |
+
dim_head=attention_head_dim,
|
| 48 |
+
heads=num_attention_heads,
|
| 49 |
+
out_dim=dim,
|
| 50 |
+
bias=True,
|
| 51 |
+
processor=processor,
|
| 52 |
+
qk_norm="rms_norm",
|
| 53 |
+
eps=1e-6,
|
| 54 |
+
pre_only=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
hidden_states: torch.Tensor,
|
| 60 |
+
cond_hidden_states: torch.Tensor,
|
| 61 |
+
temb: torch.Tensor,
|
| 62 |
+
cond_temb: torch.Tensor,
|
| 63 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 64 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 65 |
+
) -> torch.Tensor:
|
| 66 |
+
use_cond = cond_hidden_states is not None
|
| 67 |
+
|
| 68 |
+
residual = hidden_states
|
| 69 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 70 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 71 |
+
|
| 72 |
+
if use_cond:
|
| 73 |
+
residual_cond = cond_hidden_states
|
| 74 |
+
norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
|
| 75 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
|
| 76 |
+
norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 77 |
+
else:
|
| 78 |
+
norm_hidden_states_concat = norm_hidden_states
|
| 79 |
+
|
| 80 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 81 |
+
if use_cond:
|
| 82 |
+
attn_output = self.attn(
|
| 83 |
+
hidden_states=norm_hidden_states_concat,
|
| 84 |
+
image_rotary_emb=image_rotary_emb,
|
| 85 |
+
use_cond=use_cond,
|
| 86 |
+
**joint_attention_kwargs,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
attn_output = self.attn(
|
| 90 |
+
hidden_states=norm_hidden_states_concat,
|
| 91 |
+
image_rotary_emb=image_rotary_emb,
|
| 92 |
+
**joint_attention_kwargs,
|
| 93 |
+
)
|
| 94 |
+
if use_cond:
|
| 95 |
+
attn_output, cond_attn_output = attn_output
|
| 96 |
+
|
| 97 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 98 |
+
gate = gate.unsqueeze(1)
|
| 99 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 100 |
+
hidden_states = residual + hidden_states
|
| 101 |
+
|
| 102 |
+
if use_cond:
|
| 103 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
| 104 |
+
cond_gate = cond_gate.unsqueeze(1)
|
| 105 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
| 106 |
+
condition_latents = residual_cond + condition_latents
|
| 107 |
+
|
| 108 |
+
if hidden_states.dtype == torch.float16:
|
| 109 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 110 |
+
|
| 111 |
+
return hidden_states, condition_latents if use_cond else None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@maybe_allow_in_graph
|
| 115 |
+
class FluxTransformerBlock(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 122 |
+
|
| 123 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 124 |
+
|
| 125 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 126 |
+
processor = FluxAttnProcessor2_0()
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 130 |
+
)
|
| 131 |
+
self.attn = Attention(
|
| 132 |
+
query_dim=dim,
|
| 133 |
+
cross_attention_dim=None,
|
| 134 |
+
added_kv_proj_dim=dim,
|
| 135 |
+
dim_head=attention_head_dim,
|
| 136 |
+
heads=num_attention_heads,
|
| 137 |
+
out_dim=dim,
|
| 138 |
+
context_pre_only=False,
|
| 139 |
+
bias=True,
|
| 140 |
+
processor=processor,
|
| 141 |
+
qk_norm=qk_norm,
|
| 142 |
+
eps=eps,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 146 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 147 |
+
|
| 148 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 149 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 150 |
+
|
| 151 |
+
# let chunk size default to None
|
| 152 |
+
self._chunk_size = None
|
| 153 |
+
self._chunk_dim = 0
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
hidden_states: torch.Tensor,
|
| 158 |
+
cond_hidden_states: torch.Tensor,
|
| 159 |
+
encoder_hidden_states: torch.Tensor,
|
| 160 |
+
temb: torch.Tensor,
|
| 161 |
+
cond_temb: torch.Tensor,
|
| 162 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 163 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 164 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 165 |
+
use_cond = cond_hidden_states is not None
|
| 166 |
+
|
| 167 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 168 |
+
if use_cond:
|
| 169 |
+
(
|
| 170 |
+
norm_cond_hidden_states,
|
| 171 |
+
cond_gate_msa,
|
| 172 |
+
cond_shift_mlp,
|
| 173 |
+
cond_scale_mlp,
|
| 174 |
+
cond_gate_mlp,
|
| 175 |
+
) = self.norm1(cond_hidden_states, emb=cond_temb)
|
| 176 |
+
norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 177 |
+
|
| 178 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 179 |
+
encoder_hidden_states, emb=temb
|
| 180 |
+
)
|
| 181 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 182 |
+
# Attention.
|
| 183 |
+
if use_cond:
|
| 184 |
+
attention_outputs = self.attn(
|
| 185 |
+
hidden_states=norm_hidden_states,
|
| 186 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 187 |
+
image_rotary_emb=image_rotary_emb,
|
| 188 |
+
use_cond=use_cond,
|
| 189 |
+
**joint_attention_kwargs,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
attention_outputs = self.attn(
|
| 193 |
+
hidden_states=norm_hidden_states,
|
| 194 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 195 |
+
image_rotary_emb=image_rotary_emb,
|
| 196 |
+
**joint_attention_kwargs,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
attn_output, context_attn_output = attention_outputs[:2]
|
| 200 |
+
cond_attn_output = attention_outputs[2] if use_cond else None
|
| 201 |
+
|
| 202 |
+
# Process attention outputs for the `hidden_states`.
|
| 203 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 204 |
+
hidden_states = hidden_states + attn_output
|
| 205 |
+
|
| 206 |
+
if use_cond:
|
| 207 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
| 208 |
+
cond_hidden_states = cond_hidden_states + cond_attn_output
|
| 209 |
+
|
| 210 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 211 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 212 |
+
|
| 213 |
+
if use_cond:
|
| 214 |
+
norm_cond_hidden_states = self.norm2(cond_hidden_states)
|
| 215 |
+
norm_cond_hidden_states = (
|
| 216 |
+
norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
|
| 217 |
+
+ cond_shift_mlp[:, None]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
ff_output = self.ff(norm_hidden_states)
|
| 221 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 222 |
+
hidden_states = hidden_states + ff_output
|
| 223 |
+
|
| 224 |
+
if use_cond:
|
| 225 |
+
cond_ff_output = self.ff(norm_cond_hidden_states)
|
| 226 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
| 227 |
+
cond_hidden_states = cond_hidden_states + cond_ff_output
|
| 228 |
+
|
| 229 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 230 |
+
|
| 231 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 232 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 233 |
+
|
| 234 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 235 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 236 |
+
|
| 237 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 238 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 239 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 240 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 241 |
+
|
| 242 |
+
return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class FluxTransformer2DModel(
|
| 246 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
| 247 |
+
):
|
| 248 |
+
_supports_gradient_checkpointing = True
|
| 249 |
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 250 |
+
|
| 251 |
+
@register_to_config
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
patch_size: int = 1,
|
| 255 |
+
in_channels: int = 64,
|
| 256 |
+
out_channels: Optional[int] = None,
|
| 257 |
+
num_layers: int = 19,
|
| 258 |
+
num_single_layers: int = 38,
|
| 259 |
+
attention_head_dim: int = 128,
|
| 260 |
+
num_attention_heads: int = 24,
|
| 261 |
+
joint_attention_dim: int = 4096,
|
| 262 |
+
pooled_projection_dim: int = 768,
|
| 263 |
+
guidance_embeds: bool = False,
|
| 264 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
| 265 |
+
):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.out_channels = out_channels or in_channels
|
| 268 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 269 |
+
|
| 270 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 271 |
+
|
| 272 |
+
text_time_guidance_cls = (
|
| 273 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 274 |
+
)
|
| 275 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 276 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 280 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 281 |
+
|
| 282 |
+
self.transformer_blocks = nn.ModuleList(
|
| 283 |
+
[
|
| 284 |
+
FluxTransformerBlock(
|
| 285 |
+
dim=self.inner_dim,
|
| 286 |
+
num_attention_heads=num_attention_heads,
|
| 287 |
+
attention_head_dim=attention_head_dim,
|
| 288 |
+
)
|
| 289 |
+
for _ in range(num_layers)
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 294 |
+
[
|
| 295 |
+
FluxSingleTransformerBlock(
|
| 296 |
+
dim=self.inner_dim,
|
| 297 |
+
num_attention_heads=num_attention_heads,
|
| 298 |
+
attention_head_dim=attention_head_dim,
|
| 299 |
+
)
|
| 300 |
+
for _ in range(num_single_layers)
|
| 301 |
+
]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 305 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 306 |
+
|
| 307 |
+
self.gradient_checkpointing = False
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 311 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 312 |
+
r"""
|
| 313 |
+
Returns:
|
| 314 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 315 |
+
indexed by its weight name.
|
| 316 |
+
"""
|
| 317 |
+
# set recursively
|
| 318 |
+
processors = {}
|
| 319 |
+
|
| 320 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 321 |
+
if hasattr(module, "get_processor"):
|
| 322 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 323 |
+
|
| 324 |
+
for sub_name, child in module.named_children():
|
| 325 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 326 |
+
|
| 327 |
+
return processors
|
| 328 |
+
|
| 329 |
+
for name, module in self.named_children():
|
| 330 |
+
fn_recursive_add_processors(name, module, processors)
|
| 331 |
+
|
| 332 |
+
return processors
|
| 333 |
+
|
| 334 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 335 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 336 |
+
r"""
|
| 337 |
+
Sets the attention processor to use to compute attention.
|
| 338 |
+
|
| 339 |
+
Parameters:
|
| 340 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 341 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 342 |
+
for **all** `Attention` layers.
|
| 343 |
+
|
| 344 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 345 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 346 |
+
|
| 347 |
+
"""
|
| 348 |
+
count = len(self.attn_processors.keys())
|
| 349 |
+
|
| 350 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 353 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 357 |
+
if hasattr(module, "set_processor"):
|
| 358 |
+
if not isinstance(processor, dict):
|
| 359 |
+
module.set_processor(processor)
|
| 360 |
+
else:
|
| 361 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 362 |
+
|
| 363 |
+
for sub_name, child in module.named_children():
|
| 364 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 365 |
+
|
| 366 |
+
# Make a copy of the processor dictionary to avoid destructive changes to the original.
|
| 367 |
+
if isinstance(processor, dict):
|
| 368 |
+
processor = processor.copy()
|
| 369 |
+
|
| 370 |
+
for name, module in self.named_children():
|
| 371 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 372 |
+
|
| 373 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
| 374 |
+
def fuse_qkv_projections(self):
|
| 375 |
+
"""
|
| 376 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 377 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 378 |
+
|
| 379 |
+
<Tip warning={true}>
|
| 380 |
+
|
| 381 |
+
This API is 🧪 experimental.
|
| 382 |
+
|
| 383 |
+
</Tip>
|
| 384 |
+
"""
|
| 385 |
+
self.original_attn_processors = None
|
| 386 |
+
|
| 387 |
+
for _, attn_processor in self.attn_processors.items():
|
| 388 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 389 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 390 |
+
|
| 391 |
+
self.original_attn_processors = self.attn_processors
|
| 392 |
+
|
| 393 |
+
for module in self.modules():
|
| 394 |
+
if isinstance(module, Attention):
|
| 395 |
+
module.fuse_projections(fuse=True)
|
| 396 |
+
|
| 397 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
| 398 |
+
|
| 399 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 400 |
+
def unfuse_qkv_projections(self):
|
| 401 |
+
"""Disables the fused QKV projection if enabled.
|
| 402 |
+
|
| 403 |
+
<Tip warning={true}>
|
| 404 |
+
|
| 405 |
+
This API is 🧪 experimental.
|
| 406 |
+
|
| 407 |
+
</Tip>
|
| 408 |
+
|
| 409 |
+
"""
|
| 410 |
+
if self.original_attn_processors is not None:
|
| 411 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 412 |
+
|
| 413 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 414 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 415 |
+
module.gradient_checkpointing = value
|
| 416 |
+
|
| 417 |
+
def forward(
|
| 418 |
+
self,
|
| 419 |
+
hidden_states: torch.Tensor,
|
| 420 |
+
cond_hidden_states: torch.Tensor = None,
|
| 421 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 422 |
+
pooled_projections: torch.Tensor = None,
|
| 423 |
+
timestep: torch.LongTensor = None,
|
| 424 |
+
img_ids: torch.Tensor = None,
|
| 425 |
+
txt_ids: torch.Tensor = None,
|
| 426 |
+
guidance: torch.Tensor = None,
|
| 427 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 428 |
+
controlnet_block_samples=None,
|
| 429 |
+
controlnet_single_block_samples=None,
|
| 430 |
+
return_dict: bool = True,
|
| 431 |
+
controlnet_blocks_repeat: bool = False,
|
| 432 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 433 |
+
if cond_hidden_states is not None:
|
| 434 |
+
use_condition = True
|
| 435 |
+
else:
|
| 436 |
+
use_condition = False
|
| 437 |
+
|
| 438 |
+
if joint_attention_kwargs is not None:
|
| 439 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 440 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 441 |
+
else:
|
| 442 |
+
lora_scale = 1.0
|
| 443 |
+
|
| 444 |
+
if USE_PEFT_BACKEND:
|
| 445 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 446 |
+
scale_lora_layers(self, lora_scale)
|
| 447 |
+
else:
|
| 448 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 449 |
+
logger.warning(
|
| 450 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 451 |
+
)
|
| 452 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 453 |
+
if cond_hidden_states is not None:
|
| 454 |
+
if cond_hidden_states.shape[-1] == self.x_embedder.in_features:
|
| 455 |
+
cond_hidden_states = self.x_embedder(cond_hidden_states)
|
| 456 |
+
elif cond_hidden_states.shape[-1] == 64:
|
| 457 |
+
# 只用前64列权重和bias
|
| 458 |
+
weight = self.x_embedder.weight[:, :64] # [inner_dim, 64]
|
| 459 |
+
bias = self.x_embedder.bias
|
| 460 |
+
cond_hidden_states = torch.nn.functional.linear(cond_hidden_states, weight, bias)
|
| 461 |
+
|
| 462 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 463 |
+
if guidance is not None:
|
| 464 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 465 |
+
else:
|
| 466 |
+
guidance = None
|
| 467 |
+
|
| 468 |
+
temb = (
|
| 469 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 470 |
+
if guidance is None
|
| 471 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
cond_temb = (
|
| 475 |
+
self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
|
| 476 |
+
if guidance is None
|
| 477 |
+
else self.time_text_embed(
|
| 478 |
+
torch.ones_like(timestep) * 0, guidance, pooled_projections
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
if txt_ids.ndim == 3:
|
| 486 |
+
logger.warning(
|
| 487 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 488 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 489 |
+
)
|
| 490 |
+
txt_ids = txt_ids[0]
|
| 491 |
+
if img_ids.ndim == 3:
|
| 492 |
+
logger.warning(
|
| 493 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 494 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 495 |
+
)
|
| 496 |
+
img_ids = img_ids[0]
|
| 497 |
+
|
| 498 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 499 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 500 |
+
|
| 501 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 502 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 503 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 504 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 505 |
+
|
| 506 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 507 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 508 |
+
|
| 509 |
+
def create_custom_forward(module, return_dict=None):
|
| 510 |
+
def custom_forward(*inputs):
|
| 511 |
+
if return_dict is not None:
|
| 512 |
+
return module(*inputs, return_dict=return_dict)
|
| 513 |
+
else:
|
| 514 |
+
return module(*inputs)
|
| 515 |
+
|
| 516 |
+
return custom_forward
|
| 517 |
+
|
| 518 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 519 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 520 |
+
create_custom_forward(block),
|
| 521 |
+
hidden_states,
|
| 522 |
+
encoder_hidden_states,
|
| 523 |
+
temb,
|
| 524 |
+
image_rotary_emb,
|
| 525 |
+
cond_temb=cond_temb if use_condition else None,
|
| 526 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 527 |
+
**ckpt_kwargs,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
else:
|
| 531 |
+
encoder_hidden_states, hidden_states, cond_hidden_states = block(
|
| 532 |
+
hidden_states=hidden_states,
|
| 533 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 534 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 535 |
+
temb=temb,
|
| 536 |
+
cond_temb=cond_temb if use_condition else None,
|
| 537 |
+
image_rotary_emb=image_rotary_emb,
|
| 538 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# controlnet residual
|
| 542 |
+
if controlnet_block_samples is not None:
|
| 543 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 544 |
+
interval_control = int(np.ceil(interval_control))
|
| 545 |
+
# For Xlabs ControlNet.
|
| 546 |
+
if controlnet_blocks_repeat:
|
| 547 |
+
hidden_states = (
|
| 548 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 552 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 553 |
+
|
| 554 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 555 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 556 |
+
|
| 557 |
+
def create_custom_forward(module, return_dict=None):
|
| 558 |
+
def custom_forward(*inputs):
|
| 559 |
+
if return_dict is not None:
|
| 560 |
+
return module(*inputs, return_dict=return_dict)
|
| 561 |
+
else:
|
| 562 |
+
return module(*inputs)
|
| 563 |
+
|
| 564 |
+
return custom_forward
|
| 565 |
+
|
| 566 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 567 |
+
hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 568 |
+
create_custom_forward(block),
|
| 569 |
+
hidden_states,
|
| 570 |
+
temb,
|
| 571 |
+
image_rotary_emb,
|
| 572 |
+
cond_temb=cond_temb if use_condition else None,
|
| 573 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 574 |
+
**ckpt_kwargs,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
else:
|
| 578 |
+
hidden_states, cond_hidden_states = block(
|
| 579 |
+
hidden_states=hidden_states,
|
| 580 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 581 |
+
temb=temb,
|
| 582 |
+
cond_temb=cond_temb if use_condition else None,
|
| 583 |
+
image_rotary_emb=image_rotary_emb,
|
| 584 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# controlnet residual
|
| 588 |
+
if controlnet_single_block_samples is not None:
|
| 589 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 590 |
+
interval_control = int(np.ceil(interval_control))
|
| 591 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 592 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 593 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 597 |
+
|
| 598 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 599 |
+
output = self.proj_out(hidden_states)
|
| 600 |
+
|
| 601 |
+
if USE_PEFT_BACKEND:
|
| 602 |
+
# remove `lora_scale` from each PEFT layer
|
| 603 |
+
unscale_lora_layers(self, lora_scale)
|
| 604 |
+
|
| 605 |
+
if not return_dict:
|
| 606 |
+
return (output,)
|
| 607 |
+
|
| 608 |
+
return Transformer2DModelOutput(sample=output)
|
train/default_config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
main_process_port: 14121
|
| 5 |
+
downcast_bf16: 'no'
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: fp16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 8
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
train/src/__init__.py
ADDED
|
File without changes
|
train/src/condition/edge_extraction.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from .util import HWC3, nms, safe_step, resize_image_with_pad, common_input_validate, get_intensity_mask, combine_layers
|
| 13 |
+
|
| 14 |
+
from .pidi import pidinet
|
| 15 |
+
from .ted import TED
|
| 16 |
+
from .lineart import Generator as LineartGenerator
|
| 17 |
+
from .informative_drawing import Generator
|
| 18 |
+
from .hed import ControlNetHED_Apache2
|
| 19 |
+
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from skimage import morphology
|
| 23 |
+
import argparse
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PREPROCESSORS_ROOT = os.getenv("PREPROCESSORS_ROOT", os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), "models/preprocessors"))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HEDDetector:
|
| 31 |
+
def __init__(self, netNetwork):
|
| 32 |
+
self.netNetwork = netNetwork
|
| 33 |
+
self.device = "cpu"
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def from_pretrained(cls, filename="ControlNetHED.pth"):
|
| 37 |
+
model_path = os.path.join(PREPROCESSORS_ROOT, filename)
|
| 38 |
+
|
| 39 |
+
netNetwork = ControlNetHED_Apache2()
|
| 40 |
+
netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
|
| 41 |
+
netNetwork.float().eval()
|
| 42 |
+
|
| 43 |
+
return cls(netNetwork)
|
| 44 |
+
|
| 45 |
+
def to(self, device):
|
| 46 |
+
self.netNetwork.to(device)
|
| 47 |
+
self.device = device
|
| 48 |
+
return self
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, upscale_method="INTER_CUBIC", **kwargs):
|
| 52 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 53 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 54 |
+
|
| 55 |
+
assert input_image.ndim == 3
|
| 56 |
+
H, W, C = input_image.shape
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
image_hed = torch.from_numpy(input_image).float().to(self.device)
|
| 59 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
| 60 |
+
edges = self.netNetwork(image_hed)
|
| 61 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
| 62 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
| 63 |
+
edges = np.stack(edges, axis=2)
|
| 64 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
| 65 |
+
if safe:
|
| 66 |
+
edge = safe_step(edge)
|
| 67 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
| 68 |
+
|
| 69 |
+
detected_map = edge
|
| 70 |
+
|
| 71 |
+
if scribble:
|
| 72 |
+
detected_map = nms(detected_map, 127, 3.0)
|
| 73 |
+
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
| 74 |
+
detected_map[detected_map > 4] = 255
|
| 75 |
+
detected_map[detected_map < 255] = 0
|
| 76 |
+
|
| 77 |
+
detected_map = HWC3(remove_pad(detected_map))
|
| 78 |
+
|
| 79 |
+
if output_type == "pil":
|
| 80 |
+
detected_map = Image.fromarray(detected_map)
|
| 81 |
+
|
| 82 |
+
return detected_map
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CannyDetector:
|
| 86 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 87 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 88 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 89 |
+
detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
|
| 90 |
+
detected_map = HWC3(remove_pad(detected_map))
|
| 91 |
+
|
| 92 |
+
if output_type == "pil":
|
| 93 |
+
detected_map = Image.fromarray(detected_map)
|
| 94 |
+
|
| 95 |
+
return detected_map
|
| 96 |
+
|
| 97 |
+
class PidiNetDetector:
|
| 98 |
+
def __init__(self, netNetwork):
|
| 99 |
+
self.netNetwork = netNetwork
|
| 100 |
+
self.device = "cpu"
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_pretrained(cls, filename="table5_pidinet.pth"):
|
| 104 |
+
model_path = os.path.join(PREPROCESSORS_ROOT, filename)
|
| 105 |
+
|
| 106 |
+
netNetwork = pidinet()
|
| 107 |
+
netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})
|
| 108 |
+
netNetwork.eval()
|
| 109 |
+
|
| 110 |
+
return cls(netNetwork)
|
| 111 |
+
|
| 112 |
+
def to(self, device):
|
| 113 |
+
self.netNetwork.to(device)
|
| 114 |
+
self.device = device
|
| 115 |
+
return self
|
| 116 |
+
|
| 117 |
+
def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, apply_filter=False, upscale_method="INTER_CUBIC", **kwargs):
|
| 118 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 119 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 120 |
+
|
| 121 |
+
detected_map = detected_map[:, :, ::-1].copy()
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
image_pidi = torch.from_numpy(detected_map).float().to(self.device)
|
| 124 |
+
image_pidi = image_pidi / 255.0
|
| 125 |
+
image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
|
| 126 |
+
edge = self.netNetwork(image_pidi)[-1]
|
| 127 |
+
edge = edge.cpu().numpy()
|
| 128 |
+
if apply_filter:
|
| 129 |
+
edge = edge > 0.5
|
| 130 |
+
if safe:
|
| 131 |
+
edge = safe_step(edge)
|
| 132 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
| 133 |
+
|
| 134 |
+
detected_map = edge[0, 0]
|
| 135 |
+
|
| 136 |
+
if scribble:
|
| 137 |
+
detected_map = nms(detected_map, 127, 3.0)
|
| 138 |
+
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
| 139 |
+
detected_map[detected_map > 4] = 255
|
| 140 |
+
detected_map[detected_map < 255] = 0
|
| 141 |
+
|
| 142 |
+
detected_map = HWC3(remove_pad(detected_map))
|
| 143 |
+
|
| 144 |
+
if output_type == "pil":
|
| 145 |
+
detected_map = Image.fromarray(detected_map)
|
| 146 |
+
|
| 147 |
+
return detected_map
|
| 148 |
+
|
| 149 |
+
class TEDDetector:
|
| 150 |
+
def __init__(self, model):
|
| 151 |
+
self.model = model
|
| 152 |
+
self.device = "cpu"
|
| 153 |
+
|
| 154 |
+
@classmethod
|
| 155 |
+
def from_pretrained(cls, filename="7_model.pth"):
|
| 156 |
+
model_path = os.path.join(PREPROCESSORS_ROOT, filename)
|
| 157 |
+
model = TED()
|
| 158 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| 159 |
+
model.eval()
|
| 160 |
+
return cls(model)
|
| 161 |
+
|
| 162 |
+
def to(self, device):
|
| 163 |
+
self.model.to(device)
|
| 164 |
+
self.device = device
|
| 165 |
+
return self
|
| 166 |
+
|
| 167 |
+
def __call__(self, input_image, detect_resolution=512, safe_steps=2, upscale_method="INTER_CUBIC", output_type=None, **kwargs):
|
| 168 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 169 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 170 |
+
|
| 171 |
+
H, W, _ = input_image.shape
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
image_teed = torch.from_numpy(input_image.copy()).float().to(self.device)
|
| 174 |
+
image_teed = rearrange(image_teed, 'h w c -> 1 c h w')
|
| 175 |
+
edges = self.model(image_teed)
|
| 176 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
| 177 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
| 178 |
+
edges = np.stack(edges, axis=2)
|
| 179 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
| 180 |
+
if safe_steps != 0:
|
| 181 |
+
edge = safe_step(edge, safe_steps)
|
| 182 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
| 183 |
+
|
| 184 |
+
detected_map = remove_pad(HWC3(edge))
|
| 185 |
+
if output_type == "pil":
|
| 186 |
+
detected_map = Image.fromarray(detected_map[..., :3])
|
| 187 |
+
|
| 188 |
+
return detected_map
|
| 189 |
+
|
| 190 |
+
class LineartStandardDetector:
|
| 191 |
+
def __call__(self, input_image=None, guassian_sigma=6.0, intensity_threshold=8, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 192 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 193 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 194 |
+
|
| 195 |
+
x = input_image.astype(np.float32)
|
| 196 |
+
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
|
| 197 |
+
intensity = np.min(g - x, axis=2).clip(0, 255)
|
| 198 |
+
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
|
| 199 |
+
intensity *= 127
|
| 200 |
+
detected_map = intensity.clip(0, 255).astype(np.uint8)
|
| 201 |
+
|
| 202 |
+
detected_map = HWC3(remove_pad(detected_map))
|
| 203 |
+
if output_type == "pil":
|
| 204 |
+
detected_map = Image.fromarray(detected_map)
|
| 205 |
+
return detected_map
|
| 206 |
+
|
| 207 |
+
class AnyLinePreprocessor:
|
| 208 |
+
def __init__(self, mteed_model, lineart_standard_detector):
|
| 209 |
+
self.device = "cpu"
|
| 210 |
+
self.mteed_model = mteed_model
|
| 211 |
+
self.lineart_standard_detector = lineart_standard_detector
|
| 212 |
+
|
| 213 |
+
@classmethod
|
| 214 |
+
def from_pretrained(cls, mteed_filename="MTEED.pth"):
|
| 215 |
+
mteed_model = TEDDetector.from_pretrained(filename=mteed_filename)
|
| 216 |
+
lineart_standard_detector = LineartStandardDetector()
|
| 217 |
+
return cls(mteed_model, lineart_standard_detector)
|
| 218 |
+
|
| 219 |
+
def to(self, device):
|
| 220 |
+
self.mteed_model.to(device)
|
| 221 |
+
self.device = device
|
| 222 |
+
return self
|
| 223 |
+
|
| 224 |
+
def __call__(self, image, resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1):
|
| 225 |
+
# Process the image with MTEED model
|
| 226 |
+
mteed_result = self.mteed_model(image, detect_resolution=resolution)
|
| 227 |
+
|
| 228 |
+
# Process the image with the lineart standard preprocessor
|
| 229 |
+
lineart_result = self.lineart_standard_detector(image, guassian_sigma=2, intensity_threshold=3, resolution=resolution)
|
| 230 |
+
|
| 231 |
+
_lineart_result = get_intensity_mask(lineart_result, lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound)
|
| 232 |
+
_cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity)
|
| 233 |
+
_lineart_result = _lineart_result * _cleaned
|
| 234 |
+
_mteed_result = mteed_result
|
| 235 |
+
|
| 236 |
+
result = combine_layers(_mteed_result, _lineart_result)
|
| 237 |
+
# print(result.shape)
|
| 238 |
+
return result
|
| 239 |
+
|
| 240 |
+
class LineartDetector:
|
| 241 |
+
def __init__(self, model, coarse_model):
|
| 242 |
+
self.model = model
|
| 243 |
+
self.model_coarse = coarse_model
|
| 244 |
+
self.device = "cpu"
|
| 245 |
+
|
| 246 |
+
@classmethod
|
| 247 |
+
def from_pretrained(cls, filename="sk_model.pth", coarse_filename="sk_model2.pth"):
|
| 248 |
+
model_path = os.path.join(PREPROCESSORS_ROOT, filename)
|
| 249 |
+
coarse_model_path = os.path.join(PREPROCESSORS_ROOT, coarse_filename)
|
| 250 |
+
|
| 251 |
+
model = LineartGenerator(3, 1, 3)
|
| 252 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| 253 |
+
model.eval()
|
| 254 |
+
|
| 255 |
+
coarse_model = LineartGenerator(3, 1, 3)
|
| 256 |
+
coarse_model.load_state_dict(torch.load(coarse_model_path, map_location="cpu"))
|
| 257 |
+
coarse_model.eval()
|
| 258 |
+
|
| 259 |
+
return cls(model, coarse_model)
|
| 260 |
+
|
| 261 |
+
def to(self, device):
|
| 262 |
+
self.model.to(device)
|
| 263 |
+
self.model_coarse.to(device)
|
| 264 |
+
self.device = device
|
| 265 |
+
return self
|
| 266 |
+
|
| 267 |
+
def __call__(self, input_image, coarse=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 268 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 269 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 270 |
+
|
| 271 |
+
model = self.model_coarse if coarse else self.model
|
| 272 |
+
assert detected_map.ndim == 3
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
image = torch.from_numpy(detected_map).float().to(self.device)
|
| 275 |
+
image = image / 255.0
|
| 276 |
+
image = rearrange(image, 'h w c -> 1 c h w')
|
| 277 |
+
line = model(image)[0][0]
|
| 278 |
+
|
| 279 |
+
line = line.cpu().numpy()
|
| 280 |
+
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
| 281 |
+
|
| 282 |
+
detected_map = HWC3(line)
|
| 283 |
+
detected_map = remove_pad(255 - detected_map)
|
| 284 |
+
|
| 285 |
+
if output_type == "pil":
|
| 286 |
+
detected_map = Image.fromarray(detected_map)
|
| 287 |
+
|
| 288 |
+
return detected_map
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class InformativeDetector:
|
| 292 |
+
def __init__(self, anime_model, contour_model):
|
| 293 |
+
self.anime_model = anime_model
|
| 294 |
+
self.contour_model = contour_model
|
| 295 |
+
self.device = "cpu"
|
| 296 |
+
|
| 297 |
+
@classmethod
|
| 298 |
+
def from_pretrained(cls, anime_filename="anime_style.pth", contour_filename="contour_style.pth"):
|
| 299 |
+
anime_model_path = os.path.join(PREPROCESSORS_ROOT, anime_filename)
|
| 300 |
+
contour_model_path = os.path.join(PREPROCESSORS_ROOT, contour_filename)
|
| 301 |
+
|
| 302 |
+
# 创建两个Generator模型
|
| 303 |
+
anime_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3
|
| 304 |
+
anime_model.load_state_dict(torch.load(anime_model_path, map_location="cpu"))
|
| 305 |
+
anime_model.eval()
|
| 306 |
+
|
| 307 |
+
contour_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3
|
| 308 |
+
contour_model.load_state_dict(torch.load(contour_model_path, map_location="cpu"))
|
| 309 |
+
contour_model.eval()
|
| 310 |
+
|
| 311 |
+
return cls(anime_model, contour_model)
|
| 312 |
+
|
| 313 |
+
def to(self, device):
|
| 314 |
+
self.anime_model.to(device)
|
| 315 |
+
self.contour_model.to(device)
|
| 316 |
+
self.device = device
|
| 317 |
+
return self
|
| 318 |
+
|
| 319 |
+
def __call__(self, input_image, style="anime", detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 320 |
+
"""
|
| 321 |
+
提取sketch
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
input_image: 输入图像
|
| 325 |
+
style: "anime" 或 "contour"
|
| 326 |
+
detect_resolution: 检测分辨率
|
| 327 |
+
output_type: 输出类型
|
| 328 |
+
upscale_method: 上采样方法
|
| 329 |
+
"""
|
| 330 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 331 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 332 |
+
|
| 333 |
+
# 选择模型
|
| 334 |
+
model = self.anime_model if style == "anime" else self.contour_model
|
| 335 |
+
|
| 336 |
+
assert detected_map.ndim == 3
|
| 337 |
+
with torch.no_grad():
|
| 338 |
+
image = torch.from_numpy(detected_map).float().to(self.device)
|
| 339 |
+
image = image / 255.0
|
| 340 |
+
# 转换维度 (h, w, c) -> (1, c, h, w)
|
| 341 |
+
image = image.permute(2, 0, 1).unsqueeze(0)
|
| 342 |
+
|
| 343 |
+
# 生成sketch
|
| 344 |
+
sketch = model(image)
|
| 345 |
+
sketch = sketch[0][0] # 取出第一个batch的第一个通道
|
| 346 |
+
|
| 347 |
+
sketch = sketch.cpu().numpy()
|
| 348 |
+
sketch = (sketch * 255.0).clip(0, 255).astype(np.uint8)
|
| 349 |
+
|
| 350 |
+
detected_map = HWC3(sketch)
|
| 351 |
+
detected_map = remove_pad(255 - detected_map) # 反转颜色
|
| 352 |
+
|
| 353 |
+
if output_type == "pil":
|
| 354 |
+
detected_map = Image.fromarray(detected_map)
|
| 355 |
+
|
| 356 |
+
return detected_map
|
train/src/condition/hed.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
| 2 |
+
# Please use this implementation in your products
|
| 3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
| 4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
| 5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
| 6 |
+
# and in this way it works better for gradio's RGB protocol
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .util import HWC3, nms, resize_image_with_pad, safe_step, common_input_validate
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DoubleConvBlock(torch.nn.Module):
|
| 21 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.convs = torch.nn.Sequential()
|
| 24 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 25 |
+
for i in range(1, layer_number):
|
| 26 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 27 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
| 28 |
+
|
| 29 |
+
def __call__(self, x, down_sampling=False):
|
| 30 |
+
h = x
|
| 31 |
+
if down_sampling:
|
| 32 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
| 33 |
+
for conv in self.convs:
|
| 34 |
+
h = conv(h)
|
| 35 |
+
h = torch.nn.functional.relu(h)
|
| 36 |
+
return h, self.projection(h)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
| 43 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
| 44 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
| 45 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
| 46 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
| 47 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
| 48 |
+
|
| 49 |
+
def __call__(self, x):
|
| 50 |
+
h = x - self.norm
|
| 51 |
+
h, projection1 = self.block1(h)
|
| 52 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
| 53 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
| 54 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
| 55 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
| 56 |
+
return projection1, projection2, projection3, projection4, projection5
|
train/src/condition/informative_drawing.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import functools
|
| 5 |
+
from torchvision import models
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
norm_layer = nn.InstanceNorm2d
|
| 11 |
+
|
| 12 |
+
class ResidualBlock(nn.Module):
|
| 13 |
+
def __init__(self, in_features):
|
| 14 |
+
super(ResidualBlock, self).__init__()
|
| 15 |
+
|
| 16 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
| 17 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 18 |
+
norm_layer(in_features),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.ReflectionPad2d(1),
|
| 21 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 22 |
+
norm_layer(in_features)
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return x + self.conv_block(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Generator(nn.Module):
|
| 32 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
| 33 |
+
super(Generator, self).__init__()
|
| 34 |
+
|
| 35 |
+
# Initial convolution block
|
| 36 |
+
model0 = [ nn.ReflectionPad2d(3),
|
| 37 |
+
nn.Conv2d(input_nc, 64, 7),
|
| 38 |
+
norm_layer(64),
|
| 39 |
+
nn.ReLU(inplace=True) ]
|
| 40 |
+
self.model0 = nn.Sequential(*model0)
|
| 41 |
+
|
| 42 |
+
# Downsampling
|
| 43 |
+
model1 = []
|
| 44 |
+
in_features = 64
|
| 45 |
+
out_features = in_features*2
|
| 46 |
+
for _ in range(2):
|
| 47 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
| 48 |
+
norm_layer(out_features),
|
| 49 |
+
nn.ReLU(inplace=True) ]
|
| 50 |
+
in_features = out_features
|
| 51 |
+
out_features = in_features*2
|
| 52 |
+
self.model1 = nn.Sequential(*model1)
|
| 53 |
+
|
| 54 |
+
model2 = []
|
| 55 |
+
# Residual blocks
|
| 56 |
+
for _ in range(n_residual_blocks):
|
| 57 |
+
model2 += [ResidualBlock(in_features)]
|
| 58 |
+
self.model2 = nn.Sequential(*model2)
|
| 59 |
+
|
| 60 |
+
# Upsampling
|
| 61 |
+
model3 = []
|
| 62 |
+
out_features = in_features//2
|
| 63 |
+
for _ in range(2):
|
| 64 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
| 65 |
+
norm_layer(out_features),
|
| 66 |
+
nn.ReLU(inplace=True) ]
|
| 67 |
+
in_features = out_features
|
| 68 |
+
out_features = in_features//2
|
| 69 |
+
self.model3 = nn.Sequential(*model3)
|
| 70 |
+
|
| 71 |
+
# Output layer
|
| 72 |
+
model4 = [ nn.ReflectionPad2d(3),
|
| 73 |
+
nn.Conv2d(64, output_nc, 7)]
|
| 74 |
+
if sigmoid:
|
| 75 |
+
model4 += [nn.Sigmoid()]
|
| 76 |
+
|
| 77 |
+
self.model4 = nn.Sequential(*model4)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, cond=None):
|
| 80 |
+
out = self.model0(x)
|
| 81 |
+
out = self.model1(out)
|
| 82 |
+
out = self.model2(out)
|
| 83 |
+
out = self.model3(out)
|
| 84 |
+
out = self.model4(out)
|
| 85 |
+
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
# Define a resnet block
|
| 89 |
+
class ResnetBlock(nn.Module):
|
| 90 |
+
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
| 91 |
+
super(ResnetBlock, self).__init__()
|
| 92 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
|
| 93 |
+
|
| 94 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
| 95 |
+
conv_block = []
|
| 96 |
+
p = 0
|
| 97 |
+
if padding_type == 'reflect':
|
| 98 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 99 |
+
elif padding_type == 'replicate':
|
| 100 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 101 |
+
elif padding_type == 'zero':
|
| 102 |
+
p = 1
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 105 |
+
|
| 106 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
| 107 |
+
norm_layer(dim),
|
| 108 |
+
activation]
|
| 109 |
+
if use_dropout:
|
| 110 |
+
conv_block += [nn.Dropout(0.5)]
|
| 111 |
+
|
| 112 |
+
p = 0
|
| 113 |
+
if padding_type == 'reflect':
|
| 114 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 115 |
+
elif padding_type == 'replicate':
|
| 116 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 117 |
+
elif padding_type == 'zero':
|
| 118 |
+
p = 1
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 121 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
| 122 |
+
norm_layer(dim)]
|
| 123 |
+
|
| 124 |
+
return nn.Sequential(*conv_block)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
out = x + self.conv_block(x)
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
class GlobalGenerator2(nn.Module):
|
| 131 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
| 132 |
+
padding_type='reflect', use_sig=False, n_UPsampling=0):
|
| 133 |
+
assert(n_blocks >= 0)
|
| 134 |
+
super(GlobalGenerator2, self).__init__()
|
| 135 |
+
activation = nn.ReLU(True)
|
| 136 |
+
|
| 137 |
+
mult = 8
|
| 138 |
+
model = [nn.ReflectionPad2d(4), nn.Conv2d(input_nc, ngf*mult, kernel_size=7, padding=0), norm_layer(ngf*mult), activation]
|
| 139 |
+
|
| 140 |
+
### downsample
|
| 141 |
+
for i in range(n_downsampling):
|
| 142 |
+
model += [nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=4, stride=2, padding=1),
|
| 143 |
+
norm_layer(ngf * mult // 2), activation]
|
| 144 |
+
mult = mult // 2
|
| 145 |
+
|
| 146 |
+
if n_UPsampling <= 0:
|
| 147 |
+
n_UPsampling = n_downsampling
|
| 148 |
+
|
| 149 |
+
### resnet blocks
|
| 150 |
+
for i in range(n_blocks):
|
| 151 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
| 152 |
+
|
| 153 |
+
### upsample
|
| 154 |
+
for i in range(n_UPsampling):
|
| 155 |
+
next_mult = mult // 2
|
| 156 |
+
if next_mult == 0:
|
| 157 |
+
next_mult = 1
|
| 158 |
+
mult = 1
|
| 159 |
+
|
| 160 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * next_mult), kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 161 |
+
norm_layer(int(ngf * next_mult)), activation]
|
| 162 |
+
mult = next_mult
|
| 163 |
+
|
| 164 |
+
if use_sig:
|
| 165 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
|
| 166 |
+
else:
|
| 167 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
| 168 |
+
self.model = nn.Sequential(*model)
|
| 169 |
+
|
| 170 |
+
def forward(self, input, cond=None):
|
| 171 |
+
return self.model(input)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class InceptionV3(nn.Module): #avg pool
|
| 175 |
+
def __init__(self, num_classes, isTrain, use_aux=True, pretrain=False, freeze=True, every_feat=False):
|
| 176 |
+
super(InceptionV3, self).__init__()
|
| 177 |
+
""" Inception v3 expects (299,299) sized images for training and has auxiliary output
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
self.every_feat = every_feat
|
| 181 |
+
|
| 182 |
+
self.model_ft = models.inception_v3(pretrained=pretrain)
|
| 183 |
+
stop = 0
|
| 184 |
+
if freeze and pretrain:
|
| 185 |
+
for child in self.model_ft.children():
|
| 186 |
+
if stop < 17:
|
| 187 |
+
for param in child.parameters():
|
| 188 |
+
param.requires_grad = False
|
| 189 |
+
stop += 1
|
| 190 |
+
|
| 191 |
+
num_ftrs = self.model_ft.AuxLogits.fc.in_features #768
|
| 192 |
+
self.model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
|
| 193 |
+
|
| 194 |
+
# Handle the primary net
|
| 195 |
+
num_ftrs = self.model_ft.fc.in_features #2048
|
| 196 |
+
self.model_ft.fc = nn.Linear(num_ftrs,num_classes)
|
| 197 |
+
|
| 198 |
+
self.model_ft.input_size = 299
|
| 199 |
+
|
| 200 |
+
self.isTrain = isTrain
|
| 201 |
+
self.use_aux = use_aux
|
| 202 |
+
|
| 203 |
+
if self.isTrain:
|
| 204 |
+
self.model_ft.train()
|
| 205 |
+
else:
|
| 206 |
+
self.model_ft.eval()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def forward(self, x, cond=None, catch_gates=False):
|
| 210 |
+
# N x 3 x 299 x 299
|
| 211 |
+
x = self.model_ft.Conv2d_1a_3x3(x)
|
| 212 |
+
|
| 213 |
+
# N x 32 x 149 x 149
|
| 214 |
+
x = self.model_ft.Conv2d_2a_3x3(x)
|
| 215 |
+
# N x 32 x 147 x 147
|
| 216 |
+
x = self.model_ft.Conv2d_2b_3x3(x)
|
| 217 |
+
# N x 64 x 147 x 147
|
| 218 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 219 |
+
# N x 64 x 73 x 73
|
| 220 |
+
x = self.model_ft.Conv2d_3b_1x1(x)
|
| 221 |
+
# N x 80 x 73 x 73
|
| 222 |
+
x = self.model_ft.Conv2d_4a_3x3(x)
|
| 223 |
+
|
| 224 |
+
# N x 192 x 71 x 71
|
| 225 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 226 |
+
# N x 192 x 35 x 35
|
| 227 |
+
x = self.model_ft.Mixed_5b(x)
|
| 228 |
+
feat1 = x
|
| 229 |
+
# N x 256 x 35 x 35
|
| 230 |
+
x = self.model_ft.Mixed_5c(x)
|
| 231 |
+
feat11 = x
|
| 232 |
+
# N x 288 x 35 x 35
|
| 233 |
+
x = self.model_ft.Mixed_5d(x)
|
| 234 |
+
feat12 = x
|
| 235 |
+
# N x 288 x 35 x 35
|
| 236 |
+
x = self.model_ft.Mixed_6a(x)
|
| 237 |
+
feat2 = x
|
| 238 |
+
# N x 768 x 17 x 17
|
| 239 |
+
x = self.model_ft.Mixed_6b(x)
|
| 240 |
+
feat21 = x
|
| 241 |
+
# N x 768 x 17 x 17
|
| 242 |
+
x = self.model_ft.Mixed_6c(x)
|
| 243 |
+
feat22 = x
|
| 244 |
+
# N x 768 x 17 x 17
|
| 245 |
+
x = self.model_ft.Mixed_6d(x)
|
| 246 |
+
feat23 = x
|
| 247 |
+
# N x 768 x 17 x 17
|
| 248 |
+
x = self.model_ft.Mixed_6e(x)
|
| 249 |
+
|
| 250 |
+
feat3 = x
|
| 251 |
+
|
| 252 |
+
# N x 768 x 17 x 17
|
| 253 |
+
aux_defined = self.isTrain and self.use_aux
|
| 254 |
+
if aux_defined:
|
| 255 |
+
aux = self.model_ft.AuxLogits(x)
|
| 256 |
+
else:
|
| 257 |
+
aux = None
|
| 258 |
+
# N x 768 x 17 x 17
|
| 259 |
+
x = self.model_ft.Mixed_7a(x)
|
| 260 |
+
# N x 1280 x 8 x 8
|
| 261 |
+
x = self.model_ft.Mixed_7b(x)
|
| 262 |
+
# N x 2048 x 8 x 8
|
| 263 |
+
x = self.model_ft.Mixed_7c(x)
|
| 264 |
+
# N x 2048 x 8 x 8
|
| 265 |
+
# Adaptive average pooling
|
| 266 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 267 |
+
# N x 2048 x 1 x 1
|
| 268 |
+
feats = F.dropout(x, training=self.isTrain)
|
| 269 |
+
# N x 2048 x 1 x 1
|
| 270 |
+
x = torch.flatten(feats, 1)
|
| 271 |
+
# N x 2048
|
| 272 |
+
x = self.model_ft.fc(x)
|
| 273 |
+
# N x 1000 (num_classes)
|
| 274 |
+
|
| 275 |
+
if self.every_feat:
|
| 276 |
+
# return feat21, feats, x
|
| 277 |
+
return x, feat21
|
| 278 |
+
|
| 279 |
+
return x, aux
|
train/src/condition/lineart.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import functools
|
| 5 |
+
from torchvision import models
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
norm_layer = nn.InstanceNorm2d
|
| 11 |
+
|
| 12 |
+
class ResidualBlock(nn.Module):
|
| 13 |
+
def __init__(self, in_features):
|
| 14 |
+
super(ResidualBlock, self).__init__()
|
| 15 |
+
|
| 16 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
| 17 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 18 |
+
norm_layer(in_features),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.ReflectionPad2d(1),
|
| 21 |
+
nn.Conv2d(in_features, in_features, 3),
|
| 22 |
+
norm_layer(in_features)
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
self.conv_block = nn.Sequential(*conv_block)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return x + self.conv_block(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Generator(nn.Module):
|
| 32 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
| 33 |
+
super(Generator, self).__init__()
|
| 34 |
+
|
| 35 |
+
# Initial convolution block
|
| 36 |
+
model0 = [ nn.ReflectionPad2d(3),
|
| 37 |
+
nn.Conv2d(input_nc, 64, 7),
|
| 38 |
+
norm_layer(64),
|
| 39 |
+
nn.ReLU(inplace=True) ]
|
| 40 |
+
self.model0 = nn.Sequential(*model0)
|
| 41 |
+
|
| 42 |
+
# Downsampling
|
| 43 |
+
model1 = []
|
| 44 |
+
in_features = 64
|
| 45 |
+
out_features = in_features*2
|
| 46 |
+
for _ in range(2):
|
| 47 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
| 48 |
+
norm_layer(out_features),
|
| 49 |
+
nn.ReLU(inplace=True) ]
|
| 50 |
+
in_features = out_features
|
| 51 |
+
out_features = in_features*2
|
| 52 |
+
self.model1 = nn.Sequential(*model1)
|
| 53 |
+
|
| 54 |
+
model2 = []
|
| 55 |
+
# Residual blocks
|
| 56 |
+
for _ in range(n_residual_blocks):
|
| 57 |
+
model2 += [ResidualBlock(in_features)]
|
| 58 |
+
self.model2 = nn.Sequential(*model2)
|
| 59 |
+
|
| 60 |
+
# Upsampling
|
| 61 |
+
model3 = []
|
| 62 |
+
out_features = in_features//2
|
| 63 |
+
for _ in range(2):
|
| 64 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
| 65 |
+
norm_layer(out_features),
|
| 66 |
+
nn.ReLU(inplace=True) ]
|
| 67 |
+
in_features = out_features
|
| 68 |
+
out_features = in_features//2
|
| 69 |
+
self.model3 = nn.Sequential(*model3)
|
| 70 |
+
|
| 71 |
+
# Output layer
|
| 72 |
+
model4 = [ nn.ReflectionPad2d(3),
|
| 73 |
+
nn.Conv2d(64, output_nc, 7)]
|
| 74 |
+
if sigmoid:
|
| 75 |
+
model4 += [nn.Sigmoid()]
|
| 76 |
+
|
| 77 |
+
self.model4 = nn.Sequential(*model4)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, cond=None):
|
| 80 |
+
out = self.model0(x)
|
| 81 |
+
out = self.model1(out)
|
| 82 |
+
out = self.model2(out)
|
| 83 |
+
out = self.model3(out)
|
| 84 |
+
out = self.model4(out)
|
| 85 |
+
|
| 86 |
+
return out
|
train/src/condition/pidi.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Zhuo Su, Wenzhe Liu
|
| 3 |
+
Date: Feb 18, 2021
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 16 |
+
"""Numpy array to tensor.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 20 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 21 |
+
float32 (bool): Whether to change to float32.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 25 |
+
one element, just return tensor.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def _totensor(img, bgr2rgb, float32):
|
| 29 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 30 |
+
if img.dtype == 'float64':
|
| 31 |
+
img = img.astype('float32')
|
| 32 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 33 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 34 |
+
if float32:
|
| 35 |
+
img = img.float()
|
| 36 |
+
return img
|
| 37 |
+
|
| 38 |
+
if isinstance(imgs, list):
|
| 39 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 40 |
+
else:
|
| 41 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 42 |
+
|
| 43 |
+
nets = {
|
| 44 |
+
'baseline': {
|
| 45 |
+
'layer0': 'cv',
|
| 46 |
+
'layer1': 'cv',
|
| 47 |
+
'layer2': 'cv',
|
| 48 |
+
'layer3': 'cv',
|
| 49 |
+
'layer4': 'cv',
|
| 50 |
+
'layer5': 'cv',
|
| 51 |
+
'layer6': 'cv',
|
| 52 |
+
'layer7': 'cv',
|
| 53 |
+
'layer8': 'cv',
|
| 54 |
+
'layer9': 'cv',
|
| 55 |
+
'layer10': 'cv',
|
| 56 |
+
'layer11': 'cv',
|
| 57 |
+
'layer12': 'cv',
|
| 58 |
+
'layer13': 'cv',
|
| 59 |
+
'layer14': 'cv',
|
| 60 |
+
'layer15': 'cv',
|
| 61 |
+
},
|
| 62 |
+
'c-v15': {
|
| 63 |
+
'layer0': 'cd',
|
| 64 |
+
'layer1': 'cv',
|
| 65 |
+
'layer2': 'cv',
|
| 66 |
+
'layer3': 'cv',
|
| 67 |
+
'layer4': 'cv',
|
| 68 |
+
'layer5': 'cv',
|
| 69 |
+
'layer6': 'cv',
|
| 70 |
+
'layer7': 'cv',
|
| 71 |
+
'layer8': 'cv',
|
| 72 |
+
'layer9': 'cv',
|
| 73 |
+
'layer10': 'cv',
|
| 74 |
+
'layer11': 'cv',
|
| 75 |
+
'layer12': 'cv',
|
| 76 |
+
'layer13': 'cv',
|
| 77 |
+
'layer14': 'cv',
|
| 78 |
+
'layer15': 'cv',
|
| 79 |
+
},
|
| 80 |
+
'a-v15': {
|
| 81 |
+
'layer0': 'ad',
|
| 82 |
+
'layer1': 'cv',
|
| 83 |
+
'layer2': 'cv',
|
| 84 |
+
'layer3': 'cv',
|
| 85 |
+
'layer4': 'cv',
|
| 86 |
+
'layer5': 'cv',
|
| 87 |
+
'layer6': 'cv',
|
| 88 |
+
'layer7': 'cv',
|
| 89 |
+
'layer8': 'cv',
|
| 90 |
+
'layer9': 'cv',
|
| 91 |
+
'layer10': 'cv',
|
| 92 |
+
'layer11': 'cv',
|
| 93 |
+
'layer12': 'cv',
|
| 94 |
+
'layer13': 'cv',
|
| 95 |
+
'layer14': 'cv',
|
| 96 |
+
'layer15': 'cv',
|
| 97 |
+
},
|
| 98 |
+
'r-v15': {
|
| 99 |
+
'layer0': 'rd',
|
| 100 |
+
'layer1': 'cv',
|
| 101 |
+
'layer2': 'cv',
|
| 102 |
+
'layer3': 'cv',
|
| 103 |
+
'layer4': 'cv',
|
| 104 |
+
'layer5': 'cv',
|
| 105 |
+
'layer6': 'cv',
|
| 106 |
+
'layer7': 'cv',
|
| 107 |
+
'layer8': 'cv',
|
| 108 |
+
'layer9': 'cv',
|
| 109 |
+
'layer10': 'cv',
|
| 110 |
+
'layer11': 'cv',
|
| 111 |
+
'layer12': 'cv',
|
| 112 |
+
'layer13': 'cv',
|
| 113 |
+
'layer14': 'cv',
|
| 114 |
+
'layer15': 'cv',
|
| 115 |
+
},
|
| 116 |
+
'cvvv4': {
|
| 117 |
+
'layer0': 'cd',
|
| 118 |
+
'layer1': 'cv',
|
| 119 |
+
'layer2': 'cv',
|
| 120 |
+
'layer3': 'cv',
|
| 121 |
+
'layer4': 'cd',
|
| 122 |
+
'layer5': 'cv',
|
| 123 |
+
'layer6': 'cv',
|
| 124 |
+
'layer7': 'cv',
|
| 125 |
+
'layer8': 'cd',
|
| 126 |
+
'layer9': 'cv',
|
| 127 |
+
'layer10': 'cv',
|
| 128 |
+
'layer11': 'cv',
|
| 129 |
+
'layer12': 'cd',
|
| 130 |
+
'layer13': 'cv',
|
| 131 |
+
'layer14': 'cv',
|
| 132 |
+
'layer15': 'cv',
|
| 133 |
+
},
|
| 134 |
+
'avvv4': {
|
| 135 |
+
'layer0': 'ad',
|
| 136 |
+
'layer1': 'cv',
|
| 137 |
+
'layer2': 'cv',
|
| 138 |
+
'layer3': 'cv',
|
| 139 |
+
'layer4': 'ad',
|
| 140 |
+
'layer5': 'cv',
|
| 141 |
+
'layer6': 'cv',
|
| 142 |
+
'layer7': 'cv',
|
| 143 |
+
'layer8': 'ad',
|
| 144 |
+
'layer9': 'cv',
|
| 145 |
+
'layer10': 'cv',
|
| 146 |
+
'layer11': 'cv',
|
| 147 |
+
'layer12': 'ad',
|
| 148 |
+
'layer13': 'cv',
|
| 149 |
+
'layer14': 'cv',
|
| 150 |
+
'layer15': 'cv',
|
| 151 |
+
},
|
| 152 |
+
'rvvv4': {
|
| 153 |
+
'layer0': 'rd',
|
| 154 |
+
'layer1': 'cv',
|
| 155 |
+
'layer2': 'cv',
|
| 156 |
+
'layer3': 'cv',
|
| 157 |
+
'layer4': 'rd',
|
| 158 |
+
'layer5': 'cv',
|
| 159 |
+
'layer6': 'cv',
|
| 160 |
+
'layer7': 'cv',
|
| 161 |
+
'layer8': 'rd',
|
| 162 |
+
'layer9': 'cv',
|
| 163 |
+
'layer10': 'cv',
|
| 164 |
+
'layer11': 'cv',
|
| 165 |
+
'layer12': 'rd',
|
| 166 |
+
'layer13': 'cv',
|
| 167 |
+
'layer14': 'cv',
|
| 168 |
+
'layer15': 'cv',
|
| 169 |
+
},
|
| 170 |
+
'cccv4': {
|
| 171 |
+
'layer0': 'cd',
|
| 172 |
+
'layer1': 'cd',
|
| 173 |
+
'layer2': 'cd',
|
| 174 |
+
'layer3': 'cv',
|
| 175 |
+
'layer4': 'cd',
|
| 176 |
+
'layer5': 'cd',
|
| 177 |
+
'layer6': 'cd',
|
| 178 |
+
'layer7': 'cv',
|
| 179 |
+
'layer8': 'cd',
|
| 180 |
+
'layer9': 'cd',
|
| 181 |
+
'layer10': 'cd',
|
| 182 |
+
'layer11': 'cv',
|
| 183 |
+
'layer12': 'cd',
|
| 184 |
+
'layer13': 'cd',
|
| 185 |
+
'layer14': 'cd',
|
| 186 |
+
'layer15': 'cv',
|
| 187 |
+
},
|
| 188 |
+
'aaav4': {
|
| 189 |
+
'layer0': 'ad',
|
| 190 |
+
'layer1': 'ad',
|
| 191 |
+
'layer2': 'ad',
|
| 192 |
+
'layer3': 'cv',
|
| 193 |
+
'layer4': 'ad',
|
| 194 |
+
'layer5': 'ad',
|
| 195 |
+
'layer6': 'ad',
|
| 196 |
+
'layer7': 'cv',
|
| 197 |
+
'layer8': 'ad',
|
| 198 |
+
'layer9': 'ad',
|
| 199 |
+
'layer10': 'ad',
|
| 200 |
+
'layer11': 'cv',
|
| 201 |
+
'layer12': 'ad',
|
| 202 |
+
'layer13': 'ad',
|
| 203 |
+
'layer14': 'ad',
|
| 204 |
+
'layer15': 'cv',
|
| 205 |
+
},
|
| 206 |
+
'rrrv4': {
|
| 207 |
+
'layer0': 'rd',
|
| 208 |
+
'layer1': 'rd',
|
| 209 |
+
'layer2': 'rd',
|
| 210 |
+
'layer3': 'cv',
|
| 211 |
+
'layer4': 'rd',
|
| 212 |
+
'layer5': 'rd',
|
| 213 |
+
'layer6': 'rd',
|
| 214 |
+
'layer7': 'cv',
|
| 215 |
+
'layer8': 'rd',
|
| 216 |
+
'layer9': 'rd',
|
| 217 |
+
'layer10': 'rd',
|
| 218 |
+
'layer11': 'cv',
|
| 219 |
+
'layer12': 'rd',
|
| 220 |
+
'layer13': 'rd',
|
| 221 |
+
'layer14': 'rd',
|
| 222 |
+
'layer15': 'cv',
|
| 223 |
+
},
|
| 224 |
+
'c16': {
|
| 225 |
+
'layer0': 'cd',
|
| 226 |
+
'layer1': 'cd',
|
| 227 |
+
'layer2': 'cd',
|
| 228 |
+
'layer3': 'cd',
|
| 229 |
+
'layer4': 'cd',
|
| 230 |
+
'layer5': 'cd',
|
| 231 |
+
'layer6': 'cd',
|
| 232 |
+
'layer7': 'cd',
|
| 233 |
+
'layer8': 'cd',
|
| 234 |
+
'layer9': 'cd',
|
| 235 |
+
'layer10': 'cd',
|
| 236 |
+
'layer11': 'cd',
|
| 237 |
+
'layer12': 'cd',
|
| 238 |
+
'layer13': 'cd',
|
| 239 |
+
'layer14': 'cd',
|
| 240 |
+
'layer15': 'cd',
|
| 241 |
+
},
|
| 242 |
+
'a16': {
|
| 243 |
+
'layer0': 'ad',
|
| 244 |
+
'layer1': 'ad',
|
| 245 |
+
'layer2': 'ad',
|
| 246 |
+
'layer3': 'ad',
|
| 247 |
+
'layer4': 'ad',
|
| 248 |
+
'layer5': 'ad',
|
| 249 |
+
'layer6': 'ad',
|
| 250 |
+
'layer7': 'ad',
|
| 251 |
+
'layer8': 'ad',
|
| 252 |
+
'layer9': 'ad',
|
| 253 |
+
'layer10': 'ad',
|
| 254 |
+
'layer11': 'ad',
|
| 255 |
+
'layer12': 'ad',
|
| 256 |
+
'layer13': 'ad',
|
| 257 |
+
'layer14': 'ad',
|
| 258 |
+
'layer15': 'ad',
|
| 259 |
+
},
|
| 260 |
+
'r16': {
|
| 261 |
+
'layer0': 'rd',
|
| 262 |
+
'layer1': 'rd',
|
| 263 |
+
'layer2': 'rd',
|
| 264 |
+
'layer3': 'rd',
|
| 265 |
+
'layer4': 'rd',
|
| 266 |
+
'layer5': 'rd',
|
| 267 |
+
'layer6': 'rd',
|
| 268 |
+
'layer7': 'rd',
|
| 269 |
+
'layer8': 'rd',
|
| 270 |
+
'layer9': 'rd',
|
| 271 |
+
'layer10': 'rd',
|
| 272 |
+
'layer11': 'rd',
|
| 273 |
+
'layer12': 'rd',
|
| 274 |
+
'layer13': 'rd',
|
| 275 |
+
'layer14': 'rd',
|
| 276 |
+
'layer15': 'rd',
|
| 277 |
+
},
|
| 278 |
+
'carv4': {
|
| 279 |
+
'layer0': 'cd',
|
| 280 |
+
'layer1': 'ad',
|
| 281 |
+
'layer2': 'rd',
|
| 282 |
+
'layer3': 'cv',
|
| 283 |
+
'layer4': 'cd',
|
| 284 |
+
'layer5': 'ad',
|
| 285 |
+
'layer6': 'rd',
|
| 286 |
+
'layer7': 'cv',
|
| 287 |
+
'layer8': 'cd',
|
| 288 |
+
'layer9': 'ad',
|
| 289 |
+
'layer10': 'rd',
|
| 290 |
+
'layer11': 'cv',
|
| 291 |
+
'layer12': 'cd',
|
| 292 |
+
'layer13': 'ad',
|
| 293 |
+
'layer14': 'rd',
|
| 294 |
+
'layer15': 'cv',
|
| 295 |
+
},
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
def createConvFunc(op_type):
|
| 299 |
+
assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
|
| 300 |
+
if op_type == 'cv':
|
| 301 |
+
return F.conv2d
|
| 302 |
+
|
| 303 |
+
if op_type == 'cd':
|
| 304 |
+
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 305 |
+
assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
|
| 306 |
+
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
|
| 307 |
+
assert padding == dilation, 'padding for cd_conv set wrong'
|
| 308 |
+
|
| 309 |
+
weights_c = weights.sum(dim=[2, 3], keepdim=True)
|
| 310 |
+
yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
|
| 311 |
+
y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 312 |
+
return y - yc
|
| 313 |
+
return func
|
| 314 |
+
elif op_type == 'ad':
|
| 315 |
+
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 316 |
+
assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
|
| 317 |
+
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
|
| 318 |
+
assert padding == dilation, 'padding for ad_conv set wrong'
|
| 319 |
+
|
| 320 |
+
shape = weights.shape
|
| 321 |
+
weights = weights.view(shape[0], shape[1], -1)
|
| 322 |
+
weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
|
| 323 |
+
y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 324 |
+
return y
|
| 325 |
+
return func
|
| 326 |
+
elif op_type == 'rd':
|
| 327 |
+
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 328 |
+
assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
|
| 329 |
+
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
|
| 330 |
+
padding = 2 * dilation
|
| 331 |
+
|
| 332 |
+
shape = weights.shape
|
| 333 |
+
if weights.is_cuda:
|
| 334 |
+
buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
|
| 335 |
+
else:
|
| 336 |
+
buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device)
|
| 337 |
+
weights = weights.view(shape[0], shape[1], -1)
|
| 338 |
+
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
|
| 339 |
+
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
|
| 340 |
+
buffer[:, :, 12] = 0
|
| 341 |
+
buffer = buffer.view(shape[0], shape[1], 5, 5)
|
| 342 |
+
y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 343 |
+
return y
|
| 344 |
+
return func
|
| 345 |
+
else:
|
| 346 |
+
print('impossible to be here unless you force that')
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
class Conv2d(nn.Module):
|
| 350 |
+
def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
|
| 351 |
+
super(Conv2d, self).__init__()
|
| 352 |
+
if in_channels % groups != 0:
|
| 353 |
+
raise ValueError('in_channels must be divisible by groups')
|
| 354 |
+
if out_channels % groups != 0:
|
| 355 |
+
raise ValueError('out_channels must be divisible by groups')
|
| 356 |
+
self.in_channels = in_channels
|
| 357 |
+
self.out_channels = out_channels
|
| 358 |
+
self.kernel_size = kernel_size
|
| 359 |
+
self.stride = stride
|
| 360 |
+
self.padding = padding
|
| 361 |
+
self.dilation = dilation
|
| 362 |
+
self.groups = groups
|
| 363 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
|
| 364 |
+
if bias:
|
| 365 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 366 |
+
else:
|
| 367 |
+
self.register_parameter('bias', None)
|
| 368 |
+
self.reset_parameters()
|
| 369 |
+
self.pdc = pdc
|
| 370 |
+
|
| 371 |
+
def reset_parameters(self):
|
| 372 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 373 |
+
if self.bias is not None:
|
| 374 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
| 375 |
+
bound = 1 / math.sqrt(fan_in)
|
| 376 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 377 |
+
|
| 378 |
+
def forward(self, input):
|
| 379 |
+
|
| 380 |
+
return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 381 |
+
|
| 382 |
+
class CSAM(nn.Module):
|
| 383 |
+
"""
|
| 384 |
+
Compact Spatial Attention Module
|
| 385 |
+
"""
|
| 386 |
+
def __init__(self, channels):
|
| 387 |
+
super(CSAM, self).__init__()
|
| 388 |
+
|
| 389 |
+
mid_channels = 4
|
| 390 |
+
self.relu1 = nn.ReLU()
|
| 391 |
+
self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
|
| 392 |
+
self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
|
| 393 |
+
self.sigmoid = nn.Sigmoid()
|
| 394 |
+
nn.init.constant_(self.conv1.bias, 0)
|
| 395 |
+
|
| 396 |
+
def forward(self, x):
|
| 397 |
+
y = self.relu1(x)
|
| 398 |
+
y = self.conv1(y)
|
| 399 |
+
y = self.conv2(y)
|
| 400 |
+
y = self.sigmoid(y)
|
| 401 |
+
|
| 402 |
+
return x * y
|
| 403 |
+
|
| 404 |
+
class CDCM(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
Compact Dilation Convolution based Module
|
| 407 |
+
"""
|
| 408 |
+
def __init__(self, in_channels, out_channels):
|
| 409 |
+
super(CDCM, self).__init__()
|
| 410 |
+
|
| 411 |
+
self.relu1 = nn.ReLU()
|
| 412 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
| 413 |
+
self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
|
| 414 |
+
self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
|
| 415 |
+
self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
|
| 416 |
+
self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
|
| 417 |
+
nn.init.constant_(self.conv1.bias, 0)
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
x = self.relu1(x)
|
| 421 |
+
x = self.conv1(x)
|
| 422 |
+
x1 = self.conv2_1(x)
|
| 423 |
+
x2 = self.conv2_2(x)
|
| 424 |
+
x3 = self.conv2_3(x)
|
| 425 |
+
x4 = self.conv2_4(x)
|
| 426 |
+
return x1 + x2 + x3 + x4
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class MapReduce(nn.Module):
|
| 430 |
+
"""
|
| 431 |
+
Reduce feature maps into a single edge map
|
| 432 |
+
"""
|
| 433 |
+
def __init__(self, channels):
|
| 434 |
+
super(MapReduce, self).__init__()
|
| 435 |
+
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
|
| 436 |
+
nn.init.constant_(self.conv.bias, 0)
|
| 437 |
+
|
| 438 |
+
def forward(self, x):
|
| 439 |
+
return self.conv(x)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class PDCBlock(nn.Module):
|
| 443 |
+
def __init__(self, pdc, inplane, ouplane, stride=1):
|
| 444 |
+
super(PDCBlock, self).__init__()
|
| 445 |
+
self.stride=stride
|
| 446 |
+
|
| 447 |
+
self.stride=stride
|
| 448 |
+
if self.stride > 1:
|
| 449 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 450 |
+
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
|
| 451 |
+
self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
|
| 452 |
+
self.relu2 = nn.ReLU()
|
| 453 |
+
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
| 454 |
+
|
| 455 |
+
def forward(self, x):
|
| 456 |
+
if self.stride > 1:
|
| 457 |
+
x = self.pool(x)
|
| 458 |
+
y = self.conv1(x)
|
| 459 |
+
y = self.relu2(y)
|
| 460 |
+
y = self.conv2(y)
|
| 461 |
+
if self.stride > 1:
|
| 462 |
+
x = self.shortcut(x)
|
| 463 |
+
y = y + x
|
| 464 |
+
return y
|
| 465 |
+
|
| 466 |
+
class PDCBlock_converted(nn.Module):
|
| 467 |
+
"""
|
| 468 |
+
CPDC, APDC can be converted to vanilla 3x3 convolution
|
| 469 |
+
RPDC can be converted to vanilla 5x5 convolution
|
| 470 |
+
"""
|
| 471 |
+
def __init__(self, pdc, inplane, ouplane, stride=1):
|
| 472 |
+
super(PDCBlock_converted, self).__init__()
|
| 473 |
+
self.stride=stride
|
| 474 |
+
|
| 475 |
+
if self.stride > 1:
|
| 476 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 477 |
+
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
|
| 478 |
+
if pdc == 'rd':
|
| 479 |
+
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
|
| 480 |
+
else:
|
| 481 |
+
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
|
| 482 |
+
self.relu2 = nn.ReLU()
|
| 483 |
+
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
| 484 |
+
|
| 485 |
+
def forward(self, x):
|
| 486 |
+
if self.stride > 1:
|
| 487 |
+
x = self.pool(x)
|
| 488 |
+
y = self.conv1(x)
|
| 489 |
+
y = self.relu2(y)
|
| 490 |
+
y = self.conv2(y)
|
| 491 |
+
if self.stride > 1:
|
| 492 |
+
x = self.shortcut(x)
|
| 493 |
+
y = y + x
|
| 494 |
+
return y
|
| 495 |
+
|
| 496 |
+
class PiDiNet(nn.Module):
|
| 497 |
+
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
|
| 498 |
+
super(PiDiNet, self).__init__()
|
| 499 |
+
self.sa = sa
|
| 500 |
+
if dil is not None:
|
| 501 |
+
assert isinstance(dil, int), 'dil should be an int'
|
| 502 |
+
self.dil = dil
|
| 503 |
+
|
| 504 |
+
self.fuseplanes = []
|
| 505 |
+
|
| 506 |
+
self.inplane = inplane
|
| 507 |
+
if convert:
|
| 508 |
+
if pdcs[0] == 'rd':
|
| 509 |
+
init_kernel_size = 5
|
| 510 |
+
init_padding = 2
|
| 511 |
+
else:
|
| 512 |
+
init_kernel_size = 3
|
| 513 |
+
init_padding = 1
|
| 514 |
+
self.init_block = nn.Conv2d(3, self.inplane,
|
| 515 |
+
kernel_size=init_kernel_size, padding=init_padding, bias=False)
|
| 516 |
+
block_class = PDCBlock_converted
|
| 517 |
+
else:
|
| 518 |
+
self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
|
| 519 |
+
block_class = PDCBlock
|
| 520 |
+
|
| 521 |
+
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
|
| 522 |
+
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
|
| 523 |
+
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
|
| 524 |
+
self.fuseplanes.append(self.inplane) # C
|
| 525 |
+
|
| 526 |
+
inplane = self.inplane
|
| 527 |
+
self.inplane = self.inplane * 2
|
| 528 |
+
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
|
| 529 |
+
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
|
| 530 |
+
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
|
| 531 |
+
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
|
| 532 |
+
self.fuseplanes.append(self.inplane) # 2C
|
| 533 |
+
|
| 534 |
+
inplane = self.inplane
|
| 535 |
+
self.inplane = self.inplane * 2
|
| 536 |
+
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
|
| 537 |
+
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
|
| 538 |
+
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
|
| 539 |
+
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
|
| 540 |
+
self.fuseplanes.append(self.inplane) # 4C
|
| 541 |
+
|
| 542 |
+
self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
|
| 543 |
+
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
|
| 544 |
+
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
|
| 545 |
+
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
|
| 546 |
+
self.fuseplanes.append(self.inplane) # 4C
|
| 547 |
+
|
| 548 |
+
self.conv_reduces = nn.ModuleList()
|
| 549 |
+
if self.sa and self.dil is not None:
|
| 550 |
+
self.attentions = nn.ModuleList()
|
| 551 |
+
self.dilations = nn.ModuleList()
|
| 552 |
+
for i in range(4):
|
| 553 |
+
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
| 554 |
+
self.attentions.append(CSAM(self.dil))
|
| 555 |
+
self.conv_reduces.append(MapReduce(self.dil))
|
| 556 |
+
elif self.sa:
|
| 557 |
+
self.attentions = nn.ModuleList()
|
| 558 |
+
for i in range(4):
|
| 559 |
+
self.attentions.append(CSAM(self.fuseplanes[i]))
|
| 560 |
+
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
| 561 |
+
elif self.dil is not None:
|
| 562 |
+
self.dilations = nn.ModuleList()
|
| 563 |
+
for i in range(4):
|
| 564 |
+
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
| 565 |
+
self.conv_reduces.append(MapReduce(self.dil))
|
| 566 |
+
else:
|
| 567 |
+
for i in range(4):
|
| 568 |
+
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
| 569 |
+
|
| 570 |
+
self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
|
| 571 |
+
nn.init.constant_(self.classifier.weight, 0.25)
|
| 572 |
+
nn.init.constant_(self.classifier.bias, 0)
|
| 573 |
+
|
| 574 |
+
# print('initialization done')
|
| 575 |
+
|
| 576 |
+
def get_weights(self):
|
| 577 |
+
conv_weights = []
|
| 578 |
+
bn_weights = []
|
| 579 |
+
relu_weights = []
|
| 580 |
+
for pname, p in self.named_parameters():
|
| 581 |
+
if 'bn' in pname:
|
| 582 |
+
bn_weights.append(p)
|
| 583 |
+
elif 'relu' in pname:
|
| 584 |
+
relu_weights.append(p)
|
| 585 |
+
else:
|
| 586 |
+
conv_weights.append(p)
|
| 587 |
+
|
| 588 |
+
return conv_weights, bn_weights, relu_weights
|
| 589 |
+
|
| 590 |
+
def forward(self, x):
|
| 591 |
+
H, W = x.size()[2:]
|
| 592 |
+
|
| 593 |
+
x = self.init_block(x)
|
| 594 |
+
|
| 595 |
+
x1 = self.block1_1(x)
|
| 596 |
+
x1 = self.block1_2(x1)
|
| 597 |
+
x1 = self.block1_3(x1)
|
| 598 |
+
|
| 599 |
+
x2 = self.block2_1(x1)
|
| 600 |
+
x2 = self.block2_2(x2)
|
| 601 |
+
x2 = self.block2_3(x2)
|
| 602 |
+
x2 = self.block2_4(x2)
|
| 603 |
+
|
| 604 |
+
x3 = self.block3_1(x2)
|
| 605 |
+
x3 = self.block3_2(x3)
|
| 606 |
+
x3 = self.block3_3(x3)
|
| 607 |
+
x3 = self.block3_4(x3)
|
| 608 |
+
|
| 609 |
+
x4 = self.block4_1(x3)
|
| 610 |
+
x4 = self.block4_2(x4)
|
| 611 |
+
x4 = self.block4_3(x4)
|
| 612 |
+
x4 = self.block4_4(x4)
|
| 613 |
+
|
| 614 |
+
x_fuses = []
|
| 615 |
+
if self.sa and self.dil is not None:
|
| 616 |
+
for i, xi in enumerate([x1, x2, x3, x4]):
|
| 617 |
+
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
|
| 618 |
+
elif self.sa:
|
| 619 |
+
for i, xi in enumerate([x1, x2, x3, x4]):
|
| 620 |
+
x_fuses.append(self.attentions[i](xi))
|
| 621 |
+
elif self.dil is not None:
|
| 622 |
+
for i, xi in enumerate([x1, x2, x3, x4]):
|
| 623 |
+
x_fuses.append(self.dilations[i](xi))
|
| 624 |
+
else:
|
| 625 |
+
x_fuses = [x1, x2, x3, x4]
|
| 626 |
+
|
| 627 |
+
e1 = self.conv_reduces[0](x_fuses[0])
|
| 628 |
+
e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
|
| 629 |
+
|
| 630 |
+
e2 = self.conv_reduces[1](x_fuses[1])
|
| 631 |
+
e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
|
| 632 |
+
|
| 633 |
+
e3 = self.conv_reduces[2](x_fuses[2])
|
| 634 |
+
e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
|
| 635 |
+
|
| 636 |
+
e4 = self.conv_reduces[3](x_fuses[3])
|
| 637 |
+
e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
|
| 638 |
+
|
| 639 |
+
outputs = [e1, e2, e3, e4]
|
| 640 |
+
|
| 641 |
+
output = self.classifier(torch.cat(outputs, dim=1))
|
| 642 |
+
#if not self.training:
|
| 643 |
+
# return torch.sigmoid(output)
|
| 644 |
+
|
| 645 |
+
outputs.append(output)
|
| 646 |
+
outputs = [torch.sigmoid(r) for r in outputs]
|
| 647 |
+
return outputs
|
| 648 |
+
|
| 649 |
+
def config_model(model):
|
| 650 |
+
model_options = list(nets.keys())
|
| 651 |
+
assert model in model_options, \
|
| 652 |
+
'unrecognized model, please choose from %s' % str(model_options)
|
| 653 |
+
|
| 654 |
+
# print(str(nets[model]))
|
| 655 |
+
|
| 656 |
+
pdcs = []
|
| 657 |
+
for i in range(16):
|
| 658 |
+
layer_name = 'layer%d' % i
|
| 659 |
+
op = nets[model][layer_name]
|
| 660 |
+
pdcs.append(createConvFunc(op))
|
| 661 |
+
|
| 662 |
+
return pdcs
|
| 663 |
+
|
| 664 |
+
def pidinet():
|
| 665 |
+
pdcs = config_model('carv4')
|
| 666 |
+
dil = 24 #if args.dil else None
|
| 667 |
+
return PiDiNet(60, pdcs, dil=dil, sa=True)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
if __name__ == '__main__':
|
| 671 |
+
model = pidinet()
|
| 672 |
+
ckp = torch.load('table5_pidinet.pth')['state_dict']
|
| 673 |
+
model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
|
| 674 |
+
im = cv2.imread('examples/test_my/cat_v4.png')
|
| 675 |
+
im = img2tensor(im).unsqueeze(0)/255.
|
| 676 |
+
res = model(im)[-1]
|
| 677 |
+
res = res>0.5
|
| 678 |
+
res = res.float()
|
| 679 |
+
res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
|
| 680 |
+
print(res.shape)
|
| 681 |
+
cv2.imwrite('edge.png', res)
|
train/src/condition/ted.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
|
| 2 |
+
# with a Slightly modification
|
| 3 |
+
# LDC parameters:
|
| 4 |
+
# 155665
|
| 5 |
+
# TED > 58K
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .util import smish as Fsmish
|
| 12 |
+
from .util import Smish
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def weight_init(m):
|
| 16 |
+
if isinstance(m, (nn.Conv2d,)):
|
| 17 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
| 18 |
+
|
| 19 |
+
if m.bias is not None:
|
| 20 |
+
torch.nn.init.zeros_(m.bias)
|
| 21 |
+
|
| 22 |
+
# for fusion layer
|
| 23 |
+
if isinstance(m, (nn.ConvTranspose2d,)):
|
| 24 |
+
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
|
| 25 |
+
if m.bias is not None:
|
| 26 |
+
torch.nn.init.zeros_(m.bias)
|
| 27 |
+
|
| 28 |
+
class CoFusion(nn.Module):
|
| 29 |
+
# from LDC
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_ch, out_ch):
|
| 32 |
+
super(CoFusion, self).__init__()
|
| 33 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
|
| 34 |
+
stride=1, padding=1) # before 64
|
| 35 |
+
self.conv3= nn.Conv2d(32, out_ch, kernel_size=3,
|
| 36 |
+
stride=1, padding=1)# before 64 instead of 32
|
| 37 |
+
self.relu = nn.ReLU()
|
| 38 |
+
self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
# fusecat = torch.cat(x, dim=1)
|
| 42 |
+
attn = self.relu(self.norm_layer1(self.conv1(x)))
|
| 43 |
+
attn = F.softmax(self.conv3(attn), dim=1)
|
| 44 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CoFusion2(nn.Module):
|
| 48 |
+
# TEDv14-3
|
| 49 |
+
def __init__(self, in_ch, out_ch):
|
| 50 |
+
super(CoFusion2, self).__init__()
|
| 51 |
+
self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
|
| 52 |
+
stride=1, padding=1) # before 64
|
| 53 |
+
# self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
|
| 54 |
+
# stride=1, padding=1)# before 64
|
| 55 |
+
self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3,
|
| 56 |
+
stride=1, padding=1)# before 64 instead of 32
|
| 57 |
+
self.smish= Smish()#nn.ReLU(inplace=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
# fusecat = torch.cat(x, dim=1)
|
| 62 |
+
attn = self.conv1(self.smish(x))
|
| 63 |
+
attn = self.conv3(self.smish(attn)) # before , )dim=1)
|
| 64 |
+
|
| 65 |
+
# return ((fusecat * attn).sum(1)).unsqueeze(1)
|
| 66 |
+
return ((x * attn).sum(1)).unsqueeze(1)
|
| 67 |
+
|
| 68 |
+
class DoubleFusion(nn.Module):
|
| 69 |
+
# TED fusion before the final edge map prediction
|
| 70 |
+
def __init__(self, in_ch, out_ch):
|
| 71 |
+
super(DoubleFusion, self).__init__()
|
| 72 |
+
self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3,
|
| 73 |
+
stride=1, padding=1, groups=in_ch) # before 64
|
| 74 |
+
self.PSconv1 = nn.PixelShuffle(1)
|
| 75 |
+
|
| 76 |
+
self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3,
|
| 77 |
+
stride=1, padding=1,groups=24)# before 64 instead of 32
|
| 78 |
+
|
| 79 |
+
self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()#
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
# fusecat = torch.cat(x, dim=1)
|
| 84 |
+
attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
|
| 85 |
+
|
| 86 |
+
attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
|
| 87 |
+
|
| 88 |
+
return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res
|
| 89 |
+
|
| 90 |
+
class _DenseLayer(nn.Sequential):
|
| 91 |
+
def __init__(self, input_features, out_features):
|
| 92 |
+
super(_DenseLayer, self).__init__()
|
| 93 |
+
|
| 94 |
+
self.add_module('conv1', nn.Conv2d(input_features, out_features,
|
| 95 |
+
kernel_size=3, stride=1, padding=2, bias=True)),
|
| 96 |
+
self.add_module('smish1', Smish()),
|
| 97 |
+
self.add_module('conv2', nn.Conv2d(out_features, out_features,
|
| 98 |
+
kernel_size=3, stride=1, bias=True))
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
x1, x2 = x
|
| 101 |
+
|
| 102 |
+
new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()
|
| 103 |
+
|
| 104 |
+
return 0.5 * (new_features + x2), x2
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class _DenseBlock(nn.Sequential):
|
| 108 |
+
def __init__(self, num_layers, input_features, out_features):
|
| 109 |
+
super(_DenseBlock, self).__init__()
|
| 110 |
+
for i in range(num_layers):
|
| 111 |
+
layer = _DenseLayer(input_features, out_features)
|
| 112 |
+
self.add_module('denselayer%d' % (i + 1), layer)
|
| 113 |
+
input_features = out_features
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class UpConvBlock(nn.Module):
|
| 117 |
+
def __init__(self, in_features, up_scale):
|
| 118 |
+
super(UpConvBlock, self).__init__()
|
| 119 |
+
self.up_factor = 2
|
| 120 |
+
self.constant_features = 16
|
| 121 |
+
|
| 122 |
+
layers = self.make_deconv_layers(in_features, up_scale)
|
| 123 |
+
assert layers is not None, layers
|
| 124 |
+
self.features = nn.Sequential(*layers)
|
| 125 |
+
|
| 126 |
+
def make_deconv_layers(self, in_features, up_scale):
|
| 127 |
+
layers = []
|
| 128 |
+
all_pads=[0,0,1,3,7]
|
| 129 |
+
for i in range(up_scale):
|
| 130 |
+
kernel_size = 2 ** up_scale
|
| 131 |
+
pad = all_pads[up_scale] # kernel_size-1
|
| 132 |
+
out_features = self.compute_out_features(i, up_scale)
|
| 133 |
+
layers.append(nn.Conv2d(in_features, out_features, 1))
|
| 134 |
+
layers.append(Smish())
|
| 135 |
+
layers.append(nn.ConvTranspose2d(
|
| 136 |
+
out_features, out_features, kernel_size, stride=2, padding=pad))
|
| 137 |
+
in_features = out_features
|
| 138 |
+
return layers
|
| 139 |
+
|
| 140 |
+
def compute_out_features(self, idx, up_scale):
|
| 141 |
+
return 1 if idx == up_scale - 1 else self.constant_features
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
return self.features(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SingleConvBlock(nn.Module):
|
| 148 |
+
def __init__(self, in_features, out_features, stride, use_ac=False):
|
| 149 |
+
super(SingleConvBlock, self).__init__()
|
| 150 |
+
# self.use_bn = use_bs
|
| 151 |
+
self.use_ac=use_ac
|
| 152 |
+
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
|
| 153 |
+
bias=True)
|
| 154 |
+
if self.use_ac:
|
| 155 |
+
self.smish = Smish()
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
x = self.conv(x)
|
| 159 |
+
if self.use_ac:
|
| 160 |
+
return self.smish(x)
|
| 161 |
+
else:
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
class DoubleConvBlock(nn.Module):
|
| 165 |
+
def __init__(self, in_features, mid_features,
|
| 166 |
+
out_features=None,
|
| 167 |
+
stride=1,
|
| 168 |
+
use_act=True):
|
| 169 |
+
super(DoubleConvBlock, self).__init__()
|
| 170 |
+
|
| 171 |
+
self.use_act = use_act
|
| 172 |
+
if out_features is None:
|
| 173 |
+
out_features = mid_features
|
| 174 |
+
self.conv1 = nn.Conv2d(in_features, mid_features,
|
| 175 |
+
3, padding=1, stride=stride)
|
| 176 |
+
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
|
| 177 |
+
self.smish= Smish()#nn.ReLU(inplace=True)
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
x = self.conv1(x)
|
| 181 |
+
x = self.smish(x)
|
| 182 |
+
x = self.conv2(x)
|
| 183 |
+
if self.use_act:
|
| 184 |
+
x = self.smish(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class TED(nn.Module):
|
| 189 |
+
""" Definition of Tiny and Efficient Edge Detector
|
| 190 |
+
model
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self):
|
| 194 |
+
super(TED, self).__init__()
|
| 195 |
+
self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,)
|
| 196 |
+
self.block_2 = DoubleConvBlock(16, 32, use_act=False)
|
| 197 |
+
self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
|
| 198 |
+
|
| 199 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 200 |
+
|
| 201 |
+
# skip1 connection, see fig. 2
|
| 202 |
+
self.side_1 = SingleConvBlock(16, 32, 2)
|
| 203 |
+
|
| 204 |
+
# skip2 connection, see fig. 2
|
| 205 |
+
self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
|
| 206 |
+
|
| 207 |
+
# USNet
|
| 208 |
+
self.up_block_1 = UpConvBlock(16, 1)
|
| 209 |
+
self.up_block_2 = UpConvBlock(32, 1)
|
| 210 |
+
self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
|
| 211 |
+
|
| 212 |
+
self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion
|
| 213 |
+
|
| 214 |
+
self.apply(weight_init)
|
| 215 |
+
|
| 216 |
+
def slice(self, tensor, slice_shape):
|
| 217 |
+
t_shape = tensor.shape
|
| 218 |
+
img_h, img_w = slice_shape
|
| 219 |
+
if img_w!=t_shape[-1] or img_h!=t_shape[2]:
|
| 220 |
+
new_tensor = F.interpolate(
|
| 221 |
+
tensor, size=(img_h, img_w), mode='bicubic',align_corners=False)
|
| 222 |
+
|
| 223 |
+
else:
|
| 224 |
+
new_tensor=tensor
|
| 225 |
+
# tensor[..., :height, :width]
|
| 226 |
+
return new_tensor
|
| 227 |
+
def resize_input(self,tensor):
|
| 228 |
+
t_shape = tensor.shape
|
| 229 |
+
if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
|
| 230 |
+
img_w= ((t_shape[3]// 8) + 1) * 8
|
| 231 |
+
img_h = ((t_shape[2] // 8) + 1) * 8
|
| 232 |
+
new_tensor = F.interpolate(
|
| 233 |
+
tensor, size=(img_h, img_w), mode='bicubic', align_corners=False)
|
| 234 |
+
else:
|
| 235 |
+
new_tensor = tensor
|
| 236 |
+
return new_tensor
|
| 237 |
+
|
| 238 |
+
def crop_bdcn(data1, h, w, crop_h, crop_w):
|
| 239 |
+
# Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
|
| 240 |
+
_, _, h1, w1 = data1.size()
|
| 241 |
+
assert (h <= h1 and w <= w1)
|
| 242 |
+
data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w]
|
| 243 |
+
return data
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def forward(self, x, single_test=False):
|
| 247 |
+
assert x.ndim == 4, x.shape
|
| 248 |
+
# supose the image size is 352x352
|
| 249 |
+
|
| 250 |
+
# Block 1
|
| 251 |
+
block_1 = self.block_1(x) # [8,16,176,176]
|
| 252 |
+
block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
|
| 253 |
+
|
| 254 |
+
# Block 2
|
| 255 |
+
block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
|
| 256 |
+
block_2_down = self.maxpool(block_2) # [8,32,88,88]
|
| 257 |
+
block_2_add = block_2_down + block_1_side # [8,32,88,88]
|
| 258 |
+
|
| 259 |
+
# Block 3
|
| 260 |
+
block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
|
| 261 |
+
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
|
| 262 |
+
|
| 263 |
+
# upsampling blocks
|
| 264 |
+
out_1 = self.up_block_1(block_1)
|
| 265 |
+
out_2 = self.up_block_2(block_2)
|
| 266 |
+
out_3 = self.up_block_3(block_3)
|
| 267 |
+
|
| 268 |
+
results = [out_1, out_2, out_3]
|
| 269 |
+
|
| 270 |
+
# concatenate multiscale outputs
|
| 271 |
+
block_cat = torch.cat(results, dim=1) # Bx6xHxW
|
| 272 |
+
block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
|
| 273 |
+
|
| 274 |
+
results.append(block_cat)
|
| 275 |
+
return results
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == '__main__':
|
| 279 |
+
batch_size = 8
|
| 280 |
+
img_height = 352
|
| 281 |
+
img_width = 352
|
| 282 |
+
|
| 283 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 284 |
+
device = "cpu"
|
| 285 |
+
input = torch.rand(batch_size, 3, img_height, img_width).to(device)
|
| 286 |
+
# target = torch.rand(batch_size, 1, img_height, img_width).to(device)
|
| 287 |
+
print(f"input shape: {input.shape}")
|
| 288 |
+
model = TED().to(device)
|
| 289 |
+
output = model(input)
|
| 290 |
+
print(f"output shapes: {[t.shape for t in output]}")
|
| 291 |
+
|
| 292 |
+
# for i in range(20000):
|
| 293 |
+
# print(i)
|
| 294 |
+
# output = model(input)
|
| 295 |
+
# loss = nn.MSELoss()(output[-1], target)
|
| 296 |
+
# loss.backward()
|
train/src/condition/util.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import tempfile
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import suppress
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from huggingface_hub import constants, hf_hub_download
|
| 12 |
+
from torch.hub import get_dir, download_url_to_file
|
| 13 |
+
from ast import literal_eval
|
| 14 |
+
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
def safe_step(x, step=2):
|
| 19 |
+
y = x.astype(np.float32) * float(step + 1)
|
| 20 |
+
y = y.astype(np.int32).astype(np.float32) / float(step)
|
| 21 |
+
return y
|
| 22 |
+
|
| 23 |
+
def nms(x, t, s):
|
| 24 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
| 25 |
+
|
| 26 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
| 27 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
| 28 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
| 29 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
| 30 |
+
|
| 31 |
+
y = np.zeros_like(x)
|
| 32 |
+
|
| 33 |
+
for f in [f1, f2, f3, f4]:
|
| 34 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
| 35 |
+
|
| 36 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
| 37 |
+
z[y > t] = 255
|
| 38 |
+
return z
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def safer_memory(x):
|
| 42 |
+
# Fix many MAC/AMD problems
|
| 43 |
+
return np.ascontiguousarray(x.copy()).copy()
|
| 44 |
+
|
| 45 |
+
UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
|
| 46 |
+
def get_upscale_method(method_str):
|
| 47 |
+
assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
|
| 48 |
+
return getattr(cv2, method_str)
|
| 49 |
+
|
| 50 |
+
def pad64(x):
|
| 51 |
+
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
| 52 |
+
|
| 53 |
+
def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
|
| 54 |
+
if skip_hwc3:
|
| 55 |
+
img = input_image
|
| 56 |
+
else:
|
| 57 |
+
img = HWC3(input_image)
|
| 58 |
+
H_raw, W_raw, _ = img.shape
|
| 59 |
+
if resolution == 0:
|
| 60 |
+
return img, lambda x: x
|
| 61 |
+
k = float(resolution) / float(min(H_raw, W_raw))
|
| 62 |
+
H_target = int(np.round(float(H_raw) * k))
|
| 63 |
+
W_target = int(np.round(float(W_raw) * k))
|
| 64 |
+
img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
|
| 65 |
+
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
| 66 |
+
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
|
| 67 |
+
|
| 68 |
+
def remove_pad(x):
|
| 69 |
+
return safer_memory(x[:H_target, :W_target, ...])
|
| 70 |
+
|
| 71 |
+
return safer_memory(img_padded), remove_pad
|
| 72 |
+
|
| 73 |
+
def common_input_validate(input_image, output_type, **kwargs):
|
| 74 |
+
if "img" in kwargs:
|
| 75 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
| 76 |
+
input_image = kwargs.pop("img")
|
| 77 |
+
|
| 78 |
+
if "return_pil" in kwargs:
|
| 79 |
+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
|
| 80 |
+
output_type = "pil" if kwargs["return_pil"] else "np"
|
| 81 |
+
|
| 82 |
+
if type(output_type) is bool:
|
| 83 |
+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
|
| 84 |
+
if output_type:
|
| 85 |
+
output_type = "pil"
|
| 86 |
+
|
| 87 |
+
if input_image is None:
|
| 88 |
+
raise ValueError("input_image must be defined.")
|
| 89 |
+
|
| 90 |
+
if not isinstance(input_image, np.ndarray):
|
| 91 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
| 92 |
+
output_type = output_type or "pil"
|
| 93 |
+
else:
|
| 94 |
+
output_type = output_type or "np"
|
| 95 |
+
|
| 96 |
+
return (input_image, output_type)
|
| 97 |
+
|
| 98 |
+
def HWC3(x):
|
| 99 |
+
assert x.dtype == np.uint8
|
| 100 |
+
if x.ndim == 2:
|
| 101 |
+
x = x[:, :, None]
|
| 102 |
+
assert x.ndim == 3
|
| 103 |
+
H, W, C = x.shape
|
| 104 |
+
assert C == 1 or C == 3 or C == 4
|
| 105 |
+
if C == 3:
|
| 106 |
+
return x
|
| 107 |
+
if C == 1:
|
| 108 |
+
return np.concatenate([x, x, x], axis=2)
|
| 109 |
+
if C == 4:
|
| 110 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 111 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 112 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 113 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 114 |
+
return y
|
| 115 |
+
|
| 116 |
+
def get_intensity_mask(image_array, lower_bound, upper_bound):
|
| 117 |
+
mask = image_array[:, :, 0]
|
| 118 |
+
mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
|
| 119 |
+
mask = np.expand_dims(mask, 2).repeat(3, axis=2)
|
| 120 |
+
return mask
|
| 121 |
+
|
| 122 |
+
def combine_layers(base_layer, top_layer):
|
| 123 |
+
mask = top_layer.astype(bool)
|
| 124 |
+
temp = 1 - (1 - top_layer) * (1 - base_layer)
|
| 125 |
+
result = base_layer * (~mask) + temp * mask
|
| 126 |
+
return result
|
| 127 |
+
|
| 128 |
+
@torch.jit.script
|
| 129 |
+
def mish(input):
|
| 130 |
+
"""
|
| 131 |
+
Applies the mish function element-wise:
|
| 132 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 133 |
+
See additional documentation for mish class.
|
| 134 |
+
"""
|
| 135 |
+
return input * torch.tanh(F.softplus(input))
|
| 136 |
+
|
| 137 |
+
@torch.jit.script
|
| 138 |
+
def smish(input):
|
| 139 |
+
"""
|
| 140 |
+
Applies the mish function element-wise:
|
| 141 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
|
| 142 |
+
See additional documentation for mish class.
|
| 143 |
+
"""
|
| 144 |
+
return input * torch.tanh(torch.log(1+torch.sigmoid(input)))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Mish(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
Applies the mish function element-wise:
|
| 150 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 151 |
+
Shape:
|
| 152 |
+
- Input: (N, *) where * means, any number of additional
|
| 153 |
+
dimensions
|
| 154 |
+
- Output: (N, *), same shape as the input
|
| 155 |
+
Examples:
|
| 156 |
+
>>> m = Mish()
|
| 157 |
+
>>> input = torch.randn(2)
|
| 158 |
+
>>> output = m(input)
|
| 159 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self):
|
| 163 |
+
"""
|
| 164 |
+
Init method.
|
| 165 |
+
"""
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
def forward(self, input):
|
| 169 |
+
"""
|
| 170 |
+
Forward pass of the function.
|
| 171 |
+
"""
|
| 172 |
+
if torch.__version__ >= "1.9":
|
| 173 |
+
return F.mish(input)
|
| 174 |
+
else:
|
| 175 |
+
return mish(input)
|
| 176 |
+
|
| 177 |
+
class Smish(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
Applies the mish function element-wise:
|
| 180 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 181 |
+
Shape:
|
| 182 |
+
- Input: (N, *) where * means, any number of additional
|
| 183 |
+
dimensions
|
| 184 |
+
- Output: (N, *), same shape as the input
|
| 185 |
+
Examples:
|
| 186 |
+
>>> m = Mish()
|
| 187 |
+
>>> input = torch.randn(2)
|
| 188 |
+
>>> output = m(input)
|
| 189 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self):
|
| 193 |
+
"""
|
| 194 |
+
Init method.
|
| 195 |
+
"""
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
def forward(self, input):
|
| 199 |
+
"""
|
| 200 |
+
Forward pass of the function.
|
| 201 |
+
"""
|
| 202 |
+
return smish(input)
|
train/src/generate_diff_mask.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone script: Given two images, generate a final difference mask using the
|
| 4 |
+
same pipeline as visualize_mask_diff (without any visualization output).
|
| 5 |
+
|
| 6 |
+
Pipeline:
|
| 7 |
+
1) Align images to a preferred resolution/crop so they share the same size.
|
| 8 |
+
2) Pixel-diff screening across parameter combinations; skip if any hull ratio is
|
| 9 |
+
outside [hull_min_allowed, hull_max_allowed].
|
| 10 |
+
3) Color-diff to produce the final mask; remove small areas and re-check hull
|
| 11 |
+
ratio. Save final mask to output path.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
import argparse
|
| 17 |
+
from typing import Tuple, Optional
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 25 |
+
(672, 1568), (688, 1504), (720, 1456), (752, 1392), (800, 1328),
|
| 26 |
+
(832, 1248), (880, 1184), (944, 1104), (1024, 1024), (1104, 944),
|
| 27 |
+
(1184, 880), (1248, 832), (1328, 800), (1392, 752), (1456, 720),
|
| 28 |
+
(1504, 688), (1568, 672),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def choose_preferred_resolution(image_width: int, image_height: int) -> Tuple[int, int]:
|
| 33 |
+
aspect_ratio = image_width / max(1, image_height)
|
| 34 |
+
best = min(((abs(aspect_ratio - (w / h)), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS), key=lambda x: x[0])
|
| 35 |
+
_, w_best, h_best = best
|
| 36 |
+
return int(w_best), int(h_best)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def align_images(source_path: str, target_path: str) -> Tuple[Image.Image, Image.Image]:
|
| 40 |
+
source_img = Image.open(source_path).convert("RGB")
|
| 41 |
+
target_img = Image.open(target_path).convert("RGB")
|
| 42 |
+
|
| 43 |
+
pref_w, pref_h = choose_preferred_resolution(source_img.width, source_img.height)
|
| 44 |
+
source_resized = source_img.resize((pref_w, pref_h), Image.Resampling.LANCZOS)
|
| 45 |
+
|
| 46 |
+
tgt_w, tgt_h = target_img.width, target_img.height
|
| 47 |
+
crop_w = min(source_resized.width, tgt_w)
|
| 48 |
+
crop_h = min(source_resized.height, tgt_h)
|
| 49 |
+
|
| 50 |
+
source_aligned = source_resized.crop((0, 0, crop_w, crop_h))
|
| 51 |
+
target_aligned = target_img.crop((0, 0, crop_w, crop_h))
|
| 52 |
+
return source_aligned, target_aligned
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def pil_to_cv_gray(img: Image.Image) -> np.ndarray:
|
| 56 |
+
bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
| 57 |
+
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
| 58 |
+
return gray
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def generate_pixel_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 11) -> np.ndarray:
|
| 62 |
+
img1_gray = pil_to_cv_gray(img1)
|
| 63 |
+
img2_gray = pil_to_cv_gray(img2)
|
| 64 |
+
diff = cv2.absdiff(img1_gray, img2_gray)
|
| 65 |
+
if threshold is None:
|
| 66 |
+
mask = cv2.threshold(diff, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
| 67 |
+
else:
|
| 68 |
+
mask = cv2.threshold(diff, int(threshold), 255, cv2.THRESH_BINARY)[1]
|
| 69 |
+
if clean_kernel_size and clean_kernel_size > 0:
|
| 70 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size))
|
| 71 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 72 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 73 |
+
return mask
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_color_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 21) -> np.ndarray:
|
| 77 |
+
bgr1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
|
| 78 |
+
bgr2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
|
| 79 |
+
lab1 = cv2.cvtColor(bgr1, cv2.COLOR_BGR2LAB).astype("float32")
|
| 80 |
+
lab2 = cv2.cvtColor(bgr2, cv2.COLOR_BGR2LAB).astype("float32")
|
| 81 |
+
diff = lab1 - lab2
|
| 82 |
+
dist = np.sqrt(np.sum(diff * diff, axis=2))
|
| 83 |
+
dist_u8 = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype("uint8")
|
| 84 |
+
if threshold is None:
|
| 85 |
+
mask = cv2.threshold(dist_u8, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
| 86 |
+
else:
|
| 87 |
+
mask = cv2.threshold(dist_u8, int(threshold), 255, cv2.THRESH_BINARY)[1]
|
| 88 |
+
if clean_kernel_size and clean_kernel_size > 0:
|
| 89 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size))
|
| 90 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 91 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 92 |
+
return mask
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def compute_unified_contour(mask_bin: np.ndarray, contours: list, min_area: int = 40, method: str = "morph", morph_kernel: int = 15, morph_iters: int = 1, approx_epsilon_ratio: float = 0.01):
|
| 96 |
+
valid_cnts = []
|
| 97 |
+
for c in contours:
|
| 98 |
+
if cv2.contourArea(c) >= max(1, min_area):
|
| 99 |
+
valid_cnts.append(c)
|
| 100 |
+
if not valid_cnts:
|
| 101 |
+
return None
|
| 102 |
+
if method == "convex_hull":
|
| 103 |
+
all_points = np.vstack(valid_cnts)
|
| 104 |
+
hull = cv2.convexHull(all_points)
|
| 105 |
+
epsilon = approx_epsilon_ratio * cv2.arcLength(hull, True)
|
| 106 |
+
unified = cv2.approxPolyDP(hull, epsilon, True)
|
| 107 |
+
return unified
|
| 108 |
+
union = np.zeros_like(mask_bin)
|
| 109 |
+
cv2.drawContours(union, valid_cnts, -1, 255, thickness=-1)
|
| 110 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_kernel, morph_kernel))
|
| 111 |
+
union_closed = union.copy()
|
| 112 |
+
for _ in range(max(1, morph_iters)):
|
| 113 |
+
union_closed = cv2.morphologyEx(union_closed, cv2.MORPH_CLOSE, kernel)
|
| 114 |
+
ext = cv2.findContours(union_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 115 |
+
ext = ext[0] if len(ext) == 2 else ext[1]
|
| 116 |
+
if not ext:
|
| 117 |
+
return None
|
| 118 |
+
largest = max(ext, key=cv2.contourArea)
|
| 119 |
+
epsilon = approx_epsilon_ratio * cv2.arcLength(largest, True)
|
| 120 |
+
unified = cv2.approxPolyDP(largest, epsilon, True)
|
| 121 |
+
return unified
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_hull_area_ratio(mask: np.ndarray, min_area: int = 40) -> float:
|
| 125 |
+
mask_bin = (mask > 0).astype("uint8") * 255
|
| 126 |
+
cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 127 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
| 128 |
+
if not cnts:
|
| 129 |
+
return 0.0
|
| 130 |
+
hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=min_area, method="convex_hull", morph_kernel=15, morph_iters=1)
|
| 131 |
+
if hull_cnt is None or len(hull_cnt) < 3:
|
| 132 |
+
return 0.0
|
| 133 |
+
hull_area = float(cv2.contourArea(hull_cnt))
|
| 134 |
+
img_area = float(mask_bin.shape[0] * mask_bin.shape[1])
|
| 135 |
+
return hull_area / max(1.0, img_area)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def clean_and_fill_mask(mask: np.ndarray, min_area: int = 40) -> np.ndarray:
|
| 139 |
+
mask_bin = (mask > 0).astype("uint8") * 255
|
| 140 |
+
cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 141 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
| 142 |
+
cleaned = np.zeros_like(mask_bin)
|
| 143 |
+
for c in cnts:
|
| 144 |
+
if cv2.contourArea(c) >= max(1, min_area):
|
| 145 |
+
cv2.drawContours(cleaned, [c], 0, 255, -1)
|
| 146 |
+
return cleaned
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def generate_final_difference_mask(source_path: str,
|
| 150 |
+
target_path: str,
|
| 151 |
+
hull_min_allowed: float = 0.001,
|
| 152 |
+
hull_max_allowed: float = 0.75,
|
| 153 |
+
pixel_parameters: Optional[list] = None,
|
| 154 |
+
pixel_clean_kernel_default: int = 11,
|
| 155 |
+
color_clean_kernel: int = 3,
|
| 156 |
+
roll_radius: int = 0,
|
| 157 |
+
roll_iters: int = 1) -> Optional[np.ndarray]:
|
| 158 |
+
if pixel_parameters is None:
|
| 159 |
+
# Mirrors the tuned combinations used in visualization script
|
| 160 |
+
pixel_parameters = [(None, 5), (None, 11), (50, 5)]
|
| 161 |
+
|
| 162 |
+
src_img, tgt_img = align_images(source_path, target_path)
|
| 163 |
+
|
| 164 |
+
# Pixel screening across parameter combinations
|
| 165 |
+
violation = False
|
| 166 |
+
for thr, ksize in pixel_parameters:
|
| 167 |
+
pm = generate_pixel_diff_mask(src_img, tgt_img, threshold=thr, clean_kernel_size=ksize)
|
| 168 |
+
r = compute_hull_area_ratio(pm, min_area=40)
|
| 169 |
+
if r < hull_min_allowed or r > hull_max_allowed:
|
| 170 |
+
violation = True
|
| 171 |
+
break
|
| 172 |
+
if violation:
|
| 173 |
+
# Failure: do not produce any mask
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# Color-based final mask → cleaned small areas
|
| 177 |
+
color_mask = generate_color_diff_mask(src_img, tgt_img, threshold=None, clean_kernel_size=color_clean_kernel)
|
| 178 |
+
cleaned = clean_and_fill_mask(color_mask, min_area=40)
|
| 179 |
+
|
| 180 |
+
# Produce binary mask from the convex hull contour of the cleaned mask
|
| 181 |
+
mask_bin = (cleaned > 0).astype("uint8") * 255
|
| 182 |
+
cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 183 |
+
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
|
| 184 |
+
hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=40, method="convex_hull", morph_kernel=15, morph_iters=1)
|
| 185 |
+
if hull_cnt is None or len(hull_cnt) < 3:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
h_mask = np.zeros_like(mask_bin)
|
| 189 |
+
cv2.drawContours(h_mask, [hull_cnt], -1, 255, thickness=-1)
|
| 190 |
+
|
| 191 |
+
# Rolling-circle smoothing: closing then opening with a disk of radius R
|
| 192 |
+
if roll_radius and roll_radius > 0 and roll_iters and roll_iters > 0:
|
| 193 |
+
ksize = max(1, 2 * int(roll_radius) + 1)
|
| 194 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
|
| 195 |
+
for _ in range(max(1, roll_iters)):
|
| 196 |
+
h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_CLOSE, kernel)
|
| 197 |
+
h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_OPEN, kernel)
|
| 198 |
+
|
| 199 |
+
# Final hull ratio check on the hull-filled binary mask
|
| 200 |
+
r_final = compute_hull_area_ratio(h_mask, min_area=40)
|
| 201 |
+
if r_final > hull_max_allowed or r_final < hull_min_allowed:
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
return h_mask
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def main():
|
| 208 |
+
parser = argparse.ArgumentParser(description="Generate final difference mask (single pair or whole dataset)")
|
| 209 |
+
# Single-pair mode (optional): if provided, runs single pair; otherwise runs dataset mode
|
| 210 |
+
parser.add_argument("--source", help="Path to source image")
|
| 211 |
+
parser.add_argument("--target", help="Path to target image")
|
| 212 |
+
parser.add_argument("--output", help="Path to write the final mask (PNG)")
|
| 213 |
+
# Dataset mode (defaults to user's dataset paths)
|
| 214 |
+
parser.add_argument("--dataset_dir", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset", help="Base dataset dir with source_images/ and target_images/")
|
| 215 |
+
parser.add_argument("--dataset_output_dir", default="/home/lzc/KontextFill/visualizations_masks/inference_masks_smoothing", help="Output directory for batch masks")
|
| 216 |
+
parser.add_argument("--json_path", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset/extracted_data.json", help="Dataset JSON mapping with fields 'source_image' and 'target_image'")
|
| 217 |
+
# Common params
|
| 218 |
+
parser.add_argument("--hull_min_allowed", type=float, default=0.001)
|
| 219 |
+
parser.add_argument("--hull_max_allowed", type=float, default=0.75)
|
| 220 |
+
parser.add_argument("--color_clean_kernel", type=int, default=3)
|
| 221 |
+
parser.add_argument("--roll_radius", type=int, default=15, help="Rolling-circle smoothing radius (pixels); 0 disables")
|
| 222 |
+
parser.add_argument("--roll_iters", type=int, default=5, help="Rolling smoothing iterations")
|
| 223 |
+
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
|
| 226 |
+
pixel_parameters = [(None, 5), (None, 11), (50, 5)]
|
| 227 |
+
|
| 228 |
+
# Decide mode: single or dataset
|
| 229 |
+
if args.source and args.target and args.output:
|
| 230 |
+
mask = generate_final_difference_mask(
|
| 231 |
+
source_path=args.source,
|
| 232 |
+
target_path=args.target,
|
| 233 |
+
hull_min_allowed=args.hull_min_allowed,
|
| 234 |
+
hull_max_allowed=args.hull_max_allowed,
|
| 235 |
+
pixel_parameters=pixel_parameters,
|
| 236 |
+
color_clean_kernel=args.color_clean_kernel,
|
| 237 |
+
roll_radius=args.roll_radius,
|
| 238 |
+
roll_iters=args.roll_iters,
|
| 239 |
+
)
|
| 240 |
+
if mask is None:
|
| 241 |
+
print("Single-pair inference failed; no output saved.")
|
| 242 |
+
return
|
| 243 |
+
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| 244 |
+
cv2.imwrite(args.output, mask)
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
# Dataset mode using JSON mapping
|
| 248 |
+
out_dir = args.dataset_output_dir
|
| 249 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 250 |
+
|
| 251 |
+
processed = 0
|
| 252 |
+
skipped = 0
|
| 253 |
+
failed = 0
|
| 254 |
+
missing_files = 0
|
| 255 |
+
try:
|
| 256 |
+
with open(args.json_path, "r", encoding="utf-8") as f:
|
| 257 |
+
entries = json.load(f)
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Failed to read JSON mapping at {args.json_path}: {e}")
|
| 260 |
+
entries = []
|
| 261 |
+
|
| 262 |
+
for item in entries:
|
| 263 |
+
try:
|
| 264 |
+
src_rel = item.get("source_image")
|
| 265 |
+
tgt_rel = item.get("target_image")
|
| 266 |
+
edit_id = item.get("id")
|
| 267 |
+
if not src_rel or not tgt_rel:
|
| 268 |
+
skipped += 1
|
| 269 |
+
continue
|
| 270 |
+
s = os.path.join(args.dataset_dir, src_rel)
|
| 271 |
+
t = os.path.join(args.dataset_dir, tgt_rel)
|
| 272 |
+
if not (os.path.exists(s) and os.path.exists(t)):
|
| 273 |
+
missing_files += 1
|
| 274 |
+
continue
|
| 275 |
+
mask = generate_final_difference_mask(
|
| 276 |
+
source_path=s,
|
| 277 |
+
target_path=t,
|
| 278 |
+
hull_min_allowed=args.hull_min_allowed,
|
| 279 |
+
hull_max_allowed=args.hull_max_allowed,
|
| 280 |
+
pixel_parameters=pixel_parameters,
|
| 281 |
+
color_clean_kernel=args.color_clean_kernel,
|
| 282 |
+
roll_radius=args.roll_radius,
|
| 283 |
+
roll_iters=args.roll_iters,
|
| 284 |
+
)
|
| 285 |
+
if mask is None:
|
| 286 |
+
failed += 1
|
| 287 |
+
continue
|
| 288 |
+
name = f"edit_{int(edit_id):04d}" if isinstance(edit_id, int) or (isinstance(edit_id, str) and edit_id.isdigit()) else os.path.splitext(os.path.basename(src_rel))[0]
|
| 289 |
+
out_path = os.path.join(out_dir, f"{name}.png")
|
| 290 |
+
cv2.imwrite(out_path, mask)
|
| 291 |
+
processed += 1
|
| 292 |
+
except Exception as e:
|
| 293 |
+
skipped += 1
|
| 294 |
+
continue
|
| 295 |
+
print(f"Batch done. Processed={processed}, Failed={failed}, Skipped={skipped}, MissingFiles={missing_files}, OutputDir={out_dir}")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
main()
|
| 300 |
+
|
| 301 |
+
|
train/src/jsonl_datasets_kontext_color.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .jsonl_datasets_colorization import FlexibleColorDetector
|
| 10 |
+
|
| 11 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 12 |
+
|
| 13 |
+
def multiple_16(num: float):
|
| 14 |
+
return int(round(num / 16) * 16)
|
| 15 |
+
|
| 16 |
+
def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"):
|
| 17 |
+
image_path = os.path.join(root, image_path)
|
| 18 |
+
try:
|
| 19 |
+
image = Image.open(image_path).convert("RGB")
|
| 20 |
+
return image
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print("file error: "+image_path)
|
| 23 |
+
with open("failed_images.txt", "a") as f:
|
| 24 |
+
f.write(f"{image_path}\n")
|
| 25 |
+
return Image.new("RGB", (size, size), (255, 255, 255))
|
| 26 |
+
|
| 27 |
+
def choose_kontext_resolution_from_wh(width: int, height: int):
|
| 28 |
+
aspect_ratio = width / max(1, height)
|
| 29 |
+
_, best_w, best_h = min(
|
| 30 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 31 |
+
)
|
| 32 |
+
return best_w, best_h
|
| 33 |
+
|
| 34 |
+
color_detector = FlexibleColorDetector()
|
| 35 |
+
|
| 36 |
+
def collate_fn(examples):
|
| 37 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 38 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 39 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 40 |
+
else:
|
| 41 |
+
cond_pixel_values = None
|
| 42 |
+
# source_pixel_values 被移除,保持兼容返回 None
|
| 43 |
+
source_pixel_values = None
|
| 44 |
+
|
| 45 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 46 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 47 |
+
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
|
| 48 |
+
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"cond_pixel_values": cond_pixel_values,
|
| 52 |
+
"source_pixel_values": source_pixel_values,
|
| 53 |
+
"pixel_values": target_pixel_values,
|
| 54 |
+
"text_ids_1": token_ids_clip,
|
| 55 |
+
"text_ids_2": token_ids_t5,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None):
|
| 60 |
+
# 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption
|
| 61 |
+
if args.train_data_dir is not None:
|
| 62 |
+
dataset = load_dataset('csv', data_files=args.train_data_dir)
|
| 63 |
+
|
| 64 |
+
# 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption
|
| 65 |
+
column_names = dataset["train"].column_names
|
| 66 |
+
image_col = column_names[0]
|
| 67 |
+
caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1]
|
| 68 |
+
|
| 69 |
+
size = args.cond_size
|
| 70 |
+
|
| 71 |
+
# 设备设置(保留接口,以后需要时可用)
|
| 72 |
+
if accelerator is not None:
|
| 73 |
+
device = accelerator.device
|
| 74 |
+
else:
|
| 75 |
+
device = "cpu"
|
| 76 |
+
|
| 77 |
+
# Transforms
|
| 78 |
+
to_tensor_and_norm = transforms.Compose([
|
| 79 |
+
transforms.ToTensor(),
|
| 80 |
+
transforms.Normalize([0.5], [0.5]),
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# cond 与 colorization 保持一致:CenterCrop -> ToTensor -> Normalize
|
| 84 |
+
cond_train_transforms = transforms.Compose([
|
| 85 |
+
transforms.CenterCrop((size, size)),
|
| 86 |
+
transforms.ToTensor(),
|
| 87 |
+
transforms.Normalize([0.5], [0.5]),
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
tokenizer_clip = tokenizers[0]
|
| 91 |
+
tokenizer_t5 = tokenizers[1]
|
| 92 |
+
|
| 93 |
+
def tokenize_prompt_clip_t5(examples):
|
| 94 |
+
captions_raw = examples[caption_col]
|
| 95 |
+
captions = []
|
| 96 |
+
for c in captions_raw:
|
| 97 |
+
if isinstance(c, str):
|
| 98 |
+
if random.random() < 0.25:
|
| 99 |
+
captions.append("")
|
| 100 |
+
else:
|
| 101 |
+
captions.append(c)
|
| 102 |
+
else:
|
| 103 |
+
captions.append("")
|
| 104 |
+
|
| 105 |
+
text_inputs_clip = tokenizer_clip(
|
| 106 |
+
captions,
|
| 107 |
+
padding="max_length",
|
| 108 |
+
max_length=77,
|
| 109 |
+
truncation=True,
|
| 110 |
+
return_length=False,
|
| 111 |
+
return_overflowing_tokens=False,
|
| 112 |
+
return_tensors="pt",
|
| 113 |
+
)
|
| 114 |
+
text_input_ids_1 = text_inputs_clip.input_ids
|
| 115 |
+
|
| 116 |
+
text_inputs_t5 = tokenizer_t5(
|
| 117 |
+
captions,
|
| 118 |
+
padding="max_length",
|
| 119 |
+
max_length=128,
|
| 120 |
+
truncation=True,
|
| 121 |
+
return_length=False,
|
| 122 |
+
return_overflowing_tokens=False,
|
| 123 |
+
return_tensors="pt",
|
| 124 |
+
)
|
| 125 |
+
text_input_ids_2 = text_inputs_t5.input_ids
|
| 126 |
+
return text_input_ids_1, text_input_ids_2
|
| 127 |
+
|
| 128 |
+
def preprocess_train(examples):
|
| 129 |
+
batch = {}
|
| 130 |
+
|
| 131 |
+
img_paths = examples[image_col]
|
| 132 |
+
|
| 133 |
+
target_tensors = []
|
| 134 |
+
cond_tensors = []
|
| 135 |
+
|
| 136 |
+
for p in img_paths:
|
| 137 |
+
# Load image by joining with root in load_image_safely
|
| 138 |
+
img = load_image_safely(p, size)
|
| 139 |
+
img = img.convert("RGB")
|
| 140 |
+
|
| 141 |
+
# Resize to Kontext preferred resolution for target
|
| 142 |
+
w, h = img.size
|
| 143 |
+
best_w, best_h = choose_kontext_resolution_from_wh(w, h)
|
| 144 |
+
img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR)
|
| 145 |
+
target_tensor = to_tensor_and_norm(img_rs)
|
| 146 |
+
|
| 147 |
+
# Build color block condition
|
| 148 |
+
color_blocks = color_detector(input_image=img, block_size=32, output_size=size)
|
| 149 |
+
edge_tensor = cond_train_transforms(color_blocks)
|
| 150 |
+
|
| 151 |
+
target_tensors.append(target_tensor)
|
| 152 |
+
cond_tensors.append(edge_tensor)
|
| 153 |
+
|
| 154 |
+
batch["pixel_values"] = target_tensors
|
| 155 |
+
batch["cond_pixel_values"] = cond_tensors
|
| 156 |
+
|
| 157 |
+
batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
|
| 158 |
+
return batch
|
| 159 |
+
|
| 160 |
+
if accelerator is not None:
|
| 161 |
+
with accelerator.main_process_first():
|
| 162 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 163 |
+
else:
|
| 164 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 165 |
+
|
| 166 |
+
return train_dataset
|
train/src/jsonl_datasets_kontext_complete_lora.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
import torchvision.transforms.functional as TF
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 12 |
+
|
| 13 |
+
def _prepend_caption(description: str, obj_name: str) -> str:
|
| 14 |
+
"""Build instruction with stochastic OBJECT choice and keep only instruction with 20% prob.
|
| 15 |
+
|
| 16 |
+
OBJECT choice (equal probability):
|
| 17 |
+
- literal string "object"
|
| 18 |
+
- JSON field `object` with '_' replaced by space
|
| 19 |
+
- JSON field `description`
|
| 20 |
+
"""
|
| 21 |
+
# Prepare options for OBJECT slot
|
| 22 |
+
cleaned_obj = (obj_name or "object").replace("_", " ").strip() or "object"
|
| 23 |
+
desc_opt = (description or "object").strip() or "object"
|
| 24 |
+
object_slot = random.choice(["object", cleaned_obj, desc_opt])
|
| 25 |
+
|
| 26 |
+
instruction = f"Complete the {object_slot}'s missing parts if necessary. White Background;"
|
| 27 |
+
|
| 28 |
+
return instruction
|
| 29 |
+
|
| 30 |
+
def collate_fn(examples):
|
| 31 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 32 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 33 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 34 |
+
else:
|
| 35 |
+
cond_pixel_values = None
|
| 36 |
+
|
| 37 |
+
if examples[0].get("source_pixel_values") is not None:
|
| 38 |
+
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
|
| 39 |
+
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 40 |
+
else:
|
| 41 |
+
source_pixel_values = None
|
| 42 |
+
|
| 43 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 44 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 45 |
+
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
|
| 46 |
+
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
|
| 47 |
+
|
| 48 |
+
mask_values = None
|
| 49 |
+
if examples[0].get("mask_values") is not None:
|
| 50 |
+
mask_values = torch.stack([example["mask_values"] for example in examples])
|
| 51 |
+
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"cond_pixel_values": cond_pixel_values,
|
| 55 |
+
"source_pixel_values": source_pixel_values,
|
| 56 |
+
"pixel_values": target_pixel_values,
|
| 57 |
+
"text_ids_1": token_ids_clip,
|
| 58 |
+
"text_ids_2": token_ids_t5,
|
| 59 |
+
"mask_values": mask_values,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _resolve_jsonl(path_str: str):
|
| 64 |
+
if path_str is None or str(path_str).strip() == "":
|
| 65 |
+
raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.")
|
| 66 |
+
if os.path.isdir(path_str):
|
| 67 |
+
files = [
|
| 68 |
+
os.path.join(path_str, f)
|
| 69 |
+
for f in os.listdir(path_str)
|
| 70 |
+
if f.lower().endswith((".jsonl", ".json"))
|
| 71 |
+
]
|
| 72 |
+
if not files:
|
| 73 |
+
raise ValueError(f"No .json or .jsonl files found under directory: {path_str}")
|
| 74 |
+
return {"train": sorted(files)}
|
| 75 |
+
if not os.path.exists(path_str):
|
| 76 |
+
raise FileNotFoundError(f"train_data_jsonl not found: {path_str}")
|
| 77 |
+
return {"train": [path_str]}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _tokenize(tokenizers, caption: str):
|
| 81 |
+
tokenizer_clip = tokenizers[0]
|
| 82 |
+
tokenizer_t5 = tokenizers[1]
|
| 83 |
+
text_inputs_clip = tokenizer_clip(
|
| 84 |
+
[caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
| 85 |
+
)
|
| 86 |
+
text_inputs_t5 = tokenizer_t5(
|
| 87 |
+
[caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt"
|
| 88 |
+
)
|
| 89 |
+
return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0]
|
| 90 |
+
|
| 91 |
+
def _apply_white_brushstrokes(image_np: np.ndarray, mask_bin: np.ndarray = None) -> np.ndarray:
|
| 92 |
+
"""Draw random white brushstrokes on the RGB image array and return modified array.
|
| 93 |
+
Strokes preferentially start within mask_bin if provided.
|
| 94 |
+
"""
|
| 95 |
+
import cv2
|
| 96 |
+
h, w = image_np.shape[:2]
|
| 97 |
+
rng = random.Random()
|
| 98 |
+
|
| 99 |
+
# Determine stroke counts and sizes based on image size
|
| 100 |
+
ref = max(1, min(h, w))
|
| 101 |
+
num_strokes = rng.randint(1, 5)
|
| 102 |
+
max_offset = max(5, ref // 40)
|
| 103 |
+
min_th = max(2, ref // 40)
|
| 104 |
+
max_th = max(min_th + 1, ref // 5)
|
| 105 |
+
|
| 106 |
+
out = image_np.copy()
|
| 107 |
+
prefer_mask_p = 0.33 if mask_bin is not None and mask_bin.any() else 0.0
|
| 108 |
+
|
| 109 |
+
def rand_point_inside_mask():
|
| 110 |
+
ys, xs = np.where(mask_bin > 0)
|
| 111 |
+
if len(xs) == 0:
|
| 112 |
+
return rng.randrange(w), rng.randrange(h)
|
| 113 |
+
i = rng.randrange(len(xs))
|
| 114 |
+
return int(xs[i]), int(ys[i])
|
| 115 |
+
|
| 116 |
+
def rand_point_any():
|
| 117 |
+
return rng.randrange(w), rng.randrange(h)
|
| 118 |
+
|
| 119 |
+
for _ in range(num_strokes):
|
| 120 |
+
if rng.random() < prefer_mask_p:
|
| 121 |
+
px, py = rand_point_inside_mask()
|
| 122 |
+
else:
|
| 123 |
+
px, py = rand_point_any()
|
| 124 |
+
px, py = rand_point_any()
|
| 125 |
+
|
| 126 |
+
# Polyline with several jittered segments
|
| 127 |
+
segments = rng.randint(40, 80)
|
| 128 |
+
thickness = rng.randint(min_th, max_th)
|
| 129 |
+
for _ in range(segments):
|
| 130 |
+
dx = rng.randint(-max_offset, max_offset)
|
| 131 |
+
dy = rng.randint(-max_offset, max_offset)
|
| 132 |
+
nx = int(np.clip(px + dx, 0, w - 1))
|
| 133 |
+
ny = int(np.clip(py + dy, 0, h - 1))
|
| 134 |
+
cv2.line(out, (px, py), (nx, ny), (255, 255, 255), thickness)
|
| 135 |
+
px, py = nx, ny
|
| 136 |
+
|
| 137 |
+
return out
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def make_train_dataset_subjects(args, tokenizers, accelerator=None):
|
| 141 |
+
"""
|
| 142 |
+
Dataset for JSONL with fields (one JSON object per line):
|
| 143 |
+
- white_image_path: absolute path to base image used for both pixel_values and source_pixel_values
|
| 144 |
+
- mask_path: absolute path to mask image (grayscale)
|
| 145 |
+
- img_width: target width
|
| 146 |
+
- img_height: target height
|
| 147 |
+
- description: caption text
|
| 148 |
+
|
| 149 |
+
Behavior:
|
| 150 |
+
- pixel_values = white_image_path resized to (img_width, img_height)
|
| 151 |
+
- source_pixel_values = same image but with random white brushstrokes overlaid
|
| 152 |
+
- mask_values = binarized mask from mask_path resized with nearest neighbor
|
| 153 |
+
- captions tokenized from description
|
| 154 |
+
"""
|
| 155 |
+
data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None))
|
| 156 |
+
file_paths = data_files.get("train", [])
|
| 157 |
+
records = []
|
| 158 |
+
for p in file_paths:
|
| 159 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 160 |
+
for line in f:
|
| 161 |
+
line = line.strip()
|
| 162 |
+
if not line:
|
| 163 |
+
continue
|
| 164 |
+
try:
|
| 165 |
+
obj = json.loads(line)
|
| 166 |
+
except Exception:
|
| 167 |
+
# Best-effort: strip any trailing commas and retry
|
| 168 |
+
try:
|
| 169 |
+
obj = json.loads(line.rstrip(","))
|
| 170 |
+
except Exception:
|
| 171 |
+
continue
|
| 172 |
+
# Keep only fields we need for this dataset schema
|
| 173 |
+
pruned = {
|
| 174 |
+
"white_image_path": obj.get("white_image_path"),
|
| 175 |
+
"mask_path": obj.get("mask_path"),
|
| 176 |
+
"img_width": obj.get("img_width"),
|
| 177 |
+
"img_height": obj.get("img_height"),
|
| 178 |
+
"description": obj.get("description"),
|
| 179 |
+
"object": obj.get("object"),
|
| 180 |
+
}
|
| 181 |
+
records.append(pruned)
|
| 182 |
+
|
| 183 |
+
size = int(getattr(args, "cond_size", 512))
|
| 184 |
+
|
| 185 |
+
to_tensor_and_norm = transforms.Compose([
|
| 186 |
+
transforms.ToTensor(),
|
| 187 |
+
transforms.Normalize([0.5], [0.5]),
|
| 188 |
+
])
|
| 189 |
+
|
| 190 |
+
# Repeat each record with independent random brushstrokes
|
| 191 |
+
REPEATS_PER_IMAGE = 5
|
| 192 |
+
|
| 193 |
+
class SubjectsDataset(torch.utils.data.Dataset):
|
| 194 |
+
def __init__(self, hf_ds):
|
| 195 |
+
self.ds = hf_ds
|
| 196 |
+
self.repeats = REPEATS_PER_IMAGE
|
| 197 |
+
def __len__(self):
|
| 198 |
+
if self.repeats and self.repeats > 1:
|
| 199 |
+
return len(self.ds) * self.repeats
|
| 200 |
+
return len(self.ds)
|
| 201 |
+
def __getitem__(self, idx):
|
| 202 |
+
if self.repeats and self.repeats > 1:
|
| 203 |
+
base_idx = idx % len(self.ds)
|
| 204 |
+
else:
|
| 205 |
+
base_idx = idx
|
| 206 |
+
rec = self.ds[base_idx]
|
| 207 |
+
|
| 208 |
+
white_p = rec.get("white_image_path", "") or ""
|
| 209 |
+
mask_p = rec.get("mask_path", "") or ""
|
| 210 |
+
|
| 211 |
+
if not os.path.isabs(white_p):
|
| 212 |
+
# Allow absolute path only to avoid ambiguity
|
| 213 |
+
raise ValueError("white_image_path must be absolute")
|
| 214 |
+
if not os.path.isabs(mask_p):
|
| 215 |
+
raise ValueError("mask_path must be absolute")
|
| 216 |
+
|
| 217 |
+
import cv2
|
| 218 |
+
mask_loaded = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
|
| 219 |
+
if mask_loaded is None:
|
| 220 |
+
raise ValueError(f"Failed to read mask: {mask_p}")
|
| 221 |
+
|
| 222 |
+
base_img = Image.open(white_p).convert("RGB")
|
| 223 |
+
|
| 224 |
+
# Desired output size
|
| 225 |
+
fw = int(rec.get("img_width") or base_img.width)
|
| 226 |
+
fh = int(rec.get("img_height") or base_img.height)
|
| 227 |
+
base_img = base_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 228 |
+
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
|
| 229 |
+
|
| 230 |
+
# Tensors: target is the clean white image
|
| 231 |
+
target_tensor = to_tensor_and_norm(base_img)
|
| 232 |
+
|
| 233 |
+
# Binary mask at final_size
|
| 234 |
+
mask_np = np.array(mask_img)
|
| 235 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 236 |
+
|
| 237 |
+
# Build source by drawing random white brushstrokes on top of the white image
|
| 238 |
+
base_np = np.array(base_img).astype(np.uint8)
|
| 239 |
+
stroked_np = _apply_white_brushstrokes(base_np, mask_bin)
|
| 240 |
+
|
| 241 |
+
# Build tensors
|
| 242 |
+
source_tensor = to_tensor_and_norm(Image.fromarray(stroked_np.astype(np.uint8)))
|
| 243 |
+
mask_tensor = torch.from_numpy(mask_bin.astype(np.float32)).unsqueeze(0)
|
| 244 |
+
|
| 245 |
+
# Caption: build instruction using description and object
|
| 246 |
+
description = rec.get("description", "")
|
| 247 |
+
obj_name = rec.get("object", "")
|
| 248 |
+
cap = _prepend_caption(description, obj_name)
|
| 249 |
+
ids1, ids2 = _tokenize(tokenizers, cap)
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"source_pixel_values": source_tensor,
|
| 253 |
+
"pixel_values": target_tensor,
|
| 254 |
+
"token_ids_clip": ids1,
|
| 255 |
+
"token_ids_t5": ids2,
|
| 256 |
+
"mask_values": mask_tensor,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return SubjectsDataset(records)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _run_test_mode(test_jsonl: str, output_dir: str, num_samples: int = 50):
|
| 265 |
+
"""Utility to visualize augmentation: saves pairs of (target, source) images.
|
| 266 |
+
Reads the JSONL directly, applies the same logic as dataset to produce
|
| 267 |
+
pixel_values (target) and source_pixel_values (with white strokes),
|
| 268 |
+
then writes them to output_dir for manual inspection.
|
| 269 |
+
"""
|
| 270 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 271 |
+
to_tensor_and_norm = transforms.Compose([
|
| 272 |
+
transforms.ToTensor(),
|
| 273 |
+
transforms.Normalize([0.5], [0.5]),
|
| 274 |
+
])
|
| 275 |
+
|
| 276 |
+
# Minimal tokenizers shim to reuse dataset tokenization pipeline
|
| 277 |
+
class _NoOpTokenizer:
|
| 278 |
+
def __call__(self, texts, padding=None, max_length=None, truncation=None, return_tensors=None):
|
| 279 |
+
return type("T", (), {"input_ids": torch.zeros((1, 1), dtype=torch.long)})()
|
| 280 |
+
|
| 281 |
+
tokenizers = [_NoOpTokenizer(), _NoOpTokenizer()]
|
| 282 |
+
|
| 283 |
+
saved = 0
|
| 284 |
+
line_idx = 0
|
| 285 |
+
import cv2
|
| 286 |
+
with open(test_jsonl, "r", encoding="utf-8") as f:
|
| 287 |
+
for raw in f:
|
| 288 |
+
if saved >= num_samples:
|
| 289 |
+
break
|
| 290 |
+
raw = raw.strip()
|
| 291 |
+
if not raw:
|
| 292 |
+
continue
|
| 293 |
+
try:
|
| 294 |
+
obj = json.loads(raw)
|
| 295 |
+
except Exception:
|
| 296 |
+
try:
|
| 297 |
+
obj = json.loads(raw.rstrip(","))
|
| 298 |
+
except Exception:
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
rec = {
|
| 302 |
+
"white_image_path": obj.get("white_image_path"),
|
| 303 |
+
"mask_path": obj.get("mask_path"),
|
| 304 |
+
"img_width": obj.get("img_width"),
|
| 305 |
+
"img_height": obj.get("img_height"),
|
| 306 |
+
"description": obj.get("description"),
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
white_p = rec.get("white_image_path", "") or ""
|
| 310 |
+
mask_p = rec.get("mask_path", "") or ""
|
| 311 |
+
if not white_p or not mask_p:
|
| 312 |
+
continue
|
| 313 |
+
if not (os.path.isabs(white_p) and os.path.isabs(mask_p)):
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
mask_loaded = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
|
| 317 |
+
if mask_loaded is None:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
base_img = Image.open(white_p).convert("RGB")
|
| 322 |
+
except Exception:
|
| 323 |
+
continue
|
| 324 |
+
|
| 325 |
+
fw = int(rec.get("img_width") or base_img.width)
|
| 326 |
+
fh = int(rec.get("img_height") or base_img.height)
|
| 327 |
+
base_img = base_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 328 |
+
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
|
| 329 |
+
|
| 330 |
+
mask_np = np.array(mask_img)
|
| 331 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 332 |
+
|
| 333 |
+
base_np = np.array(base_img).astype(np.uint8)
|
| 334 |
+
stroked_np = _apply_white_brushstrokes(base_np, mask_bin)
|
| 335 |
+
|
| 336 |
+
# Save images
|
| 337 |
+
idx_str = f"{line_idx:05d}"
|
| 338 |
+
try:
|
| 339 |
+
Image.fromarray(base_np).save(os.path.join(output_dir, f"{idx_str}_target.jpg"))
|
| 340 |
+
Image.fromarray(stroked_np).save(os.path.join(output_dir, f"{idx_str}_source.jpg"))
|
| 341 |
+
Image.fromarray((mask_bin * 255).astype(np.uint8)).save(os.path.join(output_dir, f"{idx_str}_mask.png"))
|
| 342 |
+
saved += 1
|
| 343 |
+
except Exception:
|
| 344 |
+
pass
|
| 345 |
+
line_idx += 1
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _parse_test_args():
|
| 349 |
+
import argparse
|
| 350 |
+
parser = argparse.ArgumentParser(description="Test visualization for Kontext complete dataset")
|
| 351 |
+
parser.add_argument("--test_jsonl", type=str, default="/robby/share/Editing/lzc/subject_completion/white_bg_picked/results_picked_filtered.jsonl", help="Path to JSONL to preview")
|
| 352 |
+
parser.add_argument("--output_dir", type=str, default="/robby/share/Editing/lzc/subject_completion/train_test", help="Output directory to save pairs")
|
| 353 |
+
parser.add_argument("--num_samples", type=int, default=50, help="Number of pairs to save")
|
| 354 |
+
return parser.parse_args()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
try:
|
| 359 |
+
args = _parse_test_args()
|
| 360 |
+
_run_test_mode(args.test_jsonl, args.output_dir, args.num_samples)
|
| 361 |
+
except SystemExit:
|
| 362 |
+
# Allow import usage without triggering test mode
|
| 363 |
+
pass
|
train/src/jsonl_datasets_kontext_edge.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
|
| 8 |
+
import numpy as np
|
| 9 |
+
from src.condition.edge_extraction import (
|
| 10 |
+
CannyDetector, PidiNetDetector, TEDDetector, LineartStandardDetector, HEDdetector,
|
| 11 |
+
AnyLinePreprocessor, LineartDetector, InformativeDetector
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 15 |
+
|
| 16 |
+
def multiple_16(num: float):
|
| 17 |
+
return int(round(num / 16) * 16)
|
| 18 |
+
|
| 19 |
+
def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"):
|
| 20 |
+
image_path = os.path.join(root, image_path)
|
| 21 |
+
try:
|
| 22 |
+
image = Image.open(image_path).convert("RGB")
|
| 23 |
+
return image
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print("file error: "+image_path)
|
| 26 |
+
with open("failed_images.txt", "a") as f:
|
| 27 |
+
f.write(f"{image_path}\n")
|
| 28 |
+
return Image.new("RGB", (size, size), (255, 255, 255))
|
| 29 |
+
|
| 30 |
+
def choose_kontext_resolution_from_wh(width: int, height: int):
|
| 31 |
+
aspect_ratio = width / max(1, height)
|
| 32 |
+
_, best_w, best_h = min(
|
| 33 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 34 |
+
)
|
| 35 |
+
return best_w, best_h
|
| 36 |
+
|
| 37 |
+
class EdgeExtractorManager:
|
| 38 |
+
_instance = None
|
| 39 |
+
_initialized = False
|
| 40 |
+
|
| 41 |
+
def __new__(cls):
|
| 42 |
+
if cls._instance is None:
|
| 43 |
+
cls._instance = super(EdgeExtractorManager, cls).__new__(cls)
|
| 44 |
+
return cls._instance
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
if not self._initialized:
|
| 48 |
+
self.edge_extractors = None
|
| 49 |
+
self.device = None
|
| 50 |
+
self._initialized = True
|
| 51 |
+
|
| 52 |
+
def set_device(self, device):
|
| 53 |
+
self.device = device
|
| 54 |
+
|
| 55 |
+
def get_edge_extractors(self, device=None):
|
| 56 |
+
# 强制在CPU上初始化,避免DataLoader子进程中触发CUDA初始化
|
| 57 |
+
current_device = "cpu"
|
| 58 |
+
if device is not None:
|
| 59 |
+
self.set_device(current_device)
|
| 60 |
+
|
| 61 |
+
if self.edge_extractors is None or len(self.edge_extractors) == 0:
|
| 62 |
+
self.edge_extractors = [
|
| 63 |
+
("canny", CannyDetector()),
|
| 64 |
+
("pidinet", PidiNetDetector.from_pretrained().to(current_device)),
|
| 65 |
+
("ted", TEDDetector.from_pretrained().to(current_device)),
|
| 66 |
+
# ("lineart_standard", LineartStandardDetector()),
|
| 67 |
+
("hed", HEDdetector.from_pretrained().to(current_device)),
|
| 68 |
+
("anyline", AnyLinePreprocessor.from_pretrained().to(current_device)),
|
| 69 |
+
# ("lineart", LineartDetector.from_pretrained().to(current_device)),
|
| 70 |
+
("informative", InformativeDetector.from_pretrained().to(current_device)),
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
return self.edge_extractors
|
| 74 |
+
|
| 75 |
+
edge_extractor_manager = EdgeExtractorManager()
|
| 76 |
+
|
| 77 |
+
def collate_fn(examples):
|
| 78 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 79 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 80 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 81 |
+
else:
|
| 82 |
+
cond_pixel_values = None
|
| 83 |
+
source_pixel_values = None
|
| 84 |
+
|
| 85 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 86 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 87 |
+
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
|
| 88 |
+
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"cond_pixel_values": cond_pixel_values,
|
| 92 |
+
"source_pixel_values": source_pixel_values,
|
| 93 |
+
"pixel_values": target_pixel_values,
|
| 94 |
+
"text_ids_1": token_ids_clip,
|
| 95 |
+
"text_ids_2": token_ids_t5,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None):
|
| 100 |
+
# 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption
|
| 101 |
+
if args.train_data_dir is not None:
|
| 102 |
+
dataset = load_dataset('csv', data_files=args.train_data_dir)
|
| 103 |
+
|
| 104 |
+
# 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption
|
| 105 |
+
column_names = dataset["train"].column_names
|
| 106 |
+
image_col = column_names[0]
|
| 107 |
+
caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1]
|
| 108 |
+
|
| 109 |
+
size = args.cond_size
|
| 110 |
+
|
| 111 |
+
# 设备设置(用于分布式时将部分检测器放到对应GPU)
|
| 112 |
+
if accelerator is not None:
|
| 113 |
+
device = accelerator.device
|
| 114 |
+
edge_extractor_manager.set_device(device)
|
| 115 |
+
else:
|
| 116 |
+
device = "cpu"
|
| 117 |
+
|
| 118 |
+
# Transforms
|
| 119 |
+
to_tensor_and_norm = transforms.Compose([
|
| 120 |
+
transforms.ToTensor(),
|
| 121 |
+
transforms.Normalize([0.5], [0.5]),
|
| 122 |
+
])
|
| 123 |
+
|
| 124 |
+
# 与 jsonl_datasets_edge.py 保持一致:Resize -> CenterCrop -> ToTensor -> Normalize
|
| 125 |
+
cond_train_transforms = transforms.Compose([
|
| 126 |
+
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 127 |
+
transforms.CenterCrop((size, size)),
|
| 128 |
+
transforms.ToTensor(),
|
| 129 |
+
transforms.Normalize([0.5], [0.5]),
|
| 130 |
+
])
|
| 131 |
+
|
| 132 |
+
tokenizer_clip = tokenizers[0]
|
| 133 |
+
tokenizer_t5 = tokenizers[1]
|
| 134 |
+
|
| 135 |
+
def tokenize_prompt_clip_t5(examples):
|
| 136 |
+
captions_raw = examples[caption_col]
|
| 137 |
+
captions = []
|
| 138 |
+
for c in captions_raw:
|
| 139 |
+
if isinstance(c, str):
|
| 140 |
+
if random.random() < 0.25:
|
| 141 |
+
captions.append("")
|
| 142 |
+
else:
|
| 143 |
+
captions.append(c)
|
| 144 |
+
else:
|
| 145 |
+
captions.append("")
|
| 146 |
+
|
| 147 |
+
text_inputs_clip = tokenizer_clip(
|
| 148 |
+
captions,
|
| 149 |
+
padding="max_length",
|
| 150 |
+
max_length=77,
|
| 151 |
+
truncation=True,
|
| 152 |
+
return_length=False,
|
| 153 |
+
return_overflowing_tokens=False,
|
| 154 |
+
return_tensors="pt",
|
| 155 |
+
)
|
| 156 |
+
text_input_ids_1 = text_inputs_clip.input_ids
|
| 157 |
+
|
| 158 |
+
text_inputs_t5 = tokenizer_t5(
|
| 159 |
+
captions,
|
| 160 |
+
padding="max_length",
|
| 161 |
+
max_length=128,
|
| 162 |
+
truncation=True,
|
| 163 |
+
return_length=False,
|
| 164 |
+
return_overflowing_tokens=False,
|
| 165 |
+
return_tensors="pt",
|
| 166 |
+
)
|
| 167 |
+
text_input_ids_2 = text_inputs_t5.input_ids
|
| 168 |
+
return text_input_ids_1, text_input_ids_2
|
| 169 |
+
|
| 170 |
+
def preprocess_train(examples):
|
| 171 |
+
batch = {}
|
| 172 |
+
|
| 173 |
+
img_paths = examples[image_col]
|
| 174 |
+
|
| 175 |
+
target_tensors = []
|
| 176 |
+
cond_tensors = []
|
| 177 |
+
|
| 178 |
+
for p in img_paths:
|
| 179 |
+
# Load image by joining with root in load_image_safely
|
| 180 |
+
img = load_image_safely(p, size)
|
| 181 |
+
img = img.convert("RGB")
|
| 182 |
+
|
| 183 |
+
# Resize to Kontext preferred resolution for target
|
| 184 |
+
w, h = img.size
|
| 185 |
+
best_w, best_h = choose_kontext_resolution_from_wh(w, h)
|
| 186 |
+
img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR)
|
| 187 |
+
target_tensor = to_tensor_and_norm(img_rs)
|
| 188 |
+
|
| 189 |
+
# Build edge condition
|
| 190 |
+
extractor_name, extractor = random.choice(edge_extractor_manager.get_edge_extractors())
|
| 191 |
+
img_np = np.array(img)
|
| 192 |
+
if extractor_name == "informative":
|
| 193 |
+
edge = extractor(img_np, style="contour")
|
| 194 |
+
else:
|
| 195 |
+
edge = extractor(img_np)
|
| 196 |
+
|
| 197 |
+
if extractor_name == "ted":
|
| 198 |
+
th = 128
|
| 199 |
+
else:
|
| 200 |
+
th = 32
|
| 201 |
+
|
| 202 |
+
edge_np = np.array(edge) if isinstance(edge, Image.Image) else edge
|
| 203 |
+
if edge_np.ndim == 3:
|
| 204 |
+
edge_np = edge_np[..., 0]
|
| 205 |
+
edge_bin = (edge_np > th).astype(np.float32)
|
| 206 |
+
edge_pil = Image.fromarray((edge_bin * 255).astype(np.uint8))
|
| 207 |
+
edge_tensor = cond_train_transforms(edge_pil)
|
| 208 |
+
edge_tensor = edge_tensor.repeat(3, 1, 1)
|
| 209 |
+
|
| 210 |
+
target_tensors.append(target_tensor)
|
| 211 |
+
cond_tensors.append(edge_tensor)
|
| 212 |
+
|
| 213 |
+
batch["pixel_values"] = target_tensors
|
| 214 |
+
batch["cond_pixel_values"] = cond_tensors
|
| 215 |
+
|
| 216 |
+
batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
|
| 217 |
+
return batch
|
| 218 |
+
|
| 219 |
+
if accelerator is not None:
|
| 220 |
+
with accelerator.main_process_first():
|
| 221 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 222 |
+
else:
|
| 223 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 224 |
+
|
| 225 |
+
return train_dataset
|
train/src/jsonl_datasets_kontext_interactive_lora.py
ADDED
|
@@ -0,0 +1,1332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
import torchvision.transforms.functional as TF
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def collate_fn(examples):
|
| 15 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 16 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 17 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 18 |
+
else:
|
| 19 |
+
cond_pixel_values = None
|
| 20 |
+
|
| 21 |
+
if examples[0].get("source_pixel_values") is not None:
|
| 22 |
+
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
|
| 23 |
+
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 24 |
+
else:
|
| 25 |
+
source_pixel_values = None
|
| 26 |
+
|
| 27 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 28 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 29 |
+
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
|
| 30 |
+
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
|
| 31 |
+
|
| 32 |
+
mask_values = None
|
| 33 |
+
if examples[0].get("mask_values") is not None:
|
| 34 |
+
mask_values = torch.stack([example["mask_values"] for example in examples])
|
| 35 |
+
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
|
| 36 |
+
|
| 37 |
+
return {
|
| 38 |
+
"cond_pixel_values": cond_pixel_values,
|
| 39 |
+
"source_pixel_values": source_pixel_values,
|
| 40 |
+
"pixel_values": target_pixel_values,
|
| 41 |
+
"text_ids_1": token_ids_clip,
|
| 42 |
+
"text_ids_2": token_ids_t5,
|
| 43 |
+
"mask_values": mask_values,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _resolve_jsonl(path_str: str):
|
| 48 |
+
if path_str is None or str(path_str).strip() == "":
|
| 49 |
+
raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.")
|
| 50 |
+
if os.path.isdir(path_str):
|
| 51 |
+
files = [
|
| 52 |
+
os.path.join(path_str, f)
|
| 53 |
+
for f in os.listdir(path_str)
|
| 54 |
+
if f.lower().endswith((".jsonl", ".json"))
|
| 55 |
+
]
|
| 56 |
+
if not files:
|
| 57 |
+
raise ValueError(f"No .json or .jsonl files found under directory: {path_str}")
|
| 58 |
+
return {"train": sorted(files)}
|
| 59 |
+
if not os.path.exists(path_str):
|
| 60 |
+
raise FileNotFoundError(f"train_data_jsonl not found: {path_str}")
|
| 61 |
+
return {"train": [path_str]}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _tokenize(tokenizers, caption: str):
|
| 65 |
+
tokenizer_clip = tokenizers[0]
|
| 66 |
+
tokenizer_t5 = tokenizers[1]
|
| 67 |
+
text_inputs_clip = tokenizer_clip(
|
| 68 |
+
[caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
| 69 |
+
)
|
| 70 |
+
text_inputs_t5 = tokenizer_t5(
|
| 71 |
+
[caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt"
|
| 72 |
+
)
|
| 73 |
+
return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _prepend_caption(caption: str) -> str:
|
| 77 |
+
"""Prepend instruction and keep only instruction with 20% prob."""
|
| 78 |
+
instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary."
|
| 79 |
+
if random.random() < 0.2:
|
| 80 |
+
return instruction
|
| 81 |
+
caption = caption or ""
|
| 82 |
+
if caption.strip():
|
| 83 |
+
return f"{instruction} {caption.strip()}"
|
| 84 |
+
return instruction
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _color_augment(pil_img: Image.Image) -> Image.Image:
|
| 88 |
+
brightness = random.uniform(0.8, 1.2)
|
| 89 |
+
contrast = random.uniform(0.8, 1.2)
|
| 90 |
+
saturation = random.uniform(0.8, 1.2)
|
| 91 |
+
hue = random.uniform(-0.05, 0.05)
|
| 92 |
+
img = TF.adjust_brightness(pil_img, brightness)
|
| 93 |
+
img = TF.adjust_contrast(img, contrast)
|
| 94 |
+
img = TF.adjust_saturation(img, saturation)
|
| 95 |
+
img = TF.adjust_hue(img, hue)
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _dilate_mask(mask_bin: np.ndarray, min_px: int = 5, max_px: int = 100) -> np.ndarray:
|
| 100 |
+
"""Grow binary mask by a random radius in [min_px, max_px]. Expects values {0,1}."""
|
| 101 |
+
import cv2
|
| 102 |
+
radius = int(max(min_px, min(max_px, random.randint(min_px, max_px))))
|
| 103 |
+
if radius <= 0:
|
| 104 |
+
return mask_bin.astype(np.uint8)
|
| 105 |
+
ksize = 2 * radius + 1
|
| 106 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
|
| 107 |
+
grown = cv2.dilate(mask_bin.astype(np.uint8), kernel, iterations=1)
|
| 108 |
+
return (grown > 0).astype(np.uint8)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _random_point_inside_mask(mask_bin: np.ndarray) -> tuple:
|
| 112 |
+
ys, xs = np.where(mask_bin > 0)
|
| 113 |
+
if len(xs) == 0:
|
| 114 |
+
h, w = mask_bin.shape
|
| 115 |
+
return w // 2, h // 2
|
| 116 |
+
idx = random.randrange(len(xs))
|
| 117 |
+
return int(xs[idx]), int(ys[idx])
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _bbox_containing_mask(mask_bin: np.ndarray, img_w: int, img_h: int) -> tuple:
|
| 121 |
+
ys, xs = np.where(mask_bin > 0)
|
| 122 |
+
if len(xs) == 0:
|
| 123 |
+
return 0, 0, img_w - 1, img_h - 1
|
| 124 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 125 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 126 |
+
# Random padding
|
| 127 |
+
max_pad = int(0.25 * min(img_w, img_h))
|
| 128 |
+
pad_x1 = random.randint(0, max_pad)
|
| 129 |
+
pad_x2 = random.randint(0, max_pad)
|
| 130 |
+
pad_y1 = random.randint(0, max_pad)
|
| 131 |
+
pad_y2 = random.randint(0, max_pad)
|
| 132 |
+
x1 = max(0, x1 - pad_x1)
|
| 133 |
+
y1 = max(0, y1 - pad_y1)
|
| 134 |
+
x2 = min(img_w - 1, x2 + pad_x2)
|
| 135 |
+
y2 = min(img_h - 1, y2 + pad_y2)
|
| 136 |
+
return x1, y1, x2, y2
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _constrained_random_mask(mask_bin: np.ndarray, image_h: int, image_w: int, aug_prob: float = 0.7) -> np.ndarray:
|
| 140 |
+
"""Generate random mask whose box contains or starts in m_p, and brush strokes start inside m_p.
|
| 141 |
+
Returns binary 0/1 array of shape (H,W).
|
| 142 |
+
"""
|
| 143 |
+
import cv2
|
| 144 |
+
if random.random() >= aug_prob:
|
| 145 |
+
return np.zeros((image_h, image_w), dtype=np.uint8)
|
| 146 |
+
|
| 147 |
+
# Scale similar to reference
|
| 148 |
+
ref_size = 1024
|
| 149 |
+
scale_factor = max(1.0, min(image_h, image_w) / float(ref_size))
|
| 150 |
+
|
| 151 |
+
out = np.zeros((image_h, image_w), dtype=np.uint8)
|
| 152 |
+
|
| 153 |
+
# Choose exactly one augmentation: bbox OR stroke
|
| 154 |
+
if random.random() < 0.2:
|
| 155 |
+
# BBox augmentation: draw N boxes (randomized), first box often contains mask
|
| 156 |
+
num_boxes = random.randint(1, 6)
|
| 157 |
+
for b in range(num_boxes):
|
| 158 |
+
if b == 0 and random.random() < 0.5:
|
| 159 |
+
x1, y1, x2, y2 = _bbox_containing_mask(mask_bin, image_w, image_h)
|
| 160 |
+
else:
|
| 161 |
+
sx, sy = _random_point_inside_mask(mask_bin)
|
| 162 |
+
max_w = int(500 * scale_factor)
|
| 163 |
+
min_w = int(100 * scale_factor)
|
| 164 |
+
bw = random.randint(max(1, min_w), max(2, max_w))
|
| 165 |
+
bh = random.randint(max(1, min_w), max(2, max_w))
|
| 166 |
+
x1 = max(0, sx - random.randint(0, bw))
|
| 167 |
+
y1 = max(0, sy - random.randint(0, bh))
|
| 168 |
+
x2 = min(image_w - 1, x1 + bw)
|
| 169 |
+
y2 = min(image_h - 1, y1 + bh)
|
| 170 |
+
out[y1:y2 + 1, x1:x2 + 1] = 1
|
| 171 |
+
else:
|
| 172 |
+
# Stroke augmentation: draw N strokes starting inside mask
|
| 173 |
+
num_strokes = random.randint(1, 6)
|
| 174 |
+
for _ in range(num_strokes):
|
| 175 |
+
num_points = random.randint(10, 30)
|
| 176 |
+
stroke_width = random.randint(max(1, int(100 * scale_factor)), max(2, int(400 * scale_factor)))
|
| 177 |
+
max_offset = max(1, int(100 * scale_factor))
|
| 178 |
+
start_x, start_y = _random_point_inside_mask(mask_bin)
|
| 179 |
+
px, py = start_x, start_y
|
| 180 |
+
for _ in range(num_points):
|
| 181 |
+
dx = random.randint(-max_offset, max_offset)
|
| 182 |
+
dy = random.randint(-max_offset, max_offset)
|
| 183 |
+
nx = int(np.clip(px + dx, 0, image_w - 1))
|
| 184 |
+
ny = int(np.clip(py + dy, 0, image_h - 1))
|
| 185 |
+
cv2.line(out, (px, py), (nx, ny), 1, stroke_width)
|
| 186 |
+
px, py = nx, ny
|
| 187 |
+
|
| 188 |
+
return (out > 0).astype(np.uint8)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def make_placement_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
|
| 192 |
+
"""
|
| 193 |
+
Dataset for JSONL with fields:
|
| 194 |
+
- generated_image_path: relative to base_dir (target image with object)
|
| 195 |
+
- mask_path: relative to base_dir (mask of object)
|
| 196 |
+
- generated_width, generated_height: image dimensions
|
| 197 |
+
- final_prompt: caption
|
| 198 |
+
- relight_images: list of {mode, path} for relighted versions
|
| 199 |
+
|
| 200 |
+
source image construction:
|
| 201 |
+
- background is target_image with a hole punched by grown mask
|
| 202 |
+
- foreground is randomly selected from relight_images with weights
|
| 203 |
+
- includes perspective transformation (moved from interactive dataset)
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
base_dir: Base directory for resolving relative paths. If None, uses args.placement_base_dir.
|
| 207 |
+
"""
|
| 208 |
+
if base_dir is None:
|
| 209 |
+
base_dir = getattr(args, "placement_base_dir")
|
| 210 |
+
|
| 211 |
+
data_files = _resolve_jsonl(getattr(args, "placement_data_jsonl", None))
|
| 212 |
+
file_paths = data_files.get("train", [])
|
| 213 |
+
records = []
|
| 214 |
+
for p in file_paths:
|
| 215 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 216 |
+
for line in f:
|
| 217 |
+
line = line.strip()
|
| 218 |
+
if not line:
|
| 219 |
+
continue
|
| 220 |
+
try:
|
| 221 |
+
obj = json.loads(line)
|
| 222 |
+
except Exception:
|
| 223 |
+
try:
|
| 224 |
+
obj = json.loads(line.rstrip(","))
|
| 225 |
+
except Exception:
|
| 226 |
+
continue
|
| 227 |
+
# Keep only fields we need
|
| 228 |
+
pruned = {
|
| 229 |
+
"generated_image_path": obj.get("generated_image_path"),
|
| 230 |
+
"mask_path": obj.get("mask_path"),
|
| 231 |
+
"generated_width": obj.get("generated_width"),
|
| 232 |
+
"generated_height": obj.get("generated_height"),
|
| 233 |
+
"final_prompt": obj.get("final_prompt"),
|
| 234 |
+
"relight_images": obj.get("relight_images"),
|
| 235 |
+
}
|
| 236 |
+
records.append(pruned)
|
| 237 |
+
|
| 238 |
+
size = int(getattr(args, "cond_size", 512))
|
| 239 |
+
|
| 240 |
+
to_tensor_and_norm = transforms.Compose([
|
| 241 |
+
transforms.ToTensor(),
|
| 242 |
+
transforms.Normalize([0.5], [0.5]),
|
| 243 |
+
])
|
| 244 |
+
|
| 245 |
+
class PlacementDataset(torch.utils.data.Dataset):
|
| 246 |
+
def __init__(self, hf_ds, base_dir):
|
| 247 |
+
self.ds = hf_ds
|
| 248 |
+
self.base_dir = base_dir
|
| 249 |
+
def __len__(self):
|
| 250 |
+
# Triplicate sampling per record
|
| 251 |
+
return len(self.ds)
|
| 252 |
+
def __getitem__(self, idx):
|
| 253 |
+
rec = self.ds[idx % len(self.ds)]
|
| 254 |
+
|
| 255 |
+
t_rel = rec.get("generated_image_path", "")
|
| 256 |
+
m_rel = rec.get("mask_path", "")
|
| 257 |
+
|
| 258 |
+
# Both are relative paths
|
| 259 |
+
t_p = os.path.join(self.base_dir, t_rel)
|
| 260 |
+
m_p = os.path.join(self.base_dir, m_rel)
|
| 261 |
+
|
| 262 |
+
import cv2
|
| 263 |
+
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
|
| 264 |
+
if mask_loaded is None:
|
| 265 |
+
raise ValueError(f"Failed to read mask: {m_p}")
|
| 266 |
+
|
| 267 |
+
tgt_img = Image.open(t_p).convert("RGB")
|
| 268 |
+
|
| 269 |
+
fw = int(rec.get("generated_width", tgt_img.width))
|
| 270 |
+
fh = int(rec.get("generated_height", tgt_img.height))
|
| 271 |
+
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 272 |
+
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
|
| 273 |
+
|
| 274 |
+
target_tensor = to_tensor_and_norm(tgt_img)
|
| 275 |
+
|
| 276 |
+
# Binary mask at final_size
|
| 277 |
+
mask_np = np.array(mask_img)
|
| 278 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 279 |
+
|
| 280 |
+
# 1) Grow mask by random 50-100 pixels
|
| 281 |
+
grown_mask = _dilate_mask(mask_bin, 50, 200)
|
| 282 |
+
|
| 283 |
+
# 2) Optional random augmentation mask constrained by mask
|
| 284 |
+
rand_mask = _constrained_random_mask(mask_bin, fh, fw, 7)
|
| 285 |
+
|
| 286 |
+
# 3) Final union mask
|
| 287 |
+
union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
|
| 288 |
+
tgt_np = np.array(tgt_img)
|
| 289 |
+
|
| 290 |
+
# Helper: choose relighted image from relight_images with weights
|
| 291 |
+
def _choose_relight_image(rec, width, height):
|
| 292 |
+
relight_list = rec.get("relight_images") or []
|
| 293 |
+
# Build map mode -> path
|
| 294 |
+
mode_to_path = {}
|
| 295 |
+
for it in relight_list:
|
| 296 |
+
try:
|
| 297 |
+
mode = str(it.get("mode", "")).strip().lower()
|
| 298 |
+
path = it.get("path")
|
| 299 |
+
except Exception:
|
| 300 |
+
continue
|
| 301 |
+
if not mode or not path:
|
| 302 |
+
continue
|
| 303 |
+
mode_to_path[mode] = path
|
| 304 |
+
|
| 305 |
+
weighted_order = [
|
| 306 |
+
("grayscale", 0.5),
|
| 307 |
+
("low", 0.3),
|
| 308 |
+
("high", 0.2),
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
# Filter to available
|
| 312 |
+
available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
|
| 313 |
+
chosen_path = None
|
| 314 |
+
if available:
|
| 315 |
+
rnd = random.random()
|
| 316 |
+
cum = 0.0
|
| 317 |
+
total_w = sum(w for _, w in available)
|
| 318 |
+
for m, w in available:
|
| 319 |
+
cum += w / total_w
|
| 320 |
+
if rnd <= cum:
|
| 321 |
+
chosen_path = mode_to_path.get(m)
|
| 322 |
+
break
|
| 323 |
+
if chosen_path is None:
|
| 324 |
+
chosen_path = mode_to_path.get(available[-1][0])
|
| 325 |
+
else:
|
| 326 |
+
# Fallback to any provided path
|
| 327 |
+
if mode_to_path:
|
| 328 |
+
chosen_path = next(iter(mode_to_path.values()))
|
| 329 |
+
|
| 330 |
+
# Open chosen image
|
| 331 |
+
if chosen_path is not None:
|
| 332 |
+
try:
|
| 333 |
+
# Paths are relative to base_dir
|
| 334 |
+
open_path = os.path.join(self.base_dir, chosen_path)
|
| 335 |
+
img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
|
| 336 |
+
return img
|
| 337 |
+
except Exception:
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
# Fallback: return target image
|
| 341 |
+
return Image.open(t_p).convert("RGB").resize((width, height), resample=Image.BILINEAR)
|
| 342 |
+
|
| 343 |
+
# Choose base image with probabilities:
|
| 344 |
+
# 20%: original target, 20%: color augment(target), 60%: relight augment
|
| 345 |
+
rsel = random.random()
|
| 346 |
+
if rsel < 0.2:
|
| 347 |
+
base_img = tgt_img
|
| 348 |
+
elif rsel < 0.4:
|
| 349 |
+
base_img = _color_augment(tgt_img)
|
| 350 |
+
else:
|
| 351 |
+
base_img = _choose_relight_image(rec, fw, fh)
|
| 352 |
+
base_np = np.array(base_img)
|
| 353 |
+
fore_np = base_np.copy()
|
| 354 |
+
|
| 355 |
+
# Random perspective augmentation (50%): apply to foreground ROI (mask bbox) and its mask only
|
| 356 |
+
perspective_applied = False
|
| 357 |
+
roi_update = None
|
| 358 |
+
paste_mask_bool = mask_bin.astype(bool)
|
| 359 |
+
if random.random() < 0.5:
|
| 360 |
+
try:
|
| 361 |
+
import cv2
|
| 362 |
+
ys, xs = np.where(mask_bin > 0)
|
| 363 |
+
if len(xs) > 0 and len(ys) > 0:
|
| 364 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 365 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 366 |
+
if x2 > x1 and y2 > y1:
|
| 367 |
+
roi = base_np[y1:y2 + 1, x1:x2 + 1]
|
| 368 |
+
roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
|
| 369 |
+
bh, bw = roi.shape[:2]
|
| 370 |
+
# Random perturbation relative to ROI size
|
| 371 |
+
max_ratio = random.uniform(0.1, 0.3)
|
| 372 |
+
dx = bw * max_ratio
|
| 373 |
+
dy = bh * max_ratio
|
| 374 |
+
src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
|
| 375 |
+
dst = np.float32([
|
| 376 |
+
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
|
| 377 |
+
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
|
| 378 |
+
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
|
| 379 |
+
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
|
| 380 |
+
])
|
| 381 |
+
M = cv2.getPerspectiveTransform(src, dst)
|
| 382 |
+
warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
|
| 383 |
+
warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
|
| 384 |
+
# Build a fresh foreground canvas
|
| 385 |
+
fore_np = np.zeros_like(base_np)
|
| 386 |
+
h_warp, w_warp = warped_mask_roi.shape
|
| 387 |
+
y2c = y1 + h_warp
|
| 388 |
+
x2c = x1 + w_warp
|
| 389 |
+
fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
|
| 390 |
+
paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
|
| 391 |
+
paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
|
| 392 |
+
roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
|
| 393 |
+
perspective_applied = True
|
| 394 |
+
except Exception:
|
| 395 |
+
perspective_applied = False
|
| 396 |
+
paste_mask_bool = mask_bin.astype(bool)
|
| 397 |
+
fore_np = base_np
|
| 398 |
+
|
| 399 |
+
# Optional: simulate resolution artifacts
|
| 400 |
+
if random.random() < 0.7:
|
| 401 |
+
ys, xs = np.where(paste_mask_bool)
|
| 402 |
+
if len(xs) > 0 and len(ys) > 0:
|
| 403 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 404 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 405 |
+
if x2 > x1 and y2 > y1:
|
| 406 |
+
crop = fore_np[y1:y2 + 1, x1:x2 + 1]
|
| 407 |
+
ch, cw = crop.shape[:2]
|
| 408 |
+
scale = random.uniform(0.15, 0.9)
|
| 409 |
+
dw = max(1, int(cw * scale))
|
| 410 |
+
dh = max(1, int(ch * scale))
|
| 411 |
+
try:
|
| 412 |
+
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
|
| 413 |
+
back = small.resize((cw, ch), Image.BICUBIC)
|
| 414 |
+
crop_blurred = np.array(back).astype(np.uint8)
|
| 415 |
+
fore_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
|
| 416 |
+
except Exception:
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
# Build masked target and compose
|
| 420 |
+
union_mask_for_target = union_mask.copy()
|
| 421 |
+
if roi_update is not None:
|
| 422 |
+
rx, ry, rh, rw, warped_mask_roi = roi_update
|
| 423 |
+
um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
|
| 424 |
+
union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
|
| 425 |
+
masked_t_np = tgt_np.copy()
|
| 426 |
+
masked_t_np[union_mask_for_target.astype(bool)] = 255
|
| 427 |
+
composed_np = masked_t_np.copy()
|
| 428 |
+
m_fore = paste_mask_bool
|
| 429 |
+
composed_np[m_fore] = fore_np[m_fore]
|
| 430 |
+
|
| 431 |
+
# Build tensors
|
| 432 |
+
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
|
| 433 |
+
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
|
| 434 |
+
|
| 435 |
+
# Caption: prepend instruction
|
| 436 |
+
cap_orig = rec.get("final_prompt", "") or ""
|
| 437 |
+
# Handle list format in final_prompt
|
| 438 |
+
if isinstance(cap_orig, list) and len(cap_orig) > 0:
|
| 439 |
+
cap_orig = cap_orig[0] if isinstance(cap_orig[0], str) else str(cap_orig[0])
|
| 440 |
+
cap = _prepend_caption(cap_orig)
|
| 441 |
+
if perspective_applied:
|
| 442 |
+
cap = f"{cap} Fix the perspective if necessary."
|
| 443 |
+
ids1, ids2 = _tokenize(tokenizers, cap)
|
| 444 |
+
|
| 445 |
+
return {
|
| 446 |
+
"source_pixel_values": source_tensor,
|
| 447 |
+
"pixel_values": target_tensor,
|
| 448 |
+
"token_ids_clip": ids1,
|
| 449 |
+
"token_ids_t5": ids2,
|
| 450 |
+
"mask_values": mask_tensor,
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
return PlacementDataset(records, base_dir)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def make_interactive_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
|
| 457 |
+
"""
|
| 458 |
+
Dataset for JSONL with fields:
|
| 459 |
+
- input_path: relative to base_dir (target image)
|
| 460 |
+
- output_path: absolute path to image with foreground
|
| 461 |
+
- mask_after_completion: absolute path to mask
|
| 462 |
+
- img_width, img_height: resize dimensions
|
| 463 |
+
- prompt: caption
|
| 464 |
+
|
| 465 |
+
source image construction:
|
| 466 |
+
- background is target_image with a hole punched by grown `mask_after_completion`
|
| 467 |
+
- foreground is from `output_path` image, pasted using original `mask_after_completion`
|
| 468 |
+
- 50% chance to color augment the foreground source
|
| 469 |
+
- NO perspective transform (moved to placement dataset)
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
base_dir: Base directory for resolving relative paths. If None, uses args.interactive_base_dir.
|
| 473 |
+
"""
|
| 474 |
+
if base_dir is None:
|
| 475 |
+
base_dir = getattr(args, "interactive_base_dir")
|
| 476 |
+
|
| 477 |
+
data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None))
|
| 478 |
+
file_paths = data_files.get("train", [])
|
| 479 |
+
records = []
|
| 480 |
+
for p in file_paths:
|
| 481 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 482 |
+
for line in f:
|
| 483 |
+
line = line.strip()
|
| 484 |
+
if not line:
|
| 485 |
+
continue
|
| 486 |
+
try:
|
| 487 |
+
obj = json.loads(line)
|
| 488 |
+
except Exception:
|
| 489 |
+
# Best-effort: strip any trailing commas and retry
|
| 490 |
+
try:
|
| 491 |
+
obj = json.loads(line.rstrip(","))
|
| 492 |
+
except Exception:
|
| 493 |
+
continue
|
| 494 |
+
# Keep only fields we actually need to avoid schema issues
|
| 495 |
+
pruned = {
|
| 496 |
+
"input_path": obj.get("input_path"),
|
| 497 |
+
"output_path": obj.get("output_path"),
|
| 498 |
+
"mask_after_completion": obj.get("mask_after_completion"),
|
| 499 |
+
"img_width": obj.get("img_width"),
|
| 500 |
+
"img_height": obj.get("img_height"),
|
| 501 |
+
"prompt": obj.get("prompt"),
|
| 502 |
+
# New optional fields
|
| 503 |
+
"generated_images": obj.get("generated_images"),
|
| 504 |
+
"positive_prompt_used": obj.get("positive_prompt_used"),
|
| 505 |
+
"negative_caption_used": obj.get("negative_caption_used"),
|
| 506 |
+
}
|
| 507 |
+
records.append(pruned)
|
| 508 |
+
|
| 509 |
+
size = int(getattr(args, "cond_size", 512))
|
| 510 |
+
|
| 511 |
+
to_tensor_and_norm = transforms.Compose([
|
| 512 |
+
transforms.ToTensor(),
|
| 513 |
+
transforms.Normalize([0.5], [0.5]),
|
| 514 |
+
])
|
| 515 |
+
|
| 516 |
+
class SubjectsDataset(torch.utils.data.Dataset):
|
| 517 |
+
def __init__(self, hf_ds, base_dir):
|
| 518 |
+
self.ds = hf_ds
|
| 519 |
+
self.base_dir = base_dir
|
| 520 |
+
def __len__(self):
|
| 521 |
+
# Triplicate sampling per record
|
| 522 |
+
return len(self.ds)
|
| 523 |
+
def __getitem__(self, idx):
|
| 524 |
+
rec = self.ds[idx % len(self.ds)]
|
| 525 |
+
|
| 526 |
+
t_rel = rec.get("input_path", "")
|
| 527 |
+
foreground_p = rec.get("output_path", "")
|
| 528 |
+
m_abs = rec.get("mask_after_completion", "")
|
| 529 |
+
|
| 530 |
+
if not os.path.isabs(m_abs):
|
| 531 |
+
raise ValueError("mask_after_completion must be absolute")
|
| 532 |
+
if not os.path.isabs(foreground_p):
|
| 533 |
+
raise ValueError("output_path must be absolute")
|
| 534 |
+
|
| 535 |
+
t_p = os.path.join(self.base_dir, t_rel)
|
| 536 |
+
m_p = m_abs
|
| 537 |
+
|
| 538 |
+
import cv2
|
| 539 |
+
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
|
| 540 |
+
if mask_loaded is None:
|
| 541 |
+
raise ValueError(f"Failed to read mask: {m_p}")
|
| 542 |
+
|
| 543 |
+
tgt_img = Image.open(t_p).convert("RGB")
|
| 544 |
+
foreground_source_img = Image.open(foreground_p).convert("RGB")
|
| 545 |
+
|
| 546 |
+
fw = int(rec.get("img_width", tgt_img.width))
|
| 547 |
+
fh = int(rec.get("img_height", tgt_img.height))
|
| 548 |
+
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 549 |
+
foreground_source_img = foreground_source_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 550 |
+
mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
|
| 551 |
+
|
| 552 |
+
# Ensure PIL images to tensors for outputs based on new logic later
|
| 553 |
+
target_tensor = to_tensor_and_norm(tgt_img)
|
| 554 |
+
|
| 555 |
+
# Binary mask at final_size
|
| 556 |
+
mask_np = np.array(mask_img)
|
| 557 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 558 |
+
|
| 559 |
+
# 1) Grow m_p by random 50-100 pixels
|
| 560 |
+
grown_mask = _dilate_mask(mask_bin, 50, 200)
|
| 561 |
+
|
| 562 |
+
# 2) Optional random augmentation mask constrained by m_p
|
| 563 |
+
rand_mask = _constrained_random_mask(mask_bin, fh, fw, aug_prob=0.7)
|
| 564 |
+
|
| 565 |
+
# 3) Final union mask
|
| 566 |
+
union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
|
| 567 |
+
tgt_np = np.array(tgt_img)
|
| 568 |
+
|
| 569 |
+
# Helper: choose relighted image from generated_images with weights
|
| 570 |
+
def _choose_relight_image(rec, default_img, width, height):
|
| 571 |
+
gen_list = rec.get("generated_images") or []
|
| 572 |
+
# Build map mode -> path
|
| 573 |
+
mode_to_path = {}
|
| 574 |
+
for it in gen_list:
|
| 575 |
+
try:
|
| 576 |
+
mode = str(it.get("mode", "")).strip().lower()
|
| 577 |
+
path = it.get("path")
|
| 578 |
+
except Exception:
|
| 579 |
+
continue
|
| 580 |
+
if not mode or not path:
|
| 581 |
+
continue
|
| 582 |
+
mode_to_path[mode] = path
|
| 583 |
+
|
| 584 |
+
# Weighted selection among available modes
|
| 585 |
+
weighted_order = [
|
| 586 |
+
("grayscale", 0.5),
|
| 587 |
+
("low", 0.3),
|
| 588 |
+
("high", 0.2),
|
| 589 |
+
]
|
| 590 |
+
|
| 591 |
+
# Filter to available
|
| 592 |
+
available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
|
| 593 |
+
chosen_path = None
|
| 594 |
+
if available:
|
| 595 |
+
rnd = random.random()
|
| 596 |
+
cum = 0.0
|
| 597 |
+
total_w = sum(w for _, w in available)
|
| 598 |
+
for m, w in available:
|
| 599 |
+
cum += w / total_w
|
| 600 |
+
if rnd <= cum:
|
| 601 |
+
chosen_path = mode_to_path.get(m)
|
| 602 |
+
break
|
| 603 |
+
if chosen_path is None:
|
| 604 |
+
chosen_path = mode_to_path.get(available[-1][0])
|
| 605 |
+
else:
|
| 606 |
+
# Fallback to any provided path
|
| 607 |
+
if mode_to_path:
|
| 608 |
+
chosen_path = next(iter(mode_to_path.values()))
|
| 609 |
+
|
| 610 |
+
# Open chosen image
|
| 611 |
+
if chosen_path is not None:
|
| 612 |
+
try:
|
| 613 |
+
open_path = chosen_path
|
| 614 |
+
# generated paths are typically absolute; if not, use as-is
|
| 615 |
+
img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
|
| 616 |
+
return img
|
| 617 |
+
except Exception:
|
| 618 |
+
pass
|
| 619 |
+
|
| 620 |
+
return default_img
|
| 621 |
+
|
| 622 |
+
# 5) Choose base image with probabilities:
|
| 623 |
+
# 20%: original, 20%: color augment(original), 60%: relight augment
|
| 624 |
+
rsel = random.random()
|
| 625 |
+
if rsel < 0.2:
|
| 626 |
+
base_img = foreground_source_img
|
| 627 |
+
elif rsel < 0.4:
|
| 628 |
+
base_img = _color_augment(foreground_source_img)
|
| 629 |
+
else:
|
| 630 |
+
base_img = _choose_relight_image(rec, foreground_source_img, fw, fh)
|
| 631 |
+
base_np = np.array(base_img)
|
| 632 |
+
|
| 633 |
+
# 5.1) Random perspective augmentation (20%): apply to foreground ROI (mask bbox) and its mask only
|
| 634 |
+
perspective_applied = False
|
| 635 |
+
roi_update = None
|
| 636 |
+
paste_mask_bool = mask_bin.astype(bool)
|
| 637 |
+
if random.random() < 0.5:
|
| 638 |
+
try:
|
| 639 |
+
import cv2
|
| 640 |
+
ys, xs = np.where(mask_bin > 0)
|
| 641 |
+
if len(xs) > 0 and len(ys) > 0:
|
| 642 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 643 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 644 |
+
if x2 > x1 and y2 > y1:
|
| 645 |
+
roi = base_np[y1:y2 + 1, x1:x2 + 1]
|
| 646 |
+
roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
|
| 647 |
+
bh, bw = roi.shape[:2]
|
| 648 |
+
# Random perturbation relative to ROI size
|
| 649 |
+
max_ratio = random.uniform(0.1, 0.3)
|
| 650 |
+
dx = bw * max_ratio
|
| 651 |
+
dy = bh * max_ratio
|
| 652 |
+
src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
|
| 653 |
+
dst = np.float32([
|
| 654 |
+
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
|
| 655 |
+
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
|
| 656 |
+
[np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
|
| 657 |
+
[np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
|
| 658 |
+
])
|
| 659 |
+
M = cv2.getPerspectiveTransform(src, dst)
|
| 660 |
+
warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
|
| 661 |
+
warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
|
| 662 |
+
# Build a fresh foreground canvas
|
| 663 |
+
fore_np = np.zeros_like(base_np)
|
| 664 |
+
h_warp, w_warp = warped_mask_roi.shape
|
| 665 |
+
y2c = y1 + h_warp
|
| 666 |
+
x2c = x1 + w_warp
|
| 667 |
+
fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
|
| 668 |
+
paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
|
| 669 |
+
paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
|
| 670 |
+
roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
|
| 671 |
+
perspective_applied = True
|
| 672 |
+
base_np = fore_np
|
| 673 |
+
except Exception:
|
| 674 |
+
perspective_applied = False
|
| 675 |
+
paste_mask_bool = mask_bin.astype(bool)
|
| 676 |
+
|
| 677 |
+
# Optional: simulate cut-out foregrounds coming from different resolutions by
|
| 678 |
+
# downscaling the masked foreground region and upscaling back to original size.
|
| 679 |
+
# This introduces realistic blur/aliasing seen in real inpaint workflows.
|
| 680 |
+
if random.random() < 0.7:
|
| 681 |
+
ys, xs = np.where(mask_bin > 0)
|
| 682 |
+
if len(xs) > 0 and len(ys) > 0:
|
| 683 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 684 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 685 |
+
# Ensure valid box
|
| 686 |
+
if x2 > x1 and y2 > y1:
|
| 687 |
+
crop = base_np[y1:y2 + 1, x1:x2 + 1]
|
| 688 |
+
ch, cw = crop.shape[:2]
|
| 689 |
+
scale = random.uniform(0.2, 0.9)
|
| 690 |
+
dw = max(1, int(cw * scale))
|
| 691 |
+
dh = max(1, int(ch * scale))
|
| 692 |
+
try:
|
| 693 |
+
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
|
| 694 |
+
back = small.resize((cw, ch), Image.BICUBIC)
|
| 695 |
+
crop_blurred = np.array(back).astype(np.uint8)
|
| 696 |
+
base_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
|
| 697 |
+
except Exception:
|
| 698 |
+
# Fallback: skip if resize fails
|
| 699 |
+
pass
|
| 700 |
+
|
| 701 |
+
# 6) Build masked target using (possibly) updated union mask; then paste
|
| 702 |
+
union_mask_for_target = union_mask.copy()
|
| 703 |
+
if roi_update is not None:
|
| 704 |
+
rx, ry, rh, rw, warped_mask_roi = roi_update
|
| 705 |
+
# Ensure union covers the warped foreground area inside ROI using warped shape
|
| 706 |
+
um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
|
| 707 |
+
union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
|
| 708 |
+
masked_t_np = tgt_np.copy()
|
| 709 |
+
masked_t_np[union_mask_for_target.astype(bool)] = 255
|
| 710 |
+
composed_np = masked_t_np.copy()
|
| 711 |
+
m_fore = paste_mask_bool
|
| 712 |
+
composed_np[m_fore] = base_np[m_fore]
|
| 713 |
+
|
| 714 |
+
# 7) Build tensors
|
| 715 |
+
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
|
| 716 |
+
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
|
| 717 |
+
|
| 718 |
+
# 8) Caption: prepend instruction, 20% keep only instruction
|
| 719 |
+
cap_orig = rec.get("prompt", "") or ""
|
| 720 |
+
cap = _prepend_caption(cap_orig)
|
| 721 |
+
if perspective_applied:
|
| 722 |
+
cap = f"{cap} Fix the perspective if necessary."
|
| 723 |
+
ids1, ids2 = _tokenize(tokenizers, cap)
|
| 724 |
+
|
| 725 |
+
return {
|
| 726 |
+
"source_pixel_values": source_tensor,
|
| 727 |
+
"pixel_values": target_tensor,
|
| 728 |
+
"token_ids_clip": ids1,
|
| 729 |
+
"token_ids_t5": ids2,
|
| 730 |
+
"mask_values": mask_tensor,
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
return SubjectsDataset(records, base_dir)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
def make_pexels_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
|
| 737 |
+
"""
|
| 738 |
+
Dataset for JSONL with fields:
|
| 739 |
+
- input_path: relative to base_dir (target image)
|
| 740 |
+
- output_path: relative to relight_base_dir (relighted image)
|
| 741 |
+
- final_size: {width, height} resize applied
|
| 742 |
+
- caption: text caption
|
| 743 |
+
|
| 744 |
+
Modified to use segmentation maps instead of raw_mask_path.
|
| 745 |
+
Randomly selects 2-5 segments from segmentation map, applies augmentation to each, and takes union.
|
| 746 |
+
This simulates multiple foreground objects being placed like a puzzle.
|
| 747 |
+
|
| 748 |
+
Each segment independently uses: 20% original, 20% color_augment, 60% relighted image.
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
base_dir: Base directory for resolving relative paths. If None, uses args.pexels_base_dir.
|
| 752 |
+
"""
|
| 753 |
+
if base_dir is None:
|
| 754 |
+
base_dir = getattr(args, "pexels_base_dir", "/mnt/robby-b1/common/datasets")
|
| 755 |
+
|
| 756 |
+
relight_base_dir = getattr(args, "pexels_relight_base_dir", "/robby/share/Editing/lzc/data/relight_outputs")
|
| 757 |
+
seg_base_dir = getattr(args, "seg_base_dir", "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182")
|
| 758 |
+
|
| 759 |
+
data_files = _resolve_jsonl(getattr(args, "pexels_data_jsonl", None))
|
| 760 |
+
file_paths = data_files.get("train", [])
|
| 761 |
+
records = []
|
| 762 |
+
for p in file_paths:
|
| 763 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 764 |
+
for line in f:
|
| 765 |
+
line = line.strip()
|
| 766 |
+
if not line:
|
| 767 |
+
continue
|
| 768 |
+
try:
|
| 769 |
+
obj = json.loads(line)
|
| 770 |
+
except Exception:
|
| 771 |
+
try:
|
| 772 |
+
obj = json.loads(line.rstrip(","))
|
| 773 |
+
except Exception:
|
| 774 |
+
continue
|
| 775 |
+
pruned = {
|
| 776 |
+
"input_path": obj.get("input_path"),
|
| 777 |
+
"output_path": obj.get("output_path"),
|
| 778 |
+
"final_size": obj.get("final_size"),
|
| 779 |
+
"caption": obj.get("caption"),
|
| 780 |
+
}
|
| 781 |
+
records.append(pruned)
|
| 782 |
+
|
| 783 |
+
to_tensor_and_norm = transforms.Compose([
|
| 784 |
+
transforms.ToTensor(),
|
| 785 |
+
transforms.Normalize([0.5], [0.5]),
|
| 786 |
+
])
|
| 787 |
+
|
| 788 |
+
class PexelsDataset(torch.utils.data.Dataset):
|
| 789 |
+
def __init__(self, hf_ds, base_dir, relight_base_dir, seg_base_dir):
|
| 790 |
+
self.ds = hf_ds
|
| 791 |
+
self.base_dir = base_dir
|
| 792 |
+
self.relight_base_dir = relight_base_dir
|
| 793 |
+
self.seg_base_dir = seg_base_dir
|
| 794 |
+
|
| 795 |
+
def __len__(self):
|
| 796 |
+
return len(self.ds)
|
| 797 |
+
|
| 798 |
+
def _extract_hash_from_filename(self, filename: str) -> str:
|
| 799 |
+
"""Extract hash from input filename for segmentation map lookup."""
|
| 800 |
+
stem = os.path.splitext(os.path.basename(filename))[0]
|
| 801 |
+
if '_' in stem:
|
| 802 |
+
parts = stem.split('_')
|
| 803 |
+
return parts[-1]
|
| 804 |
+
return stem
|
| 805 |
+
|
| 806 |
+
def _build_segmap_path(self, input_filename: str) -> str:
|
| 807 |
+
"""Build path to segmentation map from input filename."""
|
| 808 |
+
hash_id = self._extract_hash_from_filename(input_filename)
|
| 809 |
+
return os.path.join(self.seg_base_dir, f"{hash_id}_map.png")
|
| 810 |
+
|
| 811 |
+
def _load_segmap_uint32(self, seg_path: str):
|
| 812 |
+
"""Load segmentation map as uint32 array."""
|
| 813 |
+
import cv2
|
| 814 |
+
try:
|
| 815 |
+
with Image.open(seg_path) as im:
|
| 816 |
+
if im.mode == 'P':
|
| 817 |
+
seg = np.array(im)
|
| 818 |
+
elif im.mode in ('I;16', 'I', 'L'):
|
| 819 |
+
seg = np.array(im)
|
| 820 |
+
else:
|
| 821 |
+
seg = np.array(im.convert('L'))
|
| 822 |
+
except Exception:
|
| 823 |
+
return None
|
| 824 |
+
|
| 825 |
+
if seg.ndim == 3:
|
| 826 |
+
seg = cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY)
|
| 827 |
+
return seg.astype(np.uint32)
|
| 828 |
+
|
| 829 |
+
def _extract_multiple_segments(
|
| 830 |
+
self,
|
| 831 |
+
image_h: int,
|
| 832 |
+
image_w: int,
|
| 833 |
+
seg_path: str,
|
| 834 |
+
min_area_ratio: float = 0.02,
|
| 835 |
+
max_area_ratio: float = 0.4,
|
| 836 |
+
):
|
| 837 |
+
"""Extract 2-5 individual segment masks from segmentation map."""
|
| 838 |
+
import cv2
|
| 839 |
+
seg = self._load_segmap_uint32(seg_path)
|
| 840 |
+
if seg is None:
|
| 841 |
+
return []
|
| 842 |
+
|
| 843 |
+
if seg.shape != (image_h, image_w):
|
| 844 |
+
seg = cv2.resize(seg.astype(np.uint16), (image_w, image_h), interpolation=cv2.INTER_NEAREST).astype(np.uint32)
|
| 845 |
+
|
| 846 |
+
labels, counts = np.unique(seg, return_counts=True)
|
| 847 |
+
if labels.size == 0:
|
| 848 |
+
return []
|
| 849 |
+
|
| 850 |
+
# Exclude background label 0
|
| 851 |
+
bg_mask = labels == 0
|
| 852 |
+
labels = labels[~bg_mask]
|
| 853 |
+
counts = counts[~bg_mask]
|
| 854 |
+
if labels.size == 0:
|
| 855 |
+
return []
|
| 856 |
+
|
| 857 |
+
area = image_h * image_w
|
| 858 |
+
min_px = int(round(min_area_ratio * area))
|
| 859 |
+
max_px = int(round(max_area_ratio * area))
|
| 860 |
+
keep = (counts >= min_px) & (counts <= max_px)
|
| 861 |
+
cand_labels = labels[keep]
|
| 862 |
+
if cand_labels.size == 0:
|
| 863 |
+
return []
|
| 864 |
+
|
| 865 |
+
# Select 2-5 labels randomly
|
| 866 |
+
max_sel = min(5, cand_labels.size)
|
| 867 |
+
min_sel = min(2, cand_labels.size)
|
| 868 |
+
num_to_select = random.randint(min_sel, max_sel)
|
| 869 |
+
chosen = np.random.choice(cand_labels, size=num_to_select, replace=False)
|
| 870 |
+
|
| 871 |
+
# Create individual masks for each chosen label
|
| 872 |
+
individual_masks = []
|
| 873 |
+
for lab in chosen:
|
| 874 |
+
binm = (seg == int(lab)).astype(np.uint8)
|
| 875 |
+
# Apply opening operation to clean up mask
|
| 876 |
+
k = max(3, int(round(max(image_h, image_w) * 0.01)))
|
| 877 |
+
if k % 2 == 0:
|
| 878 |
+
k += 1
|
| 879 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
| 880 |
+
eroded = cv2.erode(binm, kernel, iterations=1)
|
| 881 |
+
opened = cv2.dilate(eroded, kernel, iterations=1)
|
| 882 |
+
individual_masks.append(opened)
|
| 883 |
+
|
| 884 |
+
return individual_masks
|
| 885 |
+
|
| 886 |
+
def __getitem__(self, idx):
|
| 887 |
+
rec = self.ds[idx % len(self.ds)]
|
| 888 |
+
|
| 889 |
+
t_rel = rec.get("input_path", "")
|
| 890 |
+
r_rel = rec.get("output_path", "")
|
| 891 |
+
|
| 892 |
+
t_p = os.path.join(self.base_dir, t_rel)
|
| 893 |
+
relight_p = os.path.join(self.relight_base_dir, r_rel)
|
| 894 |
+
|
| 895 |
+
import cv2
|
| 896 |
+
tgt_img = Image.open(t_p).convert("RGB")
|
| 897 |
+
|
| 898 |
+
# Load relighted image, fallback to target if not available
|
| 899 |
+
try:
|
| 900 |
+
relighted_img = Image.open(relight_p).convert("RGB")
|
| 901 |
+
except Exception:
|
| 902 |
+
relighted_img = tgt_img.copy()
|
| 903 |
+
|
| 904 |
+
final_size = rec.get("final_size", {}) or {}
|
| 905 |
+
fw = int(final_size.get("width", tgt_img.width))
|
| 906 |
+
fh = int(final_size.get("height", tgt_img.height))
|
| 907 |
+
tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 908 |
+
relighted_img = relighted_img.resize((fw, fh), resample=Image.BILINEAR)
|
| 909 |
+
|
| 910 |
+
target_tensor = to_tensor_and_norm(tgt_img)
|
| 911 |
+
|
| 912 |
+
# Build segmentation map path and extract multiple segments
|
| 913 |
+
input_filename = os.path.basename(t_rel)
|
| 914 |
+
seg_path = self._build_segmap_path(input_filename)
|
| 915 |
+
individual_masks = self._extract_multiple_segments(fh, fw, seg_path)
|
| 916 |
+
|
| 917 |
+
if not individual_masks:
|
| 918 |
+
# Fallback: create empty mask (will be handled gracefully)
|
| 919 |
+
union_mask = np.zeros((fh, fw), dtype=np.uint8)
|
| 920 |
+
individual_masks = []
|
| 921 |
+
else:
|
| 922 |
+
# Apply augmentation to each segment mask and take union
|
| 923 |
+
augmented_masks = []
|
| 924 |
+
for seg_mask in individual_masks:
|
| 925 |
+
# 1) Grow mask by random 50-200 pixels
|
| 926 |
+
grown = _dilate_mask(seg_mask, 50, 200)
|
| 927 |
+
# 2) Optional random augmentation mask constrained by this segment
|
| 928 |
+
rand_mask = _constrained_random_mask(seg_mask, fh, fw, aug_prob=0.7)
|
| 929 |
+
# 3) Union for this segment
|
| 930 |
+
seg_union = np.clip(grown | rand_mask, 0, 1).astype(np.uint8)
|
| 931 |
+
augmented_masks.append(seg_union)
|
| 932 |
+
|
| 933 |
+
# Take union of all augmented masks
|
| 934 |
+
union_mask = np.zeros((fh, fw), dtype=np.uint8)
|
| 935 |
+
for m in augmented_masks:
|
| 936 |
+
union_mask = np.clip(union_mask | m, 0, 1).astype(np.uint8)
|
| 937 |
+
|
| 938 |
+
tgt_np = np.array(tgt_img)
|
| 939 |
+
|
| 940 |
+
# Build masked target first
|
| 941 |
+
masked_t_np = tgt_np.copy()
|
| 942 |
+
masked_t_np[union_mask.astype(bool)] = 255
|
| 943 |
+
composed_np = masked_t_np.copy()
|
| 944 |
+
|
| 945 |
+
# Process each segment independently with different augmentations
|
| 946 |
+
# This simulates multiple foreground objects from different sources
|
| 947 |
+
for seg_mask in individual_masks:
|
| 948 |
+
# 1) Choose source for this segment: 20% original, 20% color_augment, 60% relighted
|
| 949 |
+
r = random.random()
|
| 950 |
+
if r < 0.2:
|
| 951 |
+
# Original image
|
| 952 |
+
seg_source_img = tgt_img
|
| 953 |
+
else:
|
| 954 |
+
seg_source_img = _color_augment(tgt_img)
|
| 955 |
+
# elif r < 0.4:
|
| 956 |
+
# # Color augmentation
|
| 957 |
+
# seg_source_img = _color_augment(tgt_img)
|
| 958 |
+
# else:
|
| 959 |
+
# # Relighted image
|
| 960 |
+
# seg_source_img = relighted_img
|
| 961 |
+
|
| 962 |
+
seg_source_np = np.array(seg_source_img)
|
| 963 |
+
|
| 964 |
+
# 2) Apply resolution augmentation to this segment's region
|
| 965 |
+
if random.random() < 0.7:
|
| 966 |
+
ys, xs = np.where(seg_mask > 0)
|
| 967 |
+
if len(xs) > 0 and len(ys) > 0:
|
| 968 |
+
x1, x2 = int(xs.min()), int(xs.max())
|
| 969 |
+
y1, y2 = int(ys.min()), int(ys.max())
|
| 970 |
+
if x2 > x1 and y2 > y1:
|
| 971 |
+
crop = seg_source_np[y1:y2 + 1, x1:x2 + 1]
|
| 972 |
+
ch, cw = crop.shape[:2]
|
| 973 |
+
scale = random.uniform(0.2, 0.9)
|
| 974 |
+
dw = max(1, int(cw * scale))
|
| 975 |
+
dh = max(1, int(ch * scale))
|
| 976 |
+
try:
|
| 977 |
+
small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
|
| 978 |
+
back = small.resize((cw, ch), Image.BICUBIC)
|
| 979 |
+
crop_blurred = np.array(back).astype(np.uint8)
|
| 980 |
+
seg_source_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
|
| 981 |
+
except Exception:
|
| 982 |
+
pass
|
| 983 |
+
|
| 984 |
+
# 3) Paste this segment onto composed image
|
| 985 |
+
m_fore = seg_mask.astype(bool)
|
| 986 |
+
composed_np[m_fore] = seg_source_np[m_fore]
|
| 987 |
+
|
| 988 |
+
# Build tensors
|
| 989 |
+
source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
|
| 990 |
+
mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
|
| 991 |
+
|
| 992 |
+
# Caption: prepend instruction
|
| 993 |
+
cap_orig = rec.get("caption", "") or ""
|
| 994 |
+
cap = _prepend_caption(cap_orig)
|
| 995 |
+
ids1, ids2 = _tokenize(tokenizers, cap)
|
| 996 |
+
|
| 997 |
+
return {
|
| 998 |
+
"source_pixel_values": source_tensor,
|
| 999 |
+
"pixel_values": target_tensor,
|
| 1000 |
+
"token_ids_clip": ids1,
|
| 1001 |
+
"token_ids_t5": ids2,
|
| 1002 |
+
"mask_values": mask_tensor,
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
return PexelsDataset(records, base_dir, relight_base_dir, seg_base_dir)
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def make_mixed_dataset(args, tokenizers, interactive_jsonl_path=None, placement_jsonl_path=None,
|
| 1009 |
+
pexels_jsonl_path=None, interactive_base_dir=None, placement_base_dir=None,
|
| 1010 |
+
pexels_base_dir=None, interactive_weight=1.0, placement_weight=1.0,
|
| 1011 |
+
pexels_weight=1.0, accelerator=None):
|
| 1012 |
+
"""
|
| 1013 |
+
Create a mixed dataset combining interactive, placement, and pexels datasets.
|
| 1014 |
+
|
| 1015 |
+
Args:
|
| 1016 |
+
args: Arguments object with dataset configuration
|
| 1017 |
+
tokenizers: Tuple of tokenizers for text encoding
|
| 1018 |
+
interactive_jsonl_path: Path to interactive dataset JSONL (optional)
|
| 1019 |
+
placement_jsonl_path: Path to placement dataset JSONL (optional)
|
| 1020 |
+
pexels_jsonl_path: Path to pexels dataset JSONL (optional)
|
| 1021 |
+
interactive_base_dir: Base directory for interactive dataset paths (optional)
|
| 1022 |
+
placement_base_dir: Base directory for placement dataset paths (optional)
|
| 1023 |
+
pexels_base_dir: Base directory for pexels dataset paths (optional)
|
| 1024 |
+
interactive_weight: Sampling weight for interactive dataset (default: 1.0)
|
| 1025 |
+
placement_weight: Sampling weight for placement dataset (default: 1.0)
|
| 1026 |
+
pexels_weight: Sampling weight for pexels dataset (default: 1.0)
|
| 1027 |
+
accelerator: Optional accelerator object
|
| 1028 |
+
|
| 1029 |
+
Returns:
|
| 1030 |
+
Mixed dataset that samples from all provided datasets with specified weights
|
| 1031 |
+
"""
|
| 1032 |
+
datasets = []
|
| 1033 |
+
dataset_names = []
|
| 1034 |
+
dataset_weights = []
|
| 1035 |
+
|
| 1036 |
+
# Create interactive dataset if path provided
|
| 1037 |
+
if interactive_jsonl_path:
|
| 1038 |
+
interactive_args = type('Args', (), {})()
|
| 1039 |
+
for k, v in vars(args).items():
|
| 1040 |
+
setattr(interactive_args, k, v)
|
| 1041 |
+
interactive_args.train_data_jsonl = interactive_jsonl_path
|
| 1042 |
+
if interactive_base_dir:
|
| 1043 |
+
interactive_args.interactive_base_dir = interactive_base_dir
|
| 1044 |
+
interactive_ds = make_interactive_dataset_subjects(interactive_args, tokenizers, accelerator)
|
| 1045 |
+
datasets.append(interactive_ds)
|
| 1046 |
+
dataset_names.append("interactive")
|
| 1047 |
+
dataset_weights.append(interactive_weight)
|
| 1048 |
+
|
| 1049 |
+
# Create placement dataset if path provided
|
| 1050 |
+
if placement_jsonl_path:
|
| 1051 |
+
placement_args = type('Args', (), {})()
|
| 1052 |
+
for k, v in vars(args).items():
|
| 1053 |
+
setattr(placement_args, k, v)
|
| 1054 |
+
placement_args.placement_data_jsonl = placement_jsonl_path
|
| 1055 |
+
if placement_base_dir:
|
| 1056 |
+
placement_args.placement_base_dir = placement_base_dir
|
| 1057 |
+
placement_ds = make_placement_dataset_subjects(placement_args, tokenizers, accelerator)
|
| 1058 |
+
datasets.append(placement_ds)
|
| 1059 |
+
dataset_names.append("placement")
|
| 1060 |
+
dataset_weights.append(placement_weight)
|
| 1061 |
+
|
| 1062 |
+
# Create pexels dataset if path provided
|
| 1063 |
+
if pexels_jsonl_path:
|
| 1064 |
+
pexels_args = type('Args', (), {})()
|
| 1065 |
+
for k, v in vars(args).items():
|
| 1066 |
+
setattr(pexels_args, k, v)
|
| 1067 |
+
pexels_args.pexels_data_jsonl = pexels_jsonl_path
|
| 1068 |
+
if pexels_base_dir:
|
| 1069 |
+
pexels_args.pexels_base_dir = pexels_base_dir
|
| 1070 |
+
pexels_ds = make_pexels_dataset_subjects(pexels_args, tokenizers, accelerator)
|
| 1071 |
+
datasets.append(pexels_ds)
|
| 1072 |
+
dataset_names.append("pexels")
|
| 1073 |
+
dataset_weights.append(pexels_weight)
|
| 1074 |
+
|
| 1075 |
+
if not datasets:
|
| 1076 |
+
raise ValueError("At least one dataset path must be provided")
|
| 1077 |
+
|
| 1078 |
+
if len(datasets) == 1:
|
| 1079 |
+
return datasets[0]
|
| 1080 |
+
|
| 1081 |
+
# Mixed dataset class with balanced sampling (based on smallest dataset)
|
| 1082 |
+
class MixedDataset(torch.utils.data.Dataset):
|
| 1083 |
+
def __init__(self, datasets, dataset_names, dataset_weights):
|
| 1084 |
+
self.datasets = datasets
|
| 1085 |
+
self.dataset_names = dataset_names
|
| 1086 |
+
self.lengths = [len(ds) for ds in datasets]
|
| 1087 |
+
|
| 1088 |
+
# Normalize weights
|
| 1089 |
+
total_weight = sum(dataset_weights)
|
| 1090 |
+
self.weights = [w / total_weight for w in dataset_weights]
|
| 1091 |
+
|
| 1092 |
+
# Calculate samples per dataset based on smallest dataset and weights
|
| 1093 |
+
# Find the minimum weighted size
|
| 1094 |
+
min_weighted_size = min(length / weight for length, weight in zip(self.lengths, dataset_weights))
|
| 1095 |
+
|
| 1096 |
+
# Each dataset contributes samples proportional to its weight, scaled by min_weighted_size
|
| 1097 |
+
self.samples_per_dataset = [int(min_weighted_size * w) for w in dataset_weights]
|
| 1098 |
+
self.total_length = sum(self.samples_per_dataset)
|
| 1099 |
+
|
| 1100 |
+
# Build cumulative sample counts for indexing
|
| 1101 |
+
self.cumsum_samples = [0]
|
| 1102 |
+
for count in self.samples_per_dataset:
|
| 1103 |
+
self.cumsum_samples.append(self.cumsum_samples[-1] + count)
|
| 1104 |
+
|
| 1105 |
+
print(f"Balanced mixed dataset created:")
|
| 1106 |
+
for i, name in enumerate(dataset_names):
|
| 1107 |
+
print(f" {name}: {self.lengths[i]} total, {self.samples_per_dataset[i]} per epoch")
|
| 1108 |
+
print(f" Total samples per epoch: {self.total_length}")
|
| 1109 |
+
|
| 1110 |
+
def __len__(self):
|
| 1111 |
+
return self.total_length
|
| 1112 |
+
|
| 1113 |
+
def __getitem__(self, idx):
|
| 1114 |
+
# Determine which dataset this idx belongs to
|
| 1115 |
+
dataset_idx = 0
|
| 1116 |
+
for i in range(len(self.cumsum_samples) - 1):
|
| 1117 |
+
if self.cumsum_samples[i] <= idx < self.cumsum_samples[i + 1]:
|
| 1118 |
+
dataset_idx = i
|
| 1119 |
+
break
|
| 1120 |
+
|
| 1121 |
+
# Randomly sample from the selected dataset (enables different samples each epoch)
|
| 1122 |
+
local_idx = random.randint(0, self.lengths[dataset_idx] - 1)
|
| 1123 |
+
sample = self.datasets[dataset_idx][local_idx]
|
| 1124 |
+
# Add dataset source information
|
| 1125 |
+
sample["dataset_source"] = self.dataset_names[dataset_idx]
|
| 1126 |
+
return sample
|
| 1127 |
+
|
| 1128 |
+
return MixedDataset(datasets, dataset_names, dataset_weights)
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
def _run_test_mode(
|
| 1132 |
+
interactive_jsonl: str = None,
|
| 1133 |
+
placement_jsonl: str = None,
|
| 1134 |
+
pexels_jsonl: str = None,
|
| 1135 |
+
interactive_base_dir: str = None,
|
| 1136 |
+
placement_base_dir: str = None,
|
| 1137 |
+
pexels_base_dir: str = None,
|
| 1138 |
+
pexels_relight_base_dir: str = None,
|
| 1139 |
+
seg_base_dir: str = None,
|
| 1140 |
+
interactive_weight: float = 1.0,
|
| 1141 |
+
placement_weight: float = 1.0,
|
| 1142 |
+
pexels_weight: float = 1.0,
|
| 1143 |
+
output_dir: str = "test_output",
|
| 1144 |
+
num_samples: int = 100
|
| 1145 |
+
):
|
| 1146 |
+
"""Test dataset by saving samples with source labels.
|
| 1147 |
+
|
| 1148 |
+
Args:
|
| 1149 |
+
interactive_jsonl: Path to interactive dataset JSONL (optional)
|
| 1150 |
+
placement_jsonl: Path to placement dataset JSONL (optional)
|
| 1151 |
+
pexels_jsonl: Path to pexels dataset JSONL (optional)
|
| 1152 |
+
interactive_base_dir: Base directory for interactive dataset
|
| 1153 |
+
placement_base_dir: Base directory for placement dataset
|
| 1154 |
+
pexels_base_dir: Base directory for pexels dataset
|
| 1155 |
+
pexels_relight_base_dir: Base directory for pexels relighted images
|
| 1156 |
+
seg_base_dir: Directory containing segmentation maps for pexels dataset
|
| 1157 |
+
interactive_weight: Sampling weight for interactive dataset (default: 1.0)
|
| 1158 |
+
placement_weight: Sampling weight for placement dataset (default: 1.0)
|
| 1159 |
+
pexels_weight: Sampling weight for pexels dataset (default: 1.0)
|
| 1160 |
+
output_dir: Output directory for test images
|
| 1161 |
+
num_samples: Number of samples to save
|
| 1162 |
+
"""
|
| 1163 |
+
if not interactive_jsonl and not placement_jsonl and not pexels_jsonl:
|
| 1164 |
+
raise ValueError("At least one dataset path must be provided")
|
| 1165 |
+
|
| 1166 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1167 |
+
|
| 1168 |
+
# Create dummy tokenizers for testing
|
| 1169 |
+
class DummyTokenizer:
|
| 1170 |
+
def __call__(self, text, **kwargs):
|
| 1171 |
+
class Result:
|
| 1172 |
+
input_ids = torch.zeros(1, 77, dtype=torch.long)
|
| 1173 |
+
return Result()
|
| 1174 |
+
|
| 1175 |
+
tokenizers = (DummyTokenizer(), DummyTokenizer())
|
| 1176 |
+
|
| 1177 |
+
# Create args object
|
| 1178 |
+
class Args:
|
| 1179 |
+
cond_size = 512
|
| 1180 |
+
|
| 1181 |
+
args = Args()
|
| 1182 |
+
args.train_data_jsonl = interactive_jsonl
|
| 1183 |
+
args.placement_data_jsonl = placement_jsonl
|
| 1184 |
+
args.pexels_data_jsonl = pexels_jsonl
|
| 1185 |
+
args.interactive_base_dir = interactive_base_dir
|
| 1186 |
+
args.placement_base_dir = placement_base_dir
|
| 1187 |
+
args.pexels_base_dir = pexels_base_dir
|
| 1188 |
+
args.pexels_relight_base_dir = pexels_relight_base_dir if pexels_relight_base_dir else "/robby/share/Editing/lzc/data/relight_outputs"
|
| 1189 |
+
args.seg_base_dir = seg_base_dir if seg_base_dir else "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182"
|
| 1190 |
+
|
| 1191 |
+
# Create dataset (single or mixed)
|
| 1192 |
+
try:
|
| 1193 |
+
# Count how many datasets are provided
|
| 1194 |
+
num_datasets = sum([bool(interactive_jsonl), bool(placement_jsonl), bool(pexels_jsonl)])
|
| 1195 |
+
|
| 1196 |
+
if num_datasets > 1:
|
| 1197 |
+
dataset = make_mixed_dataset(
|
| 1198 |
+
args, tokenizers,
|
| 1199 |
+
interactive_jsonl_path=interactive_jsonl,
|
| 1200 |
+
placement_jsonl_path=placement_jsonl,
|
| 1201 |
+
pexels_jsonl_path=pexels_jsonl,
|
| 1202 |
+
interactive_base_dir=args.interactive_base_dir,
|
| 1203 |
+
placement_base_dir=args.placement_base_dir,
|
| 1204 |
+
pexels_base_dir=args.pexels_base_dir,
|
| 1205 |
+
interactive_weight=interactive_weight,
|
| 1206 |
+
placement_weight=placement_weight,
|
| 1207 |
+
pexels_weight=pexels_weight
|
| 1208 |
+
)
|
| 1209 |
+
print(f"Created mixed dataset with {len(dataset)} samples")
|
| 1210 |
+
weights_str = []
|
| 1211 |
+
if interactive_jsonl:
|
| 1212 |
+
weights_str.append(f"Interactive: {interactive_weight:.2f}")
|
| 1213 |
+
if placement_jsonl:
|
| 1214 |
+
weights_str.append(f"Placement: {placement_weight:.2f}")
|
| 1215 |
+
if pexels_jsonl:
|
| 1216 |
+
weights_str.append(f"Pexels: {pexels_weight:.2f}")
|
| 1217 |
+
print(f"Sampling weights - {', '.join(weights_str)}")
|
| 1218 |
+
elif pexels_jsonl:
|
| 1219 |
+
dataset = make_pexels_dataset_subjects(args, tokenizers, base_dir=pexels_base_dir)
|
| 1220 |
+
print(f"Created pexels dataset with {len(dataset)} samples")
|
| 1221 |
+
elif placement_jsonl:
|
| 1222 |
+
dataset = make_placement_dataset_subjects(args, tokenizers, base_dir=args.placement_base_dir)
|
| 1223 |
+
print(f"Created placement dataset with {len(dataset)} samples")
|
| 1224 |
+
else:
|
| 1225 |
+
dataset = make_interactive_dataset_subjects(args, tokenizers, base_dir=args.interactive_base_dir)
|
| 1226 |
+
print(f"Created interactive dataset with {len(dataset)} samples")
|
| 1227 |
+
except Exception as e:
|
| 1228 |
+
print(f"Failed to create dataset: {e}")
|
| 1229 |
+
import traceback
|
| 1230 |
+
traceback.print_exc()
|
| 1231 |
+
return
|
| 1232 |
+
|
| 1233 |
+
# Sample and save
|
| 1234 |
+
saved = 0
|
| 1235 |
+
counts = {}
|
| 1236 |
+
|
| 1237 |
+
for attempt in range(min(num_samples * 3, len(dataset))):
|
| 1238 |
+
try:
|
| 1239 |
+
idx = random.randint(0, len(dataset) - 1)
|
| 1240 |
+
sample = dataset[idx]
|
| 1241 |
+
|
| 1242 |
+
source_name = sample.get("dataset_source", "single")
|
| 1243 |
+
counts[source_name] = counts.get(source_name, 0) + 1
|
| 1244 |
+
|
| 1245 |
+
# Denormalize tensors from [-1, 1] to [0, 255]
|
| 1246 |
+
source_np = ((sample["source_pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
| 1247 |
+
target_np = ((sample["pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
| 1248 |
+
|
| 1249 |
+
# Save images
|
| 1250 |
+
idx_str = f"{saved:05d}"
|
| 1251 |
+
Image.fromarray(source_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_source.jpg"))
|
| 1252 |
+
Image.fromarray(target_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_target.jpg"))
|
| 1253 |
+
|
| 1254 |
+
saved += 1
|
| 1255 |
+
if saved % 10 == 0:
|
| 1256 |
+
print(f"Saved {saved}/{num_samples} samples - {counts}")
|
| 1257 |
+
if saved >= num_samples:
|
| 1258 |
+
break
|
| 1259 |
+
|
| 1260 |
+
except Exception as e:
|
| 1261 |
+
print(f"Failed to process sample: {e}")
|
| 1262 |
+
continue
|
| 1263 |
+
|
| 1264 |
+
print(f"\nTest complete. Saved {saved} samples to {output_dir}")
|
| 1265 |
+
print(f"Distribution: {counts}")
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
def _parse_test_args():
|
| 1269 |
+
import argparse
|
| 1270 |
+
parser = argparse.ArgumentParser(description="Test visualization for Kontext datasets")
|
| 1271 |
+
parser.add_argument("--interactive_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl",
|
| 1272 |
+
help="Path to interactive dataset JSONL")
|
| 1273 |
+
parser.add_argument("--placement_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl",
|
| 1274 |
+
help="Path to placement dataset JSONL")
|
| 1275 |
+
parser.add_argument("--pexels_jsonl", type=str, default=None,
|
| 1276 |
+
help="Path to pexels dataset JSONL")
|
| 1277 |
+
parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1",
|
| 1278 |
+
help="Base directory for interactive dataset")
|
| 1279 |
+
parser.add_argument("--placement_base_dir", type=str, default=None,
|
| 1280 |
+
help="Base directory for placement dataset")
|
| 1281 |
+
parser.add_argument("--pexels_base_dir", type=str, default=None,
|
| 1282 |
+
help="Base directory for pexels dataset")
|
| 1283 |
+
parser.add_argument("--pexels_relight_base_dir", type=str, default="/robby/share/Editing/lzc/data/relight_outputs",
|
| 1284 |
+
help="Base directory for pexels relighted images")
|
| 1285 |
+
parser.add_argument("--seg_base_dir", type=str, default=None,
|
| 1286 |
+
help="Directory containing segmentation maps for pexels dataset")
|
| 1287 |
+
parser.add_argument("--interactive_weight", type=float, default=1.0,
|
| 1288 |
+
help="Sampling weight for interactive dataset (default: 1.0)")
|
| 1289 |
+
parser.add_argument("--placement_weight", type=float, default=1.0,
|
| 1290 |
+
help="Sampling weight for placement dataset (default: 1.0)")
|
| 1291 |
+
parser.add_argument("--pexels_weight", type=float, default=0,
|
| 1292 |
+
help="Sampling weight for pexels dataset (default: 1.0)")
|
| 1293 |
+
parser.add_argument("--output_dir", type=str, default="visualize_output",
|
| 1294 |
+
help="Output directory to save pairs")
|
| 1295 |
+
parser.add_argument("--num_samples", type=int, default=100,
|
| 1296 |
+
help="Number of pairs to save")
|
| 1297 |
+
|
| 1298 |
+
# Legacy arguments
|
| 1299 |
+
parser.add_argument("--test_jsonl", type=str, default=None,
|
| 1300 |
+
help="Legacy: Path to JSONL (uses as interactive_jsonl)")
|
| 1301 |
+
parser.add_argument("--base_dir", type=str, default=None,
|
| 1302 |
+
help="Legacy: Base directory (uses as interactive_base_dir)")
|
| 1303 |
+
return parser.parse_args()
|
| 1304 |
+
|
| 1305 |
+
|
| 1306 |
+
if __name__ == "__main__":
|
| 1307 |
+
try:
|
| 1308 |
+
args = _parse_test_args()
|
| 1309 |
+
|
| 1310 |
+
# Handle legacy args
|
| 1311 |
+
interactive_jsonl = args.interactive_jsonl or args.test_jsonl
|
| 1312 |
+
interactive_base_dir = args.interactive_base_dir or args.base_dir
|
| 1313 |
+
|
| 1314 |
+
_run_test_mode(
|
| 1315 |
+
interactive_jsonl=interactive_jsonl,
|
| 1316 |
+
placement_jsonl=args.placement_jsonl,
|
| 1317 |
+
pexels_jsonl=args.pexels_jsonl,
|
| 1318 |
+
interactive_base_dir=interactive_base_dir,
|
| 1319 |
+
placement_base_dir=args.placement_base_dir,
|
| 1320 |
+
pexels_base_dir=args.pexels_base_dir,
|
| 1321 |
+
pexels_relight_base_dir=args.pexels_relight_base_dir,
|
| 1322 |
+
seg_base_dir=args.seg_base_dir,
|
| 1323 |
+
interactive_weight=args.interactive_weight,
|
| 1324 |
+
placement_weight=args.placement_weight,
|
| 1325 |
+
pexels_weight=args.pexels_weight,
|
| 1326 |
+
output_dir=args.output_dir,
|
| 1327 |
+
num_samples=args.num_samples
|
| 1328 |
+
)
|
| 1329 |
+
except SystemExit:
|
| 1330 |
+
# Allow import usage without triggering test mode
|
| 1331 |
+
pass
|
| 1332 |
+
|
train/src/jsonl_datasets_kontext_local.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from datasets import Dataset
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
|
| 8 |
+
from .jsonl_datasets_kontext import make_train_dataset_inpaint_mask
|
| 9 |
+
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
from .generate_diff_mask import generate_final_difference_mask, align_images
|
| 12 |
+
|
| 13 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 14 |
+
BLEND_PIXEL_VALUES = True
|
| 15 |
+
|
| 16 |
+
def multiple_16(num: float):
|
| 17 |
+
return int(round(num / 16) * 16)
|
| 18 |
+
|
| 19 |
+
def choose_kontext_resolution_from_wh(width: int, height: int):
|
| 20 |
+
aspect_ratio = width / max(1, height)
|
| 21 |
+
_, best_w, best_h = min(
|
| 22 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 23 |
+
)
|
| 24 |
+
return best_w, best_h
|
| 25 |
+
|
| 26 |
+
def collate_fn(examples):
|
| 27 |
+
if examples[0].get("cond_pixel_values") is not None:
|
| 28 |
+
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
|
| 29 |
+
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 30 |
+
else:
|
| 31 |
+
cond_pixel_values = None
|
| 32 |
+
if examples[0].get("source_pixel_values") is not None:
|
| 33 |
+
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
|
| 34 |
+
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 35 |
+
else:
|
| 36 |
+
source_pixel_values = None
|
| 37 |
+
|
| 38 |
+
target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 39 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 40 |
+
token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
|
| 41 |
+
token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
|
| 42 |
+
|
| 43 |
+
mask_values = None
|
| 44 |
+
if examples[0].get("mask_values") is not None:
|
| 45 |
+
mask_values = torch.stack([example["mask_values"] for example in examples])
|
| 46 |
+
mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"cond_pixel_values": cond_pixel_values,
|
| 50 |
+
"source_pixel_values": source_pixel_values,
|
| 51 |
+
"pixel_values": target_pixel_values,
|
| 52 |
+
"text_ids_1": token_ids_clip,
|
| 53 |
+
"text_ids_2": token_ids_t5,
|
| 54 |
+
"mask_values": mask_values,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# New dataset for local_edits JSON mapping with on-the-fly diff-mask generation
|
| 59 |
+
def make_train_dataset_local_edits(args, tokenizers, accelerator=None):
|
| 60 |
+
# Read JSON entries
|
| 61 |
+
with open(args.local_edits_json, "r", encoding="utf-8") as f:
|
| 62 |
+
entries = json.load(f)
|
| 63 |
+
|
| 64 |
+
samples = []
|
| 65 |
+
for item in entries:
|
| 66 |
+
rel_path = item.get("path", "")
|
| 67 |
+
local_edits = item.get("local_edits", []) or []
|
| 68 |
+
if not rel_path or not local_edits:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
base_name = os.path.basename(rel_path)
|
| 72 |
+
prefix = os.path.splitext(base_name)[0]
|
| 73 |
+
group_dir = os.path.basename(os.path.dirname(rel_path))
|
| 74 |
+
gid_int = None
|
| 75 |
+
try:
|
| 76 |
+
gid_int = int(group_dir)
|
| 77 |
+
except Exception:
|
| 78 |
+
try:
|
| 79 |
+
digits = "".join([ch for ch in group_dir if ch.isdigit()])
|
| 80 |
+
gid_int = int(digits) if digits else None
|
| 81 |
+
except Exception:
|
| 82 |
+
gid_int = None
|
| 83 |
+
|
| 84 |
+
group_str = group_dir # e.g., "0139" from the JSON path segment
|
| 85 |
+
|
| 86 |
+
# Resolve source/target directories strictly as base/<0139>
|
| 87 |
+
src_dir_candidates = [os.path.join(args.source_frames_dir, group_str)]
|
| 88 |
+
tgt_dir_candidates = [os.path.join(args.target_frames_dir, group_str)]
|
| 89 |
+
src_dir = next((d for d in src_dir_candidates if d and os.path.isdir(d)), None)
|
| 90 |
+
tgt_dir = next((d for d in tgt_dir_candidates if d and os.path.isdir(d)), None)
|
| 91 |
+
if src_dir is None or tgt_dir is None:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
src_path = os.path.join(src_dir, f"{prefix}.png")
|
| 95 |
+
for idx, prompt in enumerate(local_edits, start=1):
|
| 96 |
+
tgt_path = os.path.join(tgt_dir, f"{prefix}_{idx}.png")
|
| 97 |
+
mask_path = os.path.join(args.masks_dir, group_str, f"{prefix}_{idx}.png")
|
| 98 |
+
if not (os.path.exists(src_path) and os.path.exists(tgt_path) and os.path.exists(mask_path)):
|
| 99 |
+
continue
|
| 100 |
+
samples.append({
|
| 101 |
+
"source_image": src_path,
|
| 102 |
+
"target_image": tgt_path,
|
| 103 |
+
"mask_image": mask_path,
|
| 104 |
+
"prompt": prompt,
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
size = args.cond_size
|
| 108 |
+
|
| 109 |
+
to_tensor_and_norm = transforms.Compose([
|
| 110 |
+
transforms.ToTensor(),
|
| 111 |
+
transforms.Normalize([0.5], [0.5]),
|
| 112 |
+
])
|
| 113 |
+
|
| 114 |
+
cond_train_transforms = transforms.Compose(
|
| 115 |
+
[
|
| 116 |
+
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 117 |
+
transforms.ToTensor(),
|
| 118 |
+
transforms.Normalize([0.5], [0.5]),
|
| 119 |
+
]
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
tokenizer_clip = tokenizers[0]
|
| 123 |
+
tokenizer_t5 = tokenizers[1]
|
| 124 |
+
|
| 125 |
+
def tokenize_prompt_single(caption: str):
|
| 126 |
+
text_inputs_clip = tokenizer_clip(
|
| 127 |
+
[caption],
|
| 128 |
+
padding="max_length",
|
| 129 |
+
max_length=77,
|
| 130 |
+
truncation=True,
|
| 131 |
+
return_tensors="pt",
|
| 132 |
+
)
|
| 133 |
+
text_input_ids_1 = text_inputs_clip.input_ids[0]
|
| 134 |
+
|
| 135 |
+
text_inputs_t5 = tokenizer_t5(
|
| 136 |
+
[caption],
|
| 137 |
+
padding="max_length",
|
| 138 |
+
max_length=128,
|
| 139 |
+
truncation=True,
|
| 140 |
+
return_tensors="pt",
|
| 141 |
+
)
|
| 142 |
+
text_input_ids_2 = text_inputs_t5.input_ids[0]
|
| 143 |
+
return text_input_ids_1, text_input_ids_2
|
| 144 |
+
|
| 145 |
+
class LocalEditsDataset(torch.utils.data.Dataset):
|
| 146 |
+
def __init__(self, samples_ls):
|
| 147 |
+
self.samples = samples_ls
|
| 148 |
+
def __len__(self):
|
| 149 |
+
return len(self.samples)
|
| 150 |
+
def __getitem__(self, idx):
|
| 151 |
+
sample = self.samples[idx]
|
| 152 |
+
s_p = sample["source_image"]
|
| 153 |
+
t_p = sample["target_image"]
|
| 154 |
+
m_p = sample["mask_image"]
|
| 155 |
+
cap = sample["prompt"]
|
| 156 |
+
|
| 157 |
+
rr = random.randint(10, 20)
|
| 158 |
+
ri = random.randint(3, 5)
|
| 159 |
+
import cv2
|
| 160 |
+
mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
|
| 161 |
+
if mask_loaded is None:
|
| 162 |
+
raise ValueError("mask load failed")
|
| 163 |
+
mask = mask_loaded.copy()
|
| 164 |
+
|
| 165 |
+
# Pre-expand mask by a fixed number of pixels before any random expansion
|
| 166 |
+
# Uses a cross-shaped kernel when tapered_corners is True to emulate "tapered" growth
|
| 167 |
+
pre_expand_px = int(getattr(args, "pre_expand_mask_px", 50))
|
| 168 |
+
pre_expand_tapered = bool(getattr(args, "pre_expand_tapered_corners", True))
|
| 169 |
+
if pre_expand_px != 0:
|
| 170 |
+
c = 0 if pre_expand_tapered else 1
|
| 171 |
+
pre_kernel = np.array([[c, 1, c],
|
| 172 |
+
[1, 1, 1],
|
| 173 |
+
[c, 1, c]], dtype=np.uint8)
|
| 174 |
+
if pre_expand_px > 0:
|
| 175 |
+
mask = cv2.dilate(mask, pre_kernel, iterations=pre_expand_px)
|
| 176 |
+
else:
|
| 177 |
+
mask = cv2.erode(mask, pre_kernel, iterations=abs(pre_expand_px))
|
| 178 |
+
if rr > 0 and ri > 0:
|
| 179 |
+
ksize = max(1, 2 * int(rr) + 1)
|
| 180 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
|
| 181 |
+
for _ in range(max(1, ri)):
|
| 182 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 183 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 184 |
+
|
| 185 |
+
src_aligned, tgt_aligned = align_images(s_p, t_p)
|
| 186 |
+
|
| 187 |
+
best_w, best_h = choose_kontext_resolution_from_wh(tgt_aligned.width, tgt_aligned.height)
|
| 188 |
+
final_img_rs = tgt_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
|
| 189 |
+
raw_img_rs = src_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
|
| 190 |
+
|
| 191 |
+
target_tensor = to_tensor_and_norm(final_img_rs)
|
| 192 |
+
source_tensor = to_tensor_and_norm(raw_img_rs)
|
| 193 |
+
|
| 194 |
+
mask_img = Image.fromarray(mask.astype(np.uint8)).convert("L")
|
| 195 |
+
if mask_img.size != src_aligned.size:
|
| 196 |
+
mask_img = mask_img.resize(src_aligned.size, Image.NEAREST)
|
| 197 |
+
mask_np = np.array(mask_img)
|
| 198 |
+
|
| 199 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 200 |
+
inv_mask = (1 - mask_bin).astype(np.uint8)
|
| 201 |
+
src_np = np.array(src_aligned)
|
| 202 |
+
masked_raw_np = src_np * inv_mask[..., None]
|
| 203 |
+
masked_raw_img = Image.fromarray(masked_raw_np.astype(np.uint8))
|
| 204 |
+
cond_tensor = cond_train_transforms(masked_raw_img)
|
| 205 |
+
|
| 206 |
+
# Prepare mask_values tensor at training resolution (best_w, best_h)
|
| 207 |
+
mask_img_rs = mask_img.resize((best_w, best_h), Image.NEAREST)
|
| 208 |
+
mask_np_rs = np.array(mask_img_rs)
|
| 209 |
+
mask_bin_rs = (mask_np_rs > 127).astype(np.float32)
|
| 210 |
+
mask_tensor = torch.from_numpy(mask_bin_rs).unsqueeze(0) # [1, H, W]
|
| 211 |
+
|
| 212 |
+
ids1, ids2 = tokenize_prompt_single(cap if isinstance(cap, str) else "")
|
| 213 |
+
|
| 214 |
+
# Optionally blend target and source using a blurred mask, controlled by args
|
| 215 |
+
if getattr(args, "blend_pixel_values", BLEND_PIXEL_VALUES):
|
| 216 |
+
blend_kernel = int(getattr(args, "blend_kernel", 21))
|
| 217 |
+
if blend_kernel % 2 == 0:
|
| 218 |
+
blend_kernel += 1
|
| 219 |
+
blend_sigma = float(getattr(args, "blend_sigma", 10.0))
|
| 220 |
+
gb = transforms.GaussianBlur(kernel_size=(blend_kernel, blend_kernel), sigma=(blend_sigma, blend_sigma))
|
| 221 |
+
# mask_tensor: [1, H, W] in [0,1]
|
| 222 |
+
blurred_mask = gb(mask_tensor) # [1, H, W]
|
| 223 |
+
# Expand to 3 channels to match image tensors
|
| 224 |
+
blurred_mask_3c = blurred_mask.expand(target_tensor.shape[0], -1, -1) # [3, H, W]
|
| 225 |
+
# Blend in normalized space (both tensors already normalized to [-1, 1])
|
| 226 |
+
target_tensor = (source_tensor * (1.0 - blurred_mask_3c)) + (target_tensor * blurred_mask_3c)
|
| 227 |
+
target_tensor = target_tensor.clamp(-1.0, 1.0)
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
"source_pixel_values": source_tensor,
|
| 231 |
+
"pixel_values": target_tensor,
|
| 232 |
+
"cond_pixel_values": cond_tensor,
|
| 233 |
+
"token_ids_clip": ids1,
|
| 234 |
+
"token_ids_t5": ids2,
|
| 235 |
+
"mask_values": mask_tensor,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
return LocalEditsDataset(samples)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class BalancedMixDataset(torch.utils.data.Dataset):
|
| 242 |
+
"""
|
| 243 |
+
A wrapper dataset that mixes two datasets with a configurable ratio.
|
| 244 |
+
|
| 245 |
+
ratio_b_per_a defines how many samples from dataset_b for each sample from dataset_a:
|
| 246 |
+
- 0 => only dataset_a (local edits)
|
| 247 |
+
- 1 => 1:1 mix (default)
|
| 248 |
+
- 2 => 1:2 mix (A:B)
|
| 249 |
+
- any float supported (e.g., 0.5 => 2:1 mix)
|
| 250 |
+
"""
|
| 251 |
+
def __init__(self, dataset_a, dataset_b, ratio_b_per_a: float = 1.0):
|
| 252 |
+
self.dataset_a = dataset_a
|
| 253 |
+
self.dataset_b = dataset_b
|
| 254 |
+
self.ratio_b_per_a = max(0.0, float(ratio_b_per_a))
|
| 255 |
+
|
| 256 |
+
len_a = len(dataset_a)
|
| 257 |
+
len_b = len(dataset_b)
|
| 258 |
+
|
| 259 |
+
# If ratio is 0, use all of dataset_a only
|
| 260 |
+
if self.ratio_b_per_a == 0 or len_b == 0:
|
| 261 |
+
a_indices = list(range(len_a))
|
| 262 |
+
random.shuffle(a_indices)
|
| 263 |
+
self.mapping = [(0, i) for i in a_indices]
|
| 264 |
+
return
|
| 265 |
+
|
| 266 |
+
# Determine how many we can draw without replacement
|
| 267 |
+
# n_a limited by A size and B availability according to ratio
|
| 268 |
+
n_a_by_ratio = int(len_b / self.ratio_b_per_a)
|
| 269 |
+
n_a = min(len_a, max(1, n_a_by_ratio))
|
| 270 |
+
n_b = min(len_b, max(1, int(round(n_a * self.ratio_b_per_a))))
|
| 271 |
+
|
| 272 |
+
a_indices = list(range(len_a))
|
| 273 |
+
b_indices = list(range(len_b))
|
| 274 |
+
random.shuffle(a_indices)
|
| 275 |
+
random.shuffle(b_indices)
|
| 276 |
+
a_indices = a_indices[: n_a]
|
| 277 |
+
b_indices = b_indices[: n_b]
|
| 278 |
+
|
| 279 |
+
mixed = [(0, i) for i in a_indices] + [(1, i) for i in b_indices]
|
| 280 |
+
random.shuffle(mixed)
|
| 281 |
+
self.mapping = mixed
|
| 282 |
+
|
| 283 |
+
def __len__(self):
|
| 284 |
+
return len(self.mapping)
|
| 285 |
+
|
| 286 |
+
def __getitem__(self, idx):
|
| 287 |
+
which, real_idx = self.mapping[idx]
|
| 288 |
+
if which == 0:
|
| 289 |
+
return self.dataset_a[real_idx]
|
| 290 |
+
else:
|
| 291 |
+
return self.dataset_b[real_idx]
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def make_train_dataset_mixed(args, tokenizers, accelerator=None):
|
| 295 |
+
"""
|
| 296 |
+
Create a mixed dataset from:
|
| 297 |
+
- Local edits dataset (this file)
|
| 298 |
+
- Inpaint-mask JSONL dataset (jsonl_datasets_kontext.make_train_dataset_inpaint_mask)
|
| 299 |
+
|
| 300 |
+
Ratio control via args.mix_ratio (float):
|
| 301 |
+
- 0 => only local edits dataset
|
| 302 |
+
- 1 => 1:1 mix (local:inpaint)
|
| 303 |
+
- 2 => 1:2 mix, etc.
|
| 304 |
+
|
| 305 |
+
Requirements:
|
| 306 |
+
- args.local_edits_json and related dirs must be set for local edits
|
| 307 |
+
- args.train_data_dir must be set for the JSONL inpaint dataset
|
| 308 |
+
"""
|
| 309 |
+
ds_local = make_train_dataset_local_edits(args, tokenizers, accelerator)
|
| 310 |
+
ds_inpaint = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
|
| 311 |
+
mix_ratio = getattr(args, "mix_ratio", 1.0)
|
| 312 |
+
return BalancedMixDataset(ds_local, ds_inpaint, ratio_b_per_a=mix_ratio)
|
train/src/layers.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from diffusers.models.attention_processor import Attention
|
| 10 |
+
|
| 11 |
+
class LoRALinearLayer(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_features: int,
|
| 15 |
+
out_features: int,
|
| 16 |
+
rank: int = 4,
|
| 17 |
+
network_alpha: Optional[float] = None,
|
| 18 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 19 |
+
dtype: Optional[torch.dtype] = None,
|
| 20 |
+
cond_width=512,
|
| 21 |
+
cond_height=512,
|
| 22 |
+
number=0,
|
| 23 |
+
n_loras=1
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 27 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 28 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
| 29 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
| 30 |
+
self.network_alpha = network_alpha
|
| 31 |
+
self.rank = rank
|
| 32 |
+
self.out_features = out_features
|
| 33 |
+
self.in_features = in_features
|
| 34 |
+
|
| 35 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 36 |
+
nn.init.zeros_(self.up.weight)
|
| 37 |
+
|
| 38 |
+
self.cond_height = cond_height
|
| 39 |
+
self.cond_width = cond_width
|
| 40 |
+
self.number = number
|
| 41 |
+
self.n_loras = n_loras
|
| 42 |
+
|
| 43 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
orig_dtype = hidden_states.dtype
|
| 45 |
+
dtype = self.down.weight.dtype
|
| 46 |
+
|
| 47 |
+
#### img condition
|
| 48 |
+
batch_size = hidden_states.shape[0]
|
| 49 |
+
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
|
| 50 |
+
block_size = hidden_states.shape[1] - cond_size * self.n_loras
|
| 51 |
+
shape = (batch_size, hidden_states.shape[1], 3072)
|
| 52 |
+
mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
|
| 53 |
+
mask[:, :block_size+self.number*cond_size, :] = 0
|
| 54 |
+
mask[:, block_size+(self.number+1)*cond_size:, :] = 0
|
| 55 |
+
hidden_states = mask * hidden_states
|
| 56 |
+
####
|
| 57 |
+
|
| 58 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 59 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 60 |
+
|
| 61 |
+
if self.network_alpha is not None:
|
| 62 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 63 |
+
|
| 64 |
+
return up_hidden_states.to(orig_dtype)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MultiSingleStreamBlockLoraProcessor(nn.Module):
|
| 68 |
+
def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
|
| 69 |
+
super().__init__()
|
| 70 |
+
# Initialize a list to store the LoRA layers
|
| 71 |
+
self.n_loras = n_loras
|
| 72 |
+
self.cond_width = cond_width
|
| 73 |
+
self.cond_height = cond_height
|
| 74 |
+
|
| 75 |
+
self.q_loras = nn.ModuleList([
|
| 76 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 77 |
+
for i in range(n_loras)
|
| 78 |
+
])
|
| 79 |
+
self.k_loras = nn.ModuleList([
|
| 80 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 81 |
+
for i in range(n_loras)
|
| 82 |
+
])
|
| 83 |
+
self.v_loras = nn.ModuleList([
|
| 84 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 85 |
+
for i in range(n_loras)
|
| 86 |
+
])
|
| 87 |
+
self.lora_weights = lora_weights
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def __call__(self,
|
| 91 |
+
attn: Attention,
|
| 92 |
+
hidden_states: torch.FloatTensor,
|
| 93 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 94 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 95 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 96 |
+
use_cond = False,
|
| 97 |
+
) -> torch.FloatTensor:
|
| 98 |
+
|
| 99 |
+
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 100 |
+
query = attn.to_q(hidden_states)
|
| 101 |
+
key = attn.to_k(hidden_states)
|
| 102 |
+
value = attn.to_v(hidden_states)
|
| 103 |
+
|
| 104 |
+
for i in range(self.n_loras):
|
| 105 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 106 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 107 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 108 |
+
|
| 109 |
+
inner_dim = key.shape[-1]
|
| 110 |
+
head_dim = inner_dim // attn.heads
|
| 111 |
+
|
| 112 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 113 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 114 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 115 |
+
|
| 116 |
+
if attn.norm_q is not None:
|
| 117 |
+
query = attn.norm_q(query)
|
| 118 |
+
if attn.norm_k is not None:
|
| 119 |
+
key = attn.norm_k(key)
|
| 120 |
+
|
| 121 |
+
if image_rotary_emb is not None:
|
| 122 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 123 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 124 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 125 |
+
|
| 126 |
+
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
|
| 127 |
+
block_size = hidden_states.shape[1] - cond_size * self.n_loras
|
| 128 |
+
scaled_cond_size = cond_size
|
| 129 |
+
scaled_block_size = block_size
|
| 130 |
+
scaled_seq_len = query.shape[2]
|
| 131 |
+
|
| 132 |
+
num_cond_blocks = self.n_loras
|
| 133 |
+
# mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
|
| 134 |
+
# mask[ :scaled_block_size, :] = 0 # First block_size row
|
| 135 |
+
# for i in range(num_cond_blocks):
|
| 136 |
+
# start = i * scaled_cond_size + scaled_block_size
|
| 137 |
+
# end = (i + 1) * scaled_cond_size + scaled_block_size
|
| 138 |
+
# mask[start:end, start:end] = 0 # Diagonal blocks
|
| 139 |
+
# mask = mask * -1e20
|
| 140 |
+
# mask = mask.to(query.dtype)
|
| 141 |
+
|
| 142 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
|
| 143 |
+
|
| 144 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 145 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 146 |
+
|
| 147 |
+
cond_hidden_states = hidden_states[:, block_size:,:]
|
| 148 |
+
hidden_states = hidden_states[:, : block_size,:]
|
| 149 |
+
|
| 150 |
+
return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
|
| 154 |
+
def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
# Initialize a list to store the LoRA layers
|
| 158 |
+
self.n_loras = n_loras
|
| 159 |
+
self.cond_width = cond_width
|
| 160 |
+
self.cond_height = cond_height
|
| 161 |
+
self.q_loras = nn.ModuleList([
|
| 162 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 163 |
+
for i in range(n_loras)
|
| 164 |
+
])
|
| 165 |
+
self.k_loras = nn.ModuleList([
|
| 166 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 167 |
+
for i in range(n_loras)
|
| 168 |
+
])
|
| 169 |
+
self.v_loras = nn.ModuleList([
|
| 170 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 171 |
+
for i in range(n_loras)
|
| 172 |
+
])
|
| 173 |
+
self.proj_loras = nn.ModuleList([
|
| 174 |
+
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
|
| 175 |
+
for i in range(n_loras)
|
| 176 |
+
])
|
| 177 |
+
self.lora_weights = lora_weights
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def __call__(self,
|
| 181 |
+
attn: Attention,
|
| 182 |
+
hidden_states: torch.FloatTensor,
|
| 183 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 184 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 185 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 186 |
+
use_cond=False,
|
| 187 |
+
) -> torch.FloatTensor:
|
| 188 |
+
|
| 189 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 190 |
+
|
| 191 |
+
# `context` projections.
|
| 192 |
+
inner_dim = 3072
|
| 193 |
+
head_dim = inner_dim // attn.heads
|
| 194 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 195 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 196 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 197 |
+
|
| 198 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 199 |
+
batch_size, -1, attn.heads, head_dim
|
| 200 |
+
).transpose(1, 2)
|
| 201 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 202 |
+
batch_size, -1, attn.heads, head_dim
|
| 203 |
+
).transpose(1, 2)
|
| 204 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 205 |
+
batch_size, -1, attn.heads, head_dim
|
| 206 |
+
).transpose(1, 2)
|
| 207 |
+
|
| 208 |
+
if attn.norm_added_q is not None:
|
| 209 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 210 |
+
if attn.norm_added_k is not None:
|
| 211 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 212 |
+
|
| 213 |
+
query = attn.to_q(hidden_states)
|
| 214 |
+
key = attn.to_k(hidden_states)
|
| 215 |
+
value = attn.to_v(hidden_states)
|
| 216 |
+
for i in range(self.n_loras):
|
| 217 |
+
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
|
| 218 |
+
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
|
| 219 |
+
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
|
| 220 |
+
|
| 221 |
+
inner_dim = key.shape[-1]
|
| 222 |
+
head_dim = inner_dim // attn.heads
|
| 223 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 224 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 225 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 226 |
+
|
| 227 |
+
if attn.norm_q is not None:
|
| 228 |
+
query = attn.norm_q(query)
|
| 229 |
+
if attn.norm_k is not None:
|
| 230 |
+
key = attn.norm_k(key)
|
| 231 |
+
|
| 232 |
+
# attention
|
| 233 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
| 234 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 235 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 236 |
+
|
| 237 |
+
if image_rotary_emb is not None:
|
| 238 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 239 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 240 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 241 |
+
|
| 242 |
+
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
|
| 243 |
+
block_size = hidden_states.shape[1] - cond_size * self.n_loras
|
| 244 |
+
scaled_cond_size = cond_size
|
| 245 |
+
scaled_seq_len = query.shape[2]
|
| 246 |
+
scaled_block_size = scaled_seq_len - cond_size * self.n_loras
|
| 247 |
+
|
| 248 |
+
num_cond_blocks = self.n_loras
|
| 249 |
+
# mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
|
| 250 |
+
# mask[ :scaled_block_size, :] = 0 # First block_size row
|
| 251 |
+
# for i in range(num_cond_blocks):
|
| 252 |
+
# start = i * scaled_cond_size + scaled_block_size
|
| 253 |
+
# end = (i + 1) * scaled_cond_size + scaled_block_size
|
| 254 |
+
# mask[start:end, start:end] = 0 # Diagonal blocks
|
| 255 |
+
# mask = mask * -1e20
|
| 256 |
+
# mask = mask.to(query.dtype)
|
| 257 |
+
|
| 258 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
|
| 259 |
+
|
| 260 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 261 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 262 |
+
|
| 263 |
+
encoder_hidden_states, hidden_states = (
|
| 264 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
| 265 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Linear projection (with LoRA weight applied to each proj layer)
|
| 269 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 270 |
+
for i in range(self.n_loras):
|
| 271 |
+
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
|
| 272 |
+
# dropout
|
| 273 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 274 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 275 |
+
|
| 276 |
+
cond_hidden_states = hidden_states[:, block_size:,:]
|
| 277 |
+
hidden_states = hidden_states[:, :block_size,:]
|
| 278 |
+
|
| 279 |
+
return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
|
train/src/lora_helper.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 2 |
+
from safetensors import safe_open
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 6 |
+
|
| 7 |
+
device = "cuda"
|
| 8 |
+
|
| 9 |
+
def load_safetensors(path):
|
| 10 |
+
tensors = {}
|
| 11 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
| 12 |
+
for key in f.keys():
|
| 13 |
+
tensors[key] = f.get_tensor(key)
|
| 14 |
+
return tensors
|
| 15 |
+
|
| 16 |
+
def get_lora_rank(checkpoint):
|
| 17 |
+
for k in checkpoint.keys():
|
| 18 |
+
if k.endswith(".down.weight"):
|
| 19 |
+
return checkpoint[k].shape[0]
|
| 20 |
+
|
| 21 |
+
def load_checkpoint(local_path):
|
| 22 |
+
if local_path is not None:
|
| 23 |
+
if '.safetensors' in local_path:
|
| 24 |
+
print(f"Loading .safetensors checkpoint from {local_path}")
|
| 25 |
+
checkpoint = load_safetensors(local_path)
|
| 26 |
+
else:
|
| 27 |
+
print(f"Loading checkpoint from {local_path}")
|
| 28 |
+
checkpoint = torch.load(local_path, map_location='cpu')
|
| 29 |
+
return checkpoint
|
| 30 |
+
|
| 31 |
+
def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
|
| 32 |
+
number = len(lora_weights)
|
| 33 |
+
ranks = [get_lora_rank(checkpoint) for _ in range(number)]
|
| 34 |
+
lora_attn_procs = {}
|
| 35 |
+
double_blocks_idx = list(range(19))
|
| 36 |
+
single_blocks_idx = list(range(38))
|
| 37 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 38 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 39 |
+
if match:
|
| 40 |
+
layer_index = int(match.group(1))
|
| 41 |
+
|
| 42 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 43 |
+
|
| 44 |
+
lora_state_dicts = {}
|
| 45 |
+
for key, value in checkpoint.items():
|
| 46 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 47 |
+
if re.search(r'\.(\d+)\.', key):
|
| 48 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 49 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 50 |
+
lora_state_dicts[key] = value
|
| 51 |
+
|
| 52 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 53 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 57 |
+
for n in range(number):
|
| 58 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 59 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 60 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 61 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 62 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 63 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 64 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 65 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 66 |
+
lora_attn_procs[name].to(device)
|
| 67 |
+
|
| 68 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 69 |
+
|
| 70 |
+
lora_state_dicts = {}
|
| 71 |
+
for key, value in checkpoint.items():
|
| 72 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 73 |
+
if re.search(r'\.(\d+)\.', key):
|
| 74 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 75 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 76 |
+
lora_state_dicts[key] = value
|
| 77 |
+
|
| 78 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 79 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
|
| 80 |
+
)
|
| 81 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 82 |
+
for n in range(number):
|
| 83 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 84 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 85 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 86 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 87 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 88 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 89 |
+
lora_attn_procs[name].to(device)
|
| 90 |
+
else:
|
| 91 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 92 |
+
|
| 93 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
|
| 97 |
+
ck_number = len(checkpoints)
|
| 98 |
+
cond_lora_number = [len(ls) for ls in lora_weights]
|
| 99 |
+
cond_number = sum(cond_lora_number)
|
| 100 |
+
ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
|
| 101 |
+
multi_lora_weight = []
|
| 102 |
+
for ls in lora_weights:
|
| 103 |
+
for n in ls:
|
| 104 |
+
multi_lora_weight.append(n)
|
| 105 |
+
|
| 106 |
+
lora_attn_procs = {}
|
| 107 |
+
double_blocks_idx = list(range(19))
|
| 108 |
+
single_blocks_idx = list(range(38))
|
| 109 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 110 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 111 |
+
if match:
|
| 112 |
+
layer_index = int(match.group(1))
|
| 113 |
+
|
| 114 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 115 |
+
lora_state_dicts = [{} for _ in range(ck_number)]
|
| 116 |
+
for idx, checkpoint in enumerate(checkpoints):
|
| 117 |
+
for key, value in checkpoint.items():
|
| 118 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 119 |
+
if re.search(r'\.(\d+)\.', key):
|
| 120 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 121 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 122 |
+
lora_state_dicts[idx][key] = value
|
| 123 |
+
|
| 124 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 125 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 129 |
+
num = 0
|
| 130 |
+
for idx in range(ck_number):
|
| 131 |
+
for n in range(cond_lora_number[idx]):
|
| 132 |
+
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
|
| 133 |
+
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
|
| 134 |
+
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
|
| 135 |
+
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
|
| 136 |
+
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
|
| 137 |
+
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
|
| 138 |
+
lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 139 |
+
lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 140 |
+
lora_attn_procs[name].to(device)
|
| 141 |
+
num += 1
|
| 142 |
+
|
| 143 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 144 |
+
|
| 145 |
+
lora_state_dicts = [{} for _ in range(ck_number)]
|
| 146 |
+
for idx, checkpoint in enumerate(checkpoints):
|
| 147 |
+
for key, value in checkpoint.items():
|
| 148 |
+
# Match based on the layer index in the key (assuming the key contains layer index)
|
| 149 |
+
if re.search(r'\.(\d+)\.', key):
|
| 150 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 151 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 152 |
+
lora_state_dicts[idx][key] = value
|
| 153 |
+
|
| 154 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 155 |
+
dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
|
| 156 |
+
)
|
| 157 |
+
# Load the weights from the checkpoint dictionary into the corresponding layers
|
| 158 |
+
num = 0
|
| 159 |
+
for idx in range(ck_number):
|
| 160 |
+
for n in range(cond_lora_number[idx]):
|
| 161 |
+
lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
|
| 162 |
+
lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
|
| 163 |
+
lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
|
| 164 |
+
lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
|
| 165 |
+
lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
|
| 166 |
+
lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
|
| 167 |
+
lora_attn_procs[name].to(device)
|
| 168 |
+
num += 1
|
| 169 |
+
|
| 170 |
+
else:
|
| 171 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 172 |
+
|
| 173 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
|
| 177 |
+
checkpoint = load_checkpoint(local_path)
|
| 178 |
+
update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
|
| 179 |
+
|
| 180 |
+
def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
|
| 181 |
+
checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
|
| 182 |
+
update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
|
| 183 |
+
|
| 184 |
+
def unset_lora(transformer):
|
| 185 |
+
lora_attn_procs = {}
|
| 186 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 187 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 188 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
'''
|
| 192 |
+
unset_lora(pipe.transformer)
|
| 193 |
+
lora_path = "./lora.safetensors"
|
| 194 |
+
lora_weights = [1, 1]
|
| 195 |
+
set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
|
| 196 |
+
'''
|
train/src/masks_integrated.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import logging
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
LOGGER = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class LinearRamp:
|
| 13 |
+
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
|
| 14 |
+
self.start_value = start_value
|
| 15 |
+
self.end_value = end_value
|
| 16 |
+
self.start_iter = start_iter
|
| 17 |
+
self.end_iter = end_iter
|
| 18 |
+
|
| 19 |
+
def __call__(self, i):
|
| 20 |
+
if i < self.start_iter:
|
| 21 |
+
return self.start_value
|
| 22 |
+
if i >= self.end_iter:
|
| 23 |
+
return self.end_value
|
| 24 |
+
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
|
| 25 |
+
return self.start_value * (1 - part) + self.end_value * part
|
| 26 |
+
|
| 27 |
+
class DrawMethod(Enum):
|
| 28 |
+
LINE = 'line'
|
| 29 |
+
CIRCLE = 'circle'
|
| 30 |
+
SQUARE = 'square'
|
| 31 |
+
|
| 32 |
+
def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
|
| 33 |
+
draw_method=DrawMethod.LINE):
|
| 34 |
+
"""生成不规则mask - 基于角度和长度的线条"""
|
| 35 |
+
draw_method = DrawMethod(draw_method)
|
| 36 |
+
|
| 37 |
+
height, width = shape
|
| 38 |
+
mask = np.zeros((height, width), np.float32)
|
| 39 |
+
times = np.random.randint(min_times, max_times + 1)
|
| 40 |
+
for i in range(times):
|
| 41 |
+
start_x = np.random.randint(width)
|
| 42 |
+
start_y = np.random.randint(height)
|
| 43 |
+
for j in range(1 + np.random.randint(5)):
|
| 44 |
+
angle = 0.01 + np.random.randint(max_angle)
|
| 45 |
+
if i % 2 == 0:
|
| 46 |
+
angle = 2 * 3.1415926 - angle
|
| 47 |
+
length = 10 + np.random.randint(max_len)
|
| 48 |
+
brush_w = 5 + np.random.randint(max_width)
|
| 49 |
+
end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
|
| 50 |
+
end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
|
| 51 |
+
if draw_method == DrawMethod.LINE:
|
| 52 |
+
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
|
| 53 |
+
elif draw_method == DrawMethod.CIRCLE:
|
| 54 |
+
cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
|
| 55 |
+
elif draw_method == DrawMethod.SQUARE:
|
| 56 |
+
radius = brush_w // 2
|
| 57 |
+
mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
|
| 58 |
+
start_x, start_y = end_x, end_y
|
| 59 |
+
return mask[None, ...]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
|
| 63 |
+
"""生成随机矩形mask"""
|
| 64 |
+
height, width = shape
|
| 65 |
+
mask = np.zeros((height, width), np.float32)
|
| 66 |
+
bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
|
| 67 |
+
times = np.random.randint(min_times, max_times + 1)
|
| 68 |
+
for i in range(times):
|
| 69 |
+
box_width = np.random.randint(bbox_min_size, bbox_max_size)
|
| 70 |
+
box_height = np.random.randint(bbox_min_size, bbox_max_size)
|
| 71 |
+
start_x = np.random.randint(margin, width - margin - box_width + 1)
|
| 72 |
+
start_y = np.random.randint(margin, height - margin - box_height + 1)
|
| 73 |
+
mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
|
| 74 |
+
return mask[None, ...]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
|
| 78 |
+
"""生成超分辨率风格的规则网格mask"""
|
| 79 |
+
height, width = shape
|
| 80 |
+
mask = np.zeros((height, width), np.float32)
|
| 81 |
+
step_x = np.random.randint(min_step, max_step + 1)
|
| 82 |
+
width_x = np.random.randint(min_width, min(step_x, max_width + 1))
|
| 83 |
+
offset_x = np.random.randint(0, step_x)
|
| 84 |
+
|
| 85 |
+
step_y = np.random.randint(min_step, max_step + 1)
|
| 86 |
+
width_y = np.random.randint(min_width, min(step_y, max_width + 1))
|
| 87 |
+
offset_y = np.random.randint(0, step_y)
|
| 88 |
+
|
| 89 |
+
for dy in range(width_y):
|
| 90 |
+
mask[offset_y + dy::step_y] = 1
|
| 91 |
+
for dx in range(width_x):
|
| 92 |
+
mask[:, offset_x + dx::step_x] = 1
|
| 93 |
+
return mask[None, ...]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def make_brush_stroke_mask(shape, num_strokes_range=(1, 5), stroke_width_range=(5, 30),
|
| 97 |
+
max_offset=50, num_points_range=(5, 15)):
|
| 98 |
+
"""生成笔刷描边样式的mask - 基于随机偏移的连续线条"""
|
| 99 |
+
num_strokes = random.randint(*num_strokes_range)
|
| 100 |
+
height, width = shape
|
| 101 |
+
mask = np.zeros((height, width), dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
for _ in range(num_strokes):
|
| 104 |
+
# 随机起点
|
| 105 |
+
start_x = random.randint(0, width)
|
| 106 |
+
start_y = random.randint(0, height)
|
| 107 |
+
|
| 108 |
+
# 随机描边参数
|
| 109 |
+
num_points = random.randint(*num_points_range)
|
| 110 |
+
stroke_width = random.randint(*stroke_width_range)
|
| 111 |
+
|
| 112 |
+
points = [(start_x, start_y)]
|
| 113 |
+
|
| 114 |
+
# 生成连续的点
|
| 115 |
+
for i in range(num_points):
|
| 116 |
+
prev_x, prev_y = points[-1]
|
| 117 |
+
# 添加随机偏移
|
| 118 |
+
dx = random.randint(-max_offset, max_offset)
|
| 119 |
+
dy = random.randint(-max_offset, max_offset)
|
| 120 |
+
new_x = max(0, min(width, prev_x + dx))
|
| 121 |
+
new_y = max(0, min(height, prev_y + dy))
|
| 122 |
+
points.append((new_x, new_y))
|
| 123 |
+
|
| 124 |
+
# 绘制描边
|
| 125 |
+
for i in range(len(points) - 1):
|
| 126 |
+
cv2.line(mask, points[i], points[i+1], 1.0, stroke_width)
|
| 127 |
+
|
| 128 |
+
return mask[None, ...]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class RandomIrregularMaskGenerator:
|
| 132 |
+
"""不规则mask生成器"""
|
| 133 |
+
def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
|
| 134 |
+
draw_method=DrawMethod.LINE):
|
| 135 |
+
self.max_angle = max_angle
|
| 136 |
+
self.max_len = max_len
|
| 137 |
+
self.max_width = max_width
|
| 138 |
+
self.min_times = min_times
|
| 139 |
+
self.max_times = max_times
|
| 140 |
+
self.draw_method = draw_method
|
| 141 |
+
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
|
| 142 |
+
|
| 143 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
| 144 |
+
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
|
| 145 |
+
cur_max_len = int(max(1, self.max_len * coef))
|
| 146 |
+
cur_max_width = int(max(1, self.max_width * coef))
|
| 147 |
+
cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
|
| 148 |
+
return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
|
| 149 |
+
max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
|
| 150 |
+
draw_method=self.draw_method)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class RandomRectangleMaskGenerator:
|
| 154 |
+
"""矩形mask生成器"""
|
| 155 |
+
def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
|
| 156 |
+
self.margin = margin
|
| 157 |
+
self.bbox_min_size = bbox_min_size
|
| 158 |
+
self.bbox_max_size = bbox_max_size
|
| 159 |
+
self.min_times = min_times
|
| 160 |
+
self.max_times = max_times
|
| 161 |
+
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
|
| 162 |
+
|
| 163 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
| 164 |
+
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
|
| 165 |
+
cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
|
| 166 |
+
cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
|
| 167 |
+
return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
|
| 168 |
+
bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
|
| 169 |
+
max_times=cur_max_times)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class RandomSuperresMaskGenerator:
|
| 173 |
+
"""超分辨率mask生成器"""
|
| 174 |
+
def __init__(self, **kwargs):
|
| 175 |
+
self.kwargs = kwargs
|
| 176 |
+
|
| 177 |
+
def __call__(self, img, iter_i=None):
|
| 178 |
+
return make_random_superres_mask(img.shape[1:], **self.kwargs)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class BrushStrokeMaskGenerator:
|
| 182 |
+
"""笔刷描边mask生成器"""
|
| 183 |
+
def __init__(self, num_strokes_range=(1, 5), stroke_width_range=(5, 30),
|
| 184 |
+
max_offset=50, num_points_range=(5, 15), ramp_kwargs=None):
|
| 185 |
+
self.num_strokes_range = num_strokes_range
|
| 186 |
+
self.stroke_width_range = stroke_width_range
|
| 187 |
+
self.max_offset = max_offset
|
| 188 |
+
self.num_points_range = num_points_range
|
| 189 |
+
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
|
| 190 |
+
|
| 191 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
| 192 |
+
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
|
| 193 |
+
cur_num_strokes = int(max(1, self.num_strokes_range[1] * coef))
|
| 194 |
+
cur_stroke_width_range = (
|
| 195 |
+
int(max(1, self.stroke_width_range[0] * coef)),
|
| 196 |
+
int(max(1, self.stroke_width_range[1] * coef))
|
| 197 |
+
)
|
| 198 |
+
return make_brush_stroke_mask(
|
| 199 |
+
img.shape[1:],
|
| 200 |
+
num_strokes_range=(cur_num_strokes, cur_num_strokes),
|
| 201 |
+
stroke_width_range=cur_stroke_width_range,
|
| 202 |
+
max_offset=self.max_offset,
|
| 203 |
+
num_points_range=self.num_points_range
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class DumbAreaMaskGenerator:
|
| 208 |
+
"""简单区域mask生成器"""
|
| 209 |
+
min_ratio = 0.1
|
| 210 |
+
max_ratio = 0.35
|
| 211 |
+
default_ratio = 0.225
|
| 212 |
+
|
| 213 |
+
def __init__(self, is_training):
|
| 214 |
+
#Parameters:
|
| 215 |
+
# is_training(bool): If true - random rectangular mask, if false - central square mask
|
| 216 |
+
self.is_training = is_training
|
| 217 |
+
|
| 218 |
+
def _random_vector(self, dimension):
|
| 219 |
+
if self.is_training:
|
| 220 |
+
lower_limit = math.sqrt(self.min_ratio)
|
| 221 |
+
upper_limit = math.sqrt(self.max_ratio)
|
| 222 |
+
mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
|
| 223 |
+
u = random.randint(0, dimension-mask_side-1)
|
| 224 |
+
v = u+mask_side
|
| 225 |
+
else:
|
| 226 |
+
margin = (math.sqrt(self.default_ratio) / 2) * dimension
|
| 227 |
+
u = round(dimension/2 - margin)
|
| 228 |
+
v = round(dimension/2 + margin)
|
| 229 |
+
return u, v
|
| 230 |
+
|
| 231 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
| 232 |
+
c, height, width = img.shape
|
| 233 |
+
mask = np.zeros((height, width), np.float32)
|
| 234 |
+
x1, x2 = self._random_vector(width)
|
| 235 |
+
y1, y2 = self._random_vector(height)
|
| 236 |
+
mask[x1:x2, y1:y2] = 1
|
| 237 |
+
return mask[None, ...]
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class IntegratedMaskGenerator:
|
| 241 |
+
"""集成的mask生成器 - 支持多种mask类型混合"""
|
| 242 |
+
def __init__(self, irregular_proba=1/4, irregular_kwargs=None,
|
| 243 |
+
box_proba=1/4, box_kwargs=None,
|
| 244 |
+
segm_proba=1/4, segm_kwargs=None,
|
| 245 |
+
brush_stroke_proba=1/4, brush_stroke_kwargs=None,
|
| 246 |
+
superres_proba=0, superres_kwargs=None,
|
| 247 |
+
squares_proba=0, squares_kwargs=None,
|
| 248 |
+
invert_proba=0):
|
| 249 |
+
self.probas = []
|
| 250 |
+
self.gens = []
|
| 251 |
+
|
| 252 |
+
if irregular_proba > 0:
|
| 253 |
+
self.probas.append(irregular_proba)
|
| 254 |
+
if irregular_kwargs is None:
|
| 255 |
+
irregular_kwargs = {}
|
| 256 |
+
else:
|
| 257 |
+
irregular_kwargs = dict(irregular_kwargs)
|
| 258 |
+
irregular_kwargs['draw_method'] = DrawMethod.LINE
|
| 259 |
+
self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
|
| 260 |
+
|
| 261 |
+
if box_proba > 0:
|
| 262 |
+
self.probas.append(box_proba)
|
| 263 |
+
if box_kwargs is None:
|
| 264 |
+
box_kwargs = {}
|
| 265 |
+
self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
|
| 266 |
+
|
| 267 |
+
if brush_stroke_proba > 0:
|
| 268 |
+
self.probas.append(brush_stroke_proba)
|
| 269 |
+
if brush_stroke_kwargs is None:
|
| 270 |
+
brush_stroke_kwargs = {}
|
| 271 |
+
self.gens.append(BrushStrokeMaskGenerator(**brush_stroke_kwargs))
|
| 272 |
+
|
| 273 |
+
if superres_proba > 0:
|
| 274 |
+
self.probas.append(superres_proba)
|
| 275 |
+
if superres_kwargs is None:
|
| 276 |
+
superres_kwargs = {}
|
| 277 |
+
self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
|
| 278 |
+
|
| 279 |
+
if squares_proba > 0:
|
| 280 |
+
self.probas.append(squares_proba)
|
| 281 |
+
if squares_kwargs is None:
|
| 282 |
+
squares_kwargs = {}
|
| 283 |
+
else:
|
| 284 |
+
squares_kwargs = dict(squares_kwargs)
|
| 285 |
+
squares_kwargs['draw_method'] = DrawMethod.SQUARE
|
| 286 |
+
self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
|
| 287 |
+
|
| 288 |
+
self.probas = np.array(self.probas, dtype='float32')
|
| 289 |
+
self.probas /= self.probas.sum()
|
| 290 |
+
self.invert_proba = invert_proba
|
| 291 |
+
|
| 292 |
+
def __call__(self, img, iter_i=None, raw_image=None):
|
| 293 |
+
kind = np.random.choice(len(self.probas), p=self.probas)
|
| 294 |
+
gen = self.gens[kind]
|
| 295 |
+
result = gen(img, iter_i=iter_i, raw_image=raw_image)
|
| 296 |
+
if self.invert_proba > 0 and random.random() < self.invert_proba:
|
| 297 |
+
result = 1 - result
|
| 298 |
+
return result
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_mask_generator(kind, kwargs):
|
| 302 |
+
"""获取mask生成器的工厂函数"""
|
| 303 |
+
if kind is None:
|
| 304 |
+
kind = "integrated"
|
| 305 |
+
if kwargs is None:
|
| 306 |
+
kwargs = {}
|
| 307 |
+
|
| 308 |
+
if kind == "integrated":
|
| 309 |
+
cl = IntegratedMaskGenerator
|
| 310 |
+
elif kind == "irregular":
|
| 311 |
+
cl = RandomIrregularMaskGenerator
|
| 312 |
+
elif kind == "rectangle":
|
| 313 |
+
cl = RandomRectangleMaskGenerator
|
| 314 |
+
elif kind == "brush_stroke":
|
| 315 |
+
cl = BrushStrokeMaskGenerator
|
| 316 |
+
elif kind == "superres":
|
| 317 |
+
cl = RandomSuperresMaskGenerator
|
| 318 |
+
elif kind == "dumb":
|
| 319 |
+
cl = DumbAreaMaskGenerator
|
| 320 |
+
else:
|
| 321 |
+
raise NotImplementedError(f"No such generator kind = {kind}")
|
| 322 |
+
return cl(**kwargs)
|
train/src/pipeline_flux_kontext_control.py
ADDED
|
@@ -0,0 +1,1009 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from .transformer_flux import FluxTransformer2DModel
|
| 7 |
+
from transformers import (
|
| 8 |
+
CLIPImageProcessor,
|
| 9 |
+
CLIPTextModel,
|
| 10 |
+
CLIPTokenizer,
|
| 11 |
+
CLIPVisionModelWithProjection,
|
| 12 |
+
T5EncoderModel,
|
| 13 |
+
T5TokenizerFast,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 17 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 18 |
+
from diffusers.models import AutoencoderKL
|
| 19 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 20 |
+
from diffusers.utils import (
|
| 21 |
+
USE_PEFT_BACKEND,
|
| 22 |
+
is_torch_xla_available,
|
| 23 |
+
logging,
|
| 24 |
+
replace_example_docstring,
|
| 25 |
+
scale_lora_layers,
|
| 26 |
+
unscale_lora_layers,
|
| 27 |
+
)
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 30 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 31 |
+
from torchvision.transforms.functional import pad
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 45 |
+
(672, 1568),
|
| 46 |
+
(688, 1504),
|
| 47 |
+
(720, 1456),
|
| 48 |
+
(752, 1392),
|
| 49 |
+
(800, 1328),
|
| 50 |
+
(832, 1248),
|
| 51 |
+
(880, 1184),
|
| 52 |
+
(944, 1104),
|
| 53 |
+
(1024, 1024),
|
| 54 |
+
(1104, 944),
|
| 55 |
+
(1184, 880),
|
| 56 |
+
(1248, 832),
|
| 57 |
+
(1328, 800),
|
| 58 |
+
(1392, 752),
|
| 59 |
+
(1456, 720),
|
| 60 |
+
(1504, 688),
|
| 61 |
+
(1568, 672),
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def calculate_shift(
|
| 66 |
+
image_seq_len,
|
| 67 |
+
base_seq_len: int = 256,
|
| 68 |
+
max_seq_len: int = 4096,
|
| 69 |
+
base_shift: float = 0.5,
|
| 70 |
+
max_shift: float = 1.15,
|
| 71 |
+
):
|
| 72 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 73 |
+
b = base_shift - m * base_seq_len
|
| 74 |
+
mu = image_seq_len * m + b
|
| 75 |
+
return mu
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def prepare_latent_image_ids_(height, width, device, dtype):
|
| 79 |
+
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
| 80 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y
|
| 81 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] # x
|
| 82 |
+
return latent_image_ids
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def prepare_latent_subject_ids(height, width, device, dtype):
|
| 86 |
+
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
|
| 87 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]
|
| 88 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]
|
| 89 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 90 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 91 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 92 |
+
)
|
| 93 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def resize_position_encoding(
|
| 97 |
+
batch_size, original_height, original_width, target_height, target_width, device, dtype
|
| 98 |
+
):
|
| 99 |
+
latent_image_ids = prepare_latent_image_ids_(original_height // 2, original_width // 2, device, dtype)
|
| 100 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 101 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 102 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
scale_h = original_height / target_height
|
| 106 |
+
scale_w = original_width / target_width
|
| 107 |
+
latent_image_ids_resized = torch.zeros(target_height // 2, target_width // 2, 3, device=device, dtype=dtype)
|
| 108 |
+
latent_image_ids_resized[..., 1] = (
|
| 109 |
+
latent_image_ids_resized[..., 1] + torch.arange(target_height // 2, device=device)[:, None] * scale_h
|
| 110 |
+
)
|
| 111 |
+
latent_image_ids_resized[..., 2] = (
|
| 112 |
+
latent_image_ids_resized[..., 2] + torch.arange(target_width // 2, device=device)[None, :] * scale_w
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = (
|
| 116 |
+
latent_image_ids_resized.shape
|
| 117 |
+
)
|
| 118 |
+
cond_latent_image_ids = latent_image_ids_resized.reshape(
|
| 119 |
+
cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
|
| 120 |
+
)
|
| 121 |
+
return latent_image_ids, cond_latent_image_ids
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 125 |
+
def retrieve_timesteps(
|
| 126 |
+
scheduler,
|
| 127 |
+
num_inference_steps: Optional[int] = None,
|
| 128 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 129 |
+
timesteps: Optional[List[int]] = None,
|
| 130 |
+
sigmas: Optional[List[float]] = None,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
r"""
|
| 134 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 135 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
scheduler (`SchedulerMixin`):
|
| 139 |
+
The scheduler to get timesteps from.
|
| 140 |
+
num_inference_steps (`int`):
|
| 141 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 142 |
+
must be `None`.
|
| 143 |
+
device (`str` or `torch.device`, *optional*):
|
| 144 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 145 |
+
timesteps (`List[int]`, *optional*):
|
| 146 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 147 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 148 |
+
sigmas (`List[float]`, *optional*):
|
| 149 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 150 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 154 |
+
second element is the number of inference steps.
|
| 155 |
+
"""
|
| 156 |
+
if timesteps is not None and sigmas is not None:
|
| 157 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 158 |
+
if timesteps is not None:
|
| 159 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 160 |
+
if not accepts_timesteps:
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 163 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 164 |
+
)
|
| 165 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 166 |
+
timesteps = scheduler.timesteps
|
| 167 |
+
num_inference_steps = len(timesteps)
|
| 168 |
+
elif sigmas is not None:
|
| 169 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 170 |
+
if not accept_sigmas:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 173 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 174 |
+
)
|
| 175 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 176 |
+
timesteps = scheduler.timesteps
|
| 177 |
+
num_inference_steps = len(timesteps)
|
| 178 |
+
else:
|
| 179 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 180 |
+
timesteps = scheduler.timesteps
|
| 181 |
+
return timesteps, num_inference_steps
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 185 |
+
def retrieve_latents(
|
| 186 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 187 |
+
):
|
| 188 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 189 |
+
return encoder_output.latent_dist.sample(generator)
|
| 190 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 191 |
+
return encoder_output.latent_dist.mode()
|
| 192 |
+
elif hasattr(encoder_output, "latents"):
|
| 193 |
+
return encoder_output.latents
|
| 194 |
+
else:
|
| 195 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class FluxKontextControlPipeline(
|
| 199 |
+
DiffusionPipeline,
|
| 200 |
+
FluxLoraLoaderMixin,
|
| 201 |
+
FromSingleFileMixin,
|
| 202 |
+
TextualInversionLoaderMixin,
|
| 203 |
+
):
|
| 204 |
+
r"""
|
| 205 |
+
The Flux Kontext pipeline for image-to-image and text-to-image generation with EasyControl.
|
| 206 |
+
|
| 207 |
+
Reference: https://bfl.ai/announcements/flux-1-kontext-dev
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 211 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 212 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 213 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 214 |
+
vae ([`AutoencoderKL`]):
|
| 215 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 216 |
+
text_encoder ([`CLIPTextModel`]):
|
| 217 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 218 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 219 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 220 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 221 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 222 |
+
tokenizer (`CLIPTokenizer`):
|
| 223 |
+
Tokenizer of class
|
| 224 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 225 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 226 |
+
Second Tokenizer of class
|
| 227 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 231 |
+
_optional_components = []
|
| 232 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 233 |
+
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 237 |
+
vae: AutoencoderKL,
|
| 238 |
+
text_encoder: CLIPTextModel,
|
| 239 |
+
tokenizer: CLIPTokenizer,
|
| 240 |
+
text_encoder_2: T5EncoderModel,
|
| 241 |
+
tokenizer_2: T5TokenizerFast,
|
| 242 |
+
transformer: FluxTransformer2DModel,
|
| 243 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 244 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 245 |
+
):
|
| 246 |
+
super().__init__()
|
| 247 |
+
|
| 248 |
+
self.register_modules(
|
| 249 |
+
vae=vae,
|
| 250 |
+
text_encoder=text_encoder,
|
| 251 |
+
text_encoder_2=text_encoder_2,
|
| 252 |
+
tokenizer=tokenizer,
|
| 253 |
+
tokenizer_2=tokenizer_2,
|
| 254 |
+
transformer=transformer,
|
| 255 |
+
scheduler=scheduler,
|
| 256 |
+
image_encoder=None,
|
| 257 |
+
feature_extractor=None,
|
| 258 |
+
)
|
| 259 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 260 |
+
# Flux latents are packed into 2x2 patches, so use VAE factor multiplied by patch size for image processing
|
| 261 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 262 |
+
self.tokenizer_max_length = (
|
| 263 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 264 |
+
)
|
| 265 |
+
self.default_sample_size = 128
|
| 266 |
+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
| 267 |
+
# EasyControl: cache multiple control LoRA processors
|
| 268 |
+
self.control_lora_processors: Dict[str, Dict[str, Any]] = {}
|
| 269 |
+
self.control_lora_cond_sizes: Dict[str, Any] = {}
|
| 270 |
+
self.current_control_type: Optional[str] = None
|
| 271 |
+
|
| 272 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
| 273 |
+
def _get_t5_prompt_embeds(
|
| 274 |
+
self,
|
| 275 |
+
prompt: Union[str, List[str]] = None,
|
| 276 |
+
num_images_per_prompt: int = 1,
|
| 277 |
+
max_sequence_length: int = 512,
|
| 278 |
+
device: Optional[torch.device] = None,
|
| 279 |
+
dtype: Optional[torch.dtype] = None,
|
| 280 |
+
):
|
| 281 |
+
device = device or self._execution_device
|
| 282 |
+
dtype = dtype or self.text_encoder.dtype
|
| 283 |
+
|
| 284 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 285 |
+
batch_size = len(prompt)
|
| 286 |
+
|
| 287 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 288 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
| 289 |
+
|
| 290 |
+
text_inputs = self.tokenizer_2(
|
| 291 |
+
prompt,
|
| 292 |
+
padding="max_length",
|
| 293 |
+
max_length=max_sequence_length,
|
| 294 |
+
truncation=True,
|
| 295 |
+
return_length=False,
|
| 296 |
+
return_overflowing_tokens=False,
|
| 297 |
+
return_tensors="pt",
|
| 298 |
+
)
|
| 299 |
+
text_input_ids = text_inputs.input_ids
|
| 300 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 301 |
+
|
| 302 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 303 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 304 |
+
logger.warning(
|
| 305 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 306 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 310 |
+
|
| 311 |
+
dtype = self.text_encoder_2.dtype
|
| 312 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 313 |
+
|
| 314 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 315 |
+
|
| 316 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 317 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 318 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 319 |
+
|
| 320 |
+
return prompt_embeds
|
| 321 |
+
|
| 322 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
| 323 |
+
def _get_clip_prompt_embeds(
|
| 324 |
+
self,
|
| 325 |
+
prompt: Union[str, List[str]],
|
| 326 |
+
num_images_per_prompt: int = 1,
|
| 327 |
+
device: Optional[torch.device] = None,
|
| 328 |
+
):
|
| 329 |
+
device = device or self._execution_device
|
| 330 |
+
|
| 331 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 332 |
+
batch_size = len(prompt)
|
| 333 |
+
|
| 334 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 335 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 336 |
+
|
| 337 |
+
text_inputs = self.tokenizer(
|
| 338 |
+
prompt,
|
| 339 |
+
padding="max_length",
|
| 340 |
+
max_length=self.tokenizer_max_length,
|
| 341 |
+
truncation=True,
|
| 342 |
+
return_overflowing_tokens=False,
|
| 343 |
+
return_length=False,
|
| 344 |
+
return_tensors="pt",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
text_input_ids = text_inputs.input_ids
|
| 348 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 349 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 350 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 351 |
+
logger.warning(
|
| 352 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 353 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 354 |
+
)
|
| 355 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 356 |
+
|
| 357 |
+
# Use pooled output of CLIPTextModel
|
| 358 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 359 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 360 |
+
|
| 361 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 362 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 363 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 364 |
+
|
| 365 |
+
return prompt_embeds
|
| 366 |
+
|
| 367 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
| 368 |
+
def encode_prompt(
|
| 369 |
+
self,
|
| 370 |
+
prompt: Union[str, List[str]],
|
| 371 |
+
prompt_2: Union[str, List[str]],
|
| 372 |
+
device: Optional[torch.device] = None,
|
| 373 |
+
num_images_per_prompt: int = 1,
|
| 374 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 375 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 376 |
+
max_sequence_length: int = 512,
|
| 377 |
+
lora_scale: Optional[float] = None,
|
| 378 |
+
):
|
| 379 |
+
r"""
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 383 |
+
prompt to be encoded
|
| 384 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 385 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 386 |
+
used in all text-encoders
|
| 387 |
+
device: (`torch.device`):
|
| 388 |
+
torch device
|
| 389 |
+
num_images_per_prompt (`int`):
|
| 390 |
+
number of images that should be generated per prompt
|
| 391 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 392 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 393 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 394 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 395 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 396 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 397 |
+
lora_scale (`float`, *optional*):
|
| 398 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 399 |
+
"""
|
| 400 |
+
device = device or self._execution_device
|
| 401 |
+
|
| 402 |
+
# set lora scale so that monkey patched LoRA
|
| 403 |
+
# function of text encoder can correctly access it
|
| 404 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 405 |
+
self._lora_scale = lora_scale
|
| 406 |
+
|
| 407 |
+
# dynamically adjust the LoRA scale
|
| 408 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 409 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 410 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 411 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 412 |
+
|
| 413 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 414 |
+
|
| 415 |
+
if prompt_embeds is None:
|
| 416 |
+
prompt_2 = prompt_2 or prompt
|
| 417 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 418 |
+
|
| 419 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 420 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 421 |
+
prompt=prompt,
|
| 422 |
+
device=device,
|
| 423 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 424 |
+
)
|
| 425 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 426 |
+
prompt=prompt_2,
|
| 427 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 428 |
+
max_sequence_length=max_sequence_length,
|
| 429 |
+
device=device,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if self.text_encoder is not None:
|
| 433 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 434 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 435 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 436 |
+
|
| 437 |
+
if self.text_encoder_2 is not None:
|
| 438 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 439 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 440 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 441 |
+
|
| 442 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 443 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 444 |
+
|
| 445 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 446 |
+
|
| 447 |
+
# Adapted from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
|
| 448 |
+
def check_inputs(
|
| 449 |
+
self,
|
| 450 |
+
prompt,
|
| 451 |
+
prompt_2,
|
| 452 |
+
height,
|
| 453 |
+
width,
|
| 454 |
+
prompt_embeds=None,
|
| 455 |
+
pooled_prompt_embeds=None,
|
| 456 |
+
callback_on_step_end_tensor_inputs=None,
|
| 457 |
+
max_sequence_length=None,
|
| 458 |
+
):
|
| 459 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 465 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 466 |
+
):
|
| 467 |
+
raise ValueError(
|
| 468 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
if prompt is not None and prompt_embeds is not None:
|
| 472 |
+
raise ValueError(
|
| 473 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 474 |
+
" only forward one of the two."
|
| 475 |
+
)
|
| 476 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 479 |
+
" only forward one of the two."
|
| 480 |
+
)
|
| 481 |
+
elif prompt is None and prompt_embeds is None:
|
| 482 |
+
raise ValueError(
|
| 483 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 484 |
+
)
|
| 485 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 486 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 487 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 488 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 489 |
+
|
| 490 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 491 |
+
raise ValueError(
|
| 492 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 496 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
| 500 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 501 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 502 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 503 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 504 |
+
|
| 505 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 506 |
+
|
| 507 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 508 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 512 |
+
|
| 513 |
+
@staticmethod
|
| 514 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
| 515 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 516 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 517 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 518 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 519 |
+
|
| 520 |
+
return latents
|
| 521 |
+
|
| 522 |
+
@staticmethod
|
| 523 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
| 524 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 525 |
+
batch_size, num_patches, channels = latents.shape
|
| 526 |
+
|
| 527 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 528 |
+
# latent height and width to be divisible by 2.
|
| 529 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 530 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 531 |
+
|
| 532 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 533 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 534 |
+
|
| 535 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 536 |
+
|
| 537 |
+
return latents
|
| 538 |
+
|
| 539 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 540 |
+
if isinstance(generator, list):
|
| 541 |
+
image_latents = [
|
| 542 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 543 |
+
for i in range(image.shape[0])
|
| 544 |
+
]
|
| 545 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 546 |
+
else:
|
| 547 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 548 |
+
|
| 549 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 550 |
+
|
| 551 |
+
return image_latents
|
| 552 |
+
|
| 553 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
|
| 554 |
+
def enable_vae_slicing(self):
|
| 555 |
+
r"""
|
| 556 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 557 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 558 |
+
"""
|
| 559 |
+
self.vae.enable_slicing()
|
| 560 |
+
|
| 561 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
|
| 562 |
+
def disable_vae_slicing(self):
|
| 563 |
+
r"""
|
| 564 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 565 |
+
computing decoding in one step.
|
| 566 |
+
"""
|
| 567 |
+
self.vae.disable_slicing()
|
| 568 |
+
|
| 569 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
|
| 570 |
+
def enable_vae_tiling(self):
|
| 571 |
+
r"""
|
| 572 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 573 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 574 |
+
processing larger images.
|
| 575 |
+
"""
|
| 576 |
+
self.vae.enable_tiling()
|
| 577 |
+
|
| 578 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
| 579 |
+
def disable_vae_tiling(self):
|
| 580 |
+
r"""
|
| 581 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 582 |
+
computing decoding in one step.
|
| 583 |
+
"""
|
| 584 |
+
self.vae.disable_tiling()
|
| 585 |
+
|
| 586 |
+
def prepare_latents(
|
| 587 |
+
self,
|
| 588 |
+
batch_size,
|
| 589 |
+
num_channels_latents,
|
| 590 |
+
height,
|
| 591 |
+
width,
|
| 592 |
+
dtype,
|
| 593 |
+
device,
|
| 594 |
+
generator,
|
| 595 |
+
image,
|
| 596 |
+
subject_images,
|
| 597 |
+
spatial_images,
|
| 598 |
+
latents=None,
|
| 599 |
+
cond_size=512,
|
| 600 |
+
):
|
| 601 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 602 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 603 |
+
height_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
|
| 604 |
+
width_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
|
| 605 |
+
|
| 606 |
+
image_latents = image_ids = None
|
| 607 |
+
# Prepare noise latents
|
| 608 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 609 |
+
if latents is None:
|
| 610 |
+
noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 611 |
+
else:
|
| 612 |
+
noise_latents = latents.to(device=device, dtype=dtype)
|
| 613 |
+
|
| 614 |
+
noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
|
| 615 |
+
# print(noise_latents.shape)
|
| 616 |
+
noise_latent_image_ids, cond_latent_image_ids_resized = resize_position_encoding(
|
| 617 |
+
batch_size, height, width, height_cond, width_cond, device, dtype
|
| 618 |
+
)
|
| 619 |
+
# noise IDs are marked with 0 in the first channel
|
| 620 |
+
noise_latent_image_ids[..., 0] = 0
|
| 621 |
+
|
| 622 |
+
cond_latents_to_concat = []
|
| 623 |
+
latents_ids_to_concat = [noise_latent_image_ids]
|
| 624 |
+
|
| 625 |
+
# 1. Prepare `image` (Kontext) latents
|
| 626 |
+
if image is not None:
|
| 627 |
+
image = image.to(device=device, dtype=dtype)
|
| 628 |
+
if image.shape[1] != self.latent_channels:
|
| 629 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 630 |
+
else:
|
| 631 |
+
image_latents = image
|
| 632 |
+
|
| 633 |
+
image_latent_h, image_latent_w = image_latents.shape[2:]
|
| 634 |
+
image_latents = self._pack_latents(
|
| 635 |
+
image_latents, batch_size, num_channels_latents, image_latent_h, image_latent_w
|
| 636 |
+
)
|
| 637 |
+
image_ids = self._prepare_latent_image_ids(
|
| 638 |
+
batch_size, image_latent_h // 2, image_latent_w // 2, device, dtype
|
| 639 |
+
)
|
| 640 |
+
image_ids[..., 0] = 1 # Mark as condition
|
| 641 |
+
latents_ids_to_concat.append(image_ids)
|
| 642 |
+
|
| 643 |
+
# 2. Prepare `subject_images` latents
|
| 644 |
+
if subject_images is not None and len(subject_images) > 0:
|
| 645 |
+
subject_images = subject_images.to(device=device, dtype=dtype)
|
| 646 |
+
subject_image_latents = self._encode_vae_image(image=subject_images, generator=generator)
|
| 647 |
+
subject_latents = self._pack_latents(
|
| 648 |
+
subject_image_latents, batch_size, num_channels_latents, height_cond * len(subject_images), width_cond
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, device, dtype)
|
| 652 |
+
latent_subject_ids[..., 0] = 1
|
| 653 |
+
latent_subject_ids[:, 1] += image_latent_h // 2
|
| 654 |
+
subject_latent_image_ids = torch.cat([latent_subject_ids for _ in range(len(subject_images))], dim=0)
|
| 655 |
+
|
| 656 |
+
cond_latents_to_concat.append(subject_latents)
|
| 657 |
+
latents_ids_to_concat.append(subject_latent_image_ids)
|
| 658 |
+
|
| 659 |
+
# 3. Prepare `spatial_images` latents
|
| 660 |
+
if spatial_images is not None and len(spatial_images) > 0:
|
| 661 |
+
spatial_images = spatial_images.to(device=device, dtype=dtype)
|
| 662 |
+
spatial_image_latents = self._encode_vae_image(image=spatial_images, generator=generator)
|
| 663 |
+
cond_latents = self._pack_latents(
|
| 664 |
+
spatial_image_latents, batch_size, num_channels_latents, height_cond * len(spatial_images), width_cond
|
| 665 |
+
)
|
| 666 |
+
cond_latent_image_ids_resized[..., 0] = 2
|
| 667 |
+
cond_latent_image_ids = torch.cat(
|
| 668 |
+
[cond_latent_image_ids_resized for _ in range(len(spatial_images))], dim=0
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
cond_latents_to_concat.append(cond_latents)
|
| 672 |
+
latents_ids_to_concat.append(cond_latent_image_ids)
|
| 673 |
+
|
| 674 |
+
cond_latents = torch.cat(cond_latents_to_concat, dim=1) if cond_latents_to_concat else None
|
| 675 |
+
latent_image_ids = torch.cat(latents_ids_to_concat, dim=0)
|
| 676 |
+
|
| 677 |
+
return noise_latents, image_latents, cond_latents, latent_image_ids
|
| 678 |
+
|
| 679 |
+
@property
|
| 680 |
+
def guidance_scale(self):
|
| 681 |
+
return self._guidance_scale
|
| 682 |
+
|
| 683 |
+
@property
|
| 684 |
+
def joint_attention_kwargs(self):
|
| 685 |
+
return self._joint_attention_kwargs
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def num_timesteps(self):
|
| 689 |
+
return self._num_timesteps
|
| 690 |
+
|
| 691 |
+
@property
|
| 692 |
+
def current_timestep(self):
|
| 693 |
+
return self._current_timestep
|
| 694 |
+
|
| 695 |
+
@property
|
| 696 |
+
def interrupt(self):
|
| 697 |
+
return self._interrupt
|
| 698 |
+
|
| 699 |
+
@torch.no_grad()
|
| 700 |
+
def __call__(
|
| 701 |
+
self,
|
| 702 |
+
image: Optional[PipelineImageInput] = None,
|
| 703 |
+
prompt: Union[str, List[str]] = None,
|
| 704 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 705 |
+
height: Optional[int] = None,
|
| 706 |
+
width: Optional[int] = None,
|
| 707 |
+
num_inference_steps: int = 28,
|
| 708 |
+
sigmas: Optional[List[float]] = None,
|
| 709 |
+
guidance_scale: float = 3.5,
|
| 710 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 711 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 712 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 713 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 714 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 715 |
+
output_type: Optional[str] = "pil",
|
| 716 |
+
return_dict: bool = True,
|
| 717 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 718 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 719 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 720 |
+
max_sequence_length: int = 512,
|
| 721 |
+
cond_size: int = 512,
|
| 722 |
+
control_dict: Optional[Dict[str, Any]] = None,
|
| 723 |
+
):
|
| 724 |
+
r"""
|
| 725 |
+
Function invoked when calling the pipeline for generation.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 729 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 730 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 731 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 732 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 733 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 734 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 735 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 736 |
+
instead.
|
| 737 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 738 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 739 |
+
will be used instead.
|
| 740 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 741 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 742 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 743 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 744 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 745 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 746 |
+
expense of slower inference.
|
| 747 |
+
sigmas (`List[float]`, *optional*):
|
| 748 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 749 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 750 |
+
will be used.
|
| 751 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 752 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 753 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 754 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 755 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 756 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 757 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 758 |
+
The number of images to generate per prompt.
|
| 759 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 760 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 761 |
+
to make generation deterministic.
|
| 762 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 763 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 764 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 765 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 766 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 767 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 768 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 769 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 770 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 771 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 772 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 773 |
+
The output format of the generate image. Choose between
|
| 774 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 775 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 776 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 777 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 778 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 779 |
+
`self.processor` in
|
| 780 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 781 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 782 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 783 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 784 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 785 |
+
`callback_on_step_end_tensor_inputs`.
|
| 786 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 787 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 788 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 789 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 790 |
+
max_sequence_length (`int` defaults to 512):
|
| 791 |
+
Maximum sequence length to use with the `prompt`.
|
| 792 |
+
cond_size (`int`, *optional*, defaults to 512):
|
| 793 |
+
The size for conditioning images.
|
| 794 |
+
|
| 795 |
+
Examples:
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 799 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 800 |
+
images.
|
| 801 |
+
"""
|
| 802 |
+
|
| 803 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 804 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 805 |
+
|
| 806 |
+
# 1. Check inputs. Raise error if not correct
|
| 807 |
+
self.check_inputs(
|
| 808 |
+
prompt,
|
| 809 |
+
prompt_2,
|
| 810 |
+
height,
|
| 811 |
+
width,
|
| 812 |
+
prompt_embeds=prompt_embeds,
|
| 813 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 814 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 815 |
+
max_sequence_length=max_sequence_length,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
self._guidance_scale = guidance_scale
|
| 819 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 820 |
+
self._current_timestep = None
|
| 821 |
+
self._interrupt = False
|
| 822 |
+
|
| 823 |
+
spatial_images = control_dict.get("spatial_images", [])
|
| 824 |
+
subject_images = control_dict.get("subject_images", [])
|
| 825 |
+
|
| 826 |
+
# 2. Define call parameters
|
| 827 |
+
if prompt is not None and isinstance(prompt, str):
|
| 828 |
+
batch_size = 1
|
| 829 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 830 |
+
batch_size = len(prompt)
|
| 831 |
+
else:
|
| 832 |
+
batch_size = prompt_embeds.shape[0]
|
| 833 |
+
|
| 834 |
+
device = self._execution_device
|
| 835 |
+
|
| 836 |
+
lora_scale = (
|
| 837 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 838 |
+
)
|
| 839 |
+
(
|
| 840 |
+
prompt_embeds,
|
| 841 |
+
pooled_prompt_embeds,
|
| 842 |
+
text_ids,
|
| 843 |
+
) = self.encode_prompt(
|
| 844 |
+
prompt=prompt,
|
| 845 |
+
prompt_2=prompt_2,
|
| 846 |
+
prompt_embeds=prompt_embeds,
|
| 847 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 848 |
+
device=device,
|
| 849 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 850 |
+
max_sequence_length=max_sequence_length,
|
| 851 |
+
lora_scale=lora_scale,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# 3. Preprocess images
|
| 855 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 856 |
+
img = image[0] if isinstance(image, list) else image
|
| 857 |
+
image_height, image_width = self.image_processor.get_default_height_width(img)
|
| 858 |
+
aspect_ratio = image_width / image_height
|
| 859 |
+
# Kontext is trained on specific resolutions, using one of them is recommended
|
| 860 |
+
_, image_width, image_height = min(
|
| 861 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 862 |
+
)
|
| 863 |
+
multiple_of = self.vae_scale_factor * 2
|
| 864 |
+
image_width = image_width // multiple_of * multiple_of
|
| 865 |
+
image_height = image_height // multiple_of * multiple_of
|
| 866 |
+
image = self.image_processor.resize(image, image_height, image_width)
|
| 867 |
+
image = self.image_processor.preprocess(image, image_height, image_width)
|
| 868 |
+
image = image.to(dtype=self.vae.dtype)
|
| 869 |
+
|
| 870 |
+
if len(subject_images) > 0:
|
| 871 |
+
subject_image_ls = []
|
| 872 |
+
for subject_image in subject_images:
|
| 873 |
+
w, h = subject_image.size[:2]
|
| 874 |
+
scale = cond_size / max(h, w)
|
| 875 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 876 |
+
subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
|
| 877 |
+
subject_image = subject_image.to(dtype=self.vae.dtype)
|
| 878 |
+
pad_h = cond_size - subject_image.shape[-2]
|
| 879 |
+
pad_w = cond_size - subject_image.shape[-1]
|
| 880 |
+
subject_image = pad(
|
| 881 |
+
subject_image, padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)), fill=0
|
| 882 |
+
)
|
| 883 |
+
subject_image_ls.append(subject_image)
|
| 884 |
+
subject_images = torch.cat(subject_image_ls, dim=-2)
|
| 885 |
+
else:
|
| 886 |
+
subject_images = None
|
| 887 |
+
|
| 888 |
+
if len(spatial_images) > 0:
|
| 889 |
+
condition_image_ls = []
|
| 890 |
+
for img in spatial_images:
|
| 891 |
+
condition_image = self.image_processor.preprocess(img, height=cond_size, width=cond_size)
|
| 892 |
+
condition_image = condition_image.to(dtype=self.vae.dtype)
|
| 893 |
+
condition_image_ls.append(condition_image)
|
| 894 |
+
spatial_images = torch.cat(condition_image_ls, dim=-2)
|
| 895 |
+
else:
|
| 896 |
+
spatial_images = None
|
| 897 |
+
|
| 898 |
+
# 4. Prepare latent variables
|
| 899 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 900 |
+
latents, image_latents, cond_latents, latent_image_ids = self.prepare_latents(
|
| 901 |
+
batch_size * num_images_per_prompt,
|
| 902 |
+
num_channels_latents,
|
| 903 |
+
height,
|
| 904 |
+
width,
|
| 905 |
+
prompt_embeds.dtype,
|
| 906 |
+
device,
|
| 907 |
+
generator,
|
| 908 |
+
image,
|
| 909 |
+
subject_images,
|
| 910 |
+
spatial_images,
|
| 911 |
+
latents,
|
| 912 |
+
cond_size,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# 5. Prepare timesteps
|
| 916 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 917 |
+
image_seq_len = latents.shape[1]
|
| 918 |
+
mu = calculate_shift(
|
| 919 |
+
image_seq_len,
|
| 920 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 921 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 922 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 923 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 924 |
+
)
|
| 925 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 926 |
+
self.scheduler,
|
| 927 |
+
num_inference_steps,
|
| 928 |
+
device,
|
| 929 |
+
sigmas=sigmas,
|
| 930 |
+
mu=mu,
|
| 931 |
+
)
|
| 932 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 933 |
+
self._num_timesteps = len(timesteps)
|
| 934 |
+
|
| 935 |
+
# handle guidance
|
| 936 |
+
if self.transformer.config.guidance_embeds:
|
| 937 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 938 |
+
guidance = guidance.expand(latents.shape[0])
|
| 939 |
+
else:
|
| 940 |
+
guidance = None
|
| 941 |
+
|
| 942 |
+
# 6. Denoising loop
|
| 943 |
+
self.scheduler.set_begin_index(0)
|
| 944 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 945 |
+
for i, t in enumerate(timesteps):
|
| 946 |
+
if self.interrupt:
|
| 947 |
+
continue
|
| 948 |
+
|
| 949 |
+
latent_model_input = latents
|
| 950 |
+
if image_latents is not None:
|
| 951 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 952 |
+
|
| 953 |
+
self._current_timestep = t
|
| 954 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 955 |
+
noise_pred = self.transformer(
|
| 956 |
+
hidden_states=latent_model_input,
|
| 957 |
+
cond_hidden_states=cond_latents,
|
| 958 |
+
timestep=timestep / 1000,
|
| 959 |
+
guidance=guidance,
|
| 960 |
+
pooled_projections=pooled_prompt_embeds,
|
| 961 |
+
encoder_hidden_states=prompt_embeds,
|
| 962 |
+
txt_ids=text_ids,
|
| 963 |
+
img_ids=latent_image_ids,
|
| 964 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 965 |
+
return_dict=False,
|
| 966 |
+
)[0]
|
| 967 |
+
|
| 968 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 969 |
+
|
| 970 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 971 |
+
latents_dtype = latents.dtype
|
| 972 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 973 |
+
|
| 974 |
+
if latents.dtype != latents_dtype:
|
| 975 |
+
latents = latents.to(latents_dtype)
|
| 976 |
+
|
| 977 |
+
if callback_on_step_end is not None:
|
| 978 |
+
callback_kwargs = {}
|
| 979 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 980 |
+
callback_kwargs[k] = locals()[k]
|
| 981 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 982 |
+
|
| 983 |
+
latents = callback_outputs.pop("latents", latents)
|
| 984 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 985 |
+
|
| 986 |
+
# call the callback, if provided
|
| 987 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 988 |
+
progress_bar.update()
|
| 989 |
+
|
| 990 |
+
if XLA_AVAILABLE:
|
| 991 |
+
xm.mark_step()
|
| 992 |
+
|
| 993 |
+
self._current_timestep = None
|
| 994 |
+
|
| 995 |
+
if output_type == "latent":
|
| 996 |
+
image = latents
|
| 997 |
+
else:
|
| 998 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 999 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1000 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1001 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1002 |
+
|
| 1003 |
+
# Offload all models
|
| 1004 |
+
self.maybe_free_model_hooks()
|
| 1005 |
+
|
| 1006 |
+
if not return_dict:
|
| 1007 |
+
return (image,)
|
| 1008 |
+
|
| 1009 |
+
return FluxPipelineOutput(images=image)
|
train/src/prompt_helper.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_text_encoders(args, class_one, class_two):
|
| 5 |
+
text_encoder_one = class_one.from_pretrained(
|
| 6 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 7 |
+
)
|
| 8 |
+
text_encoder_two = class_two.from_pretrained(
|
| 9 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
| 10 |
+
)
|
| 11 |
+
return text_encoder_one, text_encoder_two
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
| 15 |
+
text_inputs = tokenizer(
|
| 16 |
+
prompt,
|
| 17 |
+
padding="max_length",
|
| 18 |
+
max_length=max_sequence_length,
|
| 19 |
+
truncation=True,
|
| 20 |
+
return_length=False,
|
| 21 |
+
return_overflowing_tokens=False,
|
| 22 |
+
return_tensors="pt",
|
| 23 |
+
)
|
| 24 |
+
text_input_ids = text_inputs.input_ids
|
| 25 |
+
return text_input_ids
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def tokenize_prompt_clip(tokenizer, prompt):
|
| 29 |
+
text_inputs = tokenizer(
|
| 30 |
+
prompt,
|
| 31 |
+
padding="max_length",
|
| 32 |
+
max_length=77,
|
| 33 |
+
truncation=True,
|
| 34 |
+
return_length=False,
|
| 35 |
+
return_overflowing_tokens=False,
|
| 36 |
+
return_tensors="pt",
|
| 37 |
+
)
|
| 38 |
+
text_input_ids = text_inputs.input_ids
|
| 39 |
+
return text_input_ids
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def tokenize_prompt_t5(tokenizer, prompt):
|
| 43 |
+
text_inputs = tokenizer(
|
| 44 |
+
prompt,
|
| 45 |
+
padding="max_length",
|
| 46 |
+
max_length=512,
|
| 47 |
+
truncation=True,
|
| 48 |
+
return_length=False,
|
| 49 |
+
return_overflowing_tokens=False,
|
| 50 |
+
return_tensors="pt",
|
| 51 |
+
)
|
| 52 |
+
text_input_ids = text_inputs.input_ids
|
| 53 |
+
return text_input_ids
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _encode_prompt_with_t5(
|
| 57 |
+
text_encoder,
|
| 58 |
+
tokenizer,
|
| 59 |
+
max_sequence_length=512,
|
| 60 |
+
prompt=None,
|
| 61 |
+
num_images_per_prompt=1,
|
| 62 |
+
device=None,
|
| 63 |
+
text_input_ids=None,
|
| 64 |
+
):
|
| 65 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 66 |
+
batch_size = len(prompt)
|
| 67 |
+
|
| 68 |
+
if tokenizer is not None:
|
| 69 |
+
text_inputs = tokenizer(
|
| 70 |
+
prompt,
|
| 71 |
+
padding="max_length",
|
| 72 |
+
max_length=max_sequence_length,
|
| 73 |
+
truncation=True,
|
| 74 |
+
return_length=False,
|
| 75 |
+
return_overflowing_tokens=False,
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
)
|
| 78 |
+
text_input_ids = text_inputs.input_ids
|
| 79 |
+
else:
|
| 80 |
+
if text_input_ids is None:
|
| 81 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 82 |
+
|
| 83 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 84 |
+
|
| 85 |
+
dtype = text_encoder.dtype
|
| 86 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 87 |
+
|
| 88 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 89 |
+
|
| 90 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 91 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 92 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 93 |
+
|
| 94 |
+
return prompt_embeds
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _encode_prompt_with_clip(
|
| 98 |
+
text_encoder,
|
| 99 |
+
tokenizer,
|
| 100 |
+
prompt: str,
|
| 101 |
+
device=None,
|
| 102 |
+
text_input_ids=None,
|
| 103 |
+
num_images_per_prompt: int = 1,
|
| 104 |
+
):
|
| 105 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 106 |
+
batch_size = len(prompt)
|
| 107 |
+
|
| 108 |
+
if tokenizer is not None:
|
| 109 |
+
text_inputs = tokenizer(
|
| 110 |
+
prompt,
|
| 111 |
+
padding="max_length",
|
| 112 |
+
max_length=77,
|
| 113 |
+
truncation=True,
|
| 114 |
+
return_overflowing_tokens=False,
|
| 115 |
+
return_length=False,
|
| 116 |
+
return_tensors="pt",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
text_input_ids = text_inputs.input_ids
|
| 120 |
+
else:
|
| 121 |
+
if text_input_ids is None:
|
| 122 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 123 |
+
|
| 124 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 125 |
+
|
| 126 |
+
# Use pooled output of CLIPTextModel
|
| 127 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 128 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
| 129 |
+
|
| 130 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 131 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 132 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 133 |
+
|
| 134 |
+
return prompt_embeds
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def encode_prompt(
|
| 138 |
+
text_encoders,
|
| 139 |
+
tokenizers,
|
| 140 |
+
prompt: str,
|
| 141 |
+
max_sequence_length,
|
| 142 |
+
device=None,
|
| 143 |
+
num_images_per_prompt: int = 1,
|
| 144 |
+
text_input_ids_list=None,
|
| 145 |
+
):
|
| 146 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 147 |
+
dtype = text_encoders[0].dtype
|
| 148 |
+
|
| 149 |
+
pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 150 |
+
text_encoder=text_encoders[0],
|
| 151 |
+
tokenizer=tokenizers[0],
|
| 152 |
+
prompt=prompt,
|
| 153 |
+
device=device if device is not None else text_encoders[0].device,
|
| 154 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 155 |
+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
prompt_embeds = _encode_prompt_with_t5(
|
| 159 |
+
text_encoder=text_encoders[1],
|
| 160 |
+
tokenizer=tokenizers[1],
|
| 161 |
+
max_sequence_length=max_sequence_length,
|
| 162 |
+
prompt=prompt,
|
| 163 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 164 |
+
device=device if device is not None else text_encoders[1].device,
|
| 165 |
+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 169 |
+
|
| 170 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
|
| 174 |
+
text_encoder_clip = text_encoders[0]
|
| 175 |
+
text_encoder_t5 = text_encoders[1]
|
| 176 |
+
tokens_clip, tokens_t5 = tokens[0], tokens[1]
|
| 177 |
+
batch_size = tokens_clip.shape[0]
|
| 178 |
+
|
| 179 |
+
if device == "cpu":
|
| 180 |
+
device = "cpu"
|
| 181 |
+
else:
|
| 182 |
+
device = accelerator.device
|
| 183 |
+
|
| 184 |
+
# clip
|
| 185 |
+
prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
|
| 186 |
+
# Use pooled output of CLIPTextModel
|
| 187 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 188 |
+
prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
| 189 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 190 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 191 |
+
pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 192 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
|
| 193 |
+
|
| 194 |
+
# t5
|
| 195 |
+
prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
|
| 196 |
+
dtype = text_encoder_t5.dtype
|
| 197 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
|
| 198 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 199 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 200 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 201 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 202 |
+
|
| 203 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
|
| 204 |
+
|
| 205 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
train/src/transformer_flux.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
| 10 |
+
from diffusers.models.attention import FeedForward
|
| 11 |
+
from diffusers.models.attention_processor import (
|
| 12 |
+
Attention,
|
| 13 |
+
AttentionProcessor,
|
| 14 |
+
FluxAttnProcessor2_0,
|
| 15 |
+
FluxAttnProcessor2_0_NPU,
|
| 16 |
+
FusedFluxAttnProcessor2_0,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 20 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 21 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
| 22 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 23 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
| 24 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 27 |
+
|
| 28 |
+
@maybe_allow_in_graph
|
| 29 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 34 |
+
|
| 35 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 36 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 37 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 38 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 39 |
+
|
| 40 |
+
if is_torch_npu_available():
|
| 41 |
+
processor = FluxAttnProcessor2_0_NPU()
|
| 42 |
+
else:
|
| 43 |
+
processor = FluxAttnProcessor2_0()
|
| 44 |
+
self.attn = Attention(
|
| 45 |
+
query_dim=dim,
|
| 46 |
+
cross_attention_dim=None,
|
| 47 |
+
dim_head=attention_head_dim,
|
| 48 |
+
heads=num_attention_heads,
|
| 49 |
+
out_dim=dim,
|
| 50 |
+
bias=True,
|
| 51 |
+
processor=processor,
|
| 52 |
+
qk_norm="rms_norm",
|
| 53 |
+
eps=1e-6,
|
| 54 |
+
pre_only=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
hidden_states: torch.Tensor,
|
| 60 |
+
cond_hidden_states: torch.Tensor,
|
| 61 |
+
temb: torch.Tensor,
|
| 62 |
+
cond_temb: torch.Tensor,
|
| 63 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 64 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 65 |
+
) -> torch.Tensor:
|
| 66 |
+
use_cond = cond_hidden_states is not None
|
| 67 |
+
|
| 68 |
+
residual = hidden_states
|
| 69 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 70 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 71 |
+
|
| 72 |
+
if use_cond:
|
| 73 |
+
residual_cond = cond_hidden_states
|
| 74 |
+
norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
|
| 75 |
+
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
|
| 76 |
+
|
| 77 |
+
norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 78 |
+
|
| 79 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 80 |
+
attn_output = self.attn(
|
| 81 |
+
hidden_states=norm_hidden_states_concat,
|
| 82 |
+
image_rotary_emb=image_rotary_emb,
|
| 83 |
+
use_cond=use_cond,
|
| 84 |
+
**joint_attention_kwargs,
|
| 85 |
+
)
|
| 86 |
+
if use_cond:
|
| 87 |
+
attn_output, cond_attn_output = attn_output
|
| 88 |
+
|
| 89 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 90 |
+
gate = gate.unsqueeze(1)
|
| 91 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 92 |
+
hidden_states = residual + hidden_states
|
| 93 |
+
|
| 94 |
+
if use_cond:
|
| 95 |
+
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
|
| 96 |
+
cond_gate = cond_gate.unsqueeze(1)
|
| 97 |
+
condition_latents = cond_gate * self.proj_out(condition_latents)
|
| 98 |
+
condition_latents = residual_cond + condition_latents
|
| 99 |
+
|
| 100 |
+
if hidden_states.dtype == torch.float16:
|
| 101 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 102 |
+
|
| 103 |
+
return hidden_states, condition_latents if use_cond else None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@maybe_allow_in_graph
|
| 107 |
+
class FluxTransformerBlock(nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 114 |
+
|
| 115 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 116 |
+
|
| 117 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 118 |
+
processor = FluxAttnProcessor2_0()
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 122 |
+
)
|
| 123 |
+
self.attn = Attention(
|
| 124 |
+
query_dim=dim,
|
| 125 |
+
cross_attention_dim=None,
|
| 126 |
+
added_kv_proj_dim=dim,
|
| 127 |
+
dim_head=attention_head_dim,
|
| 128 |
+
heads=num_attention_heads,
|
| 129 |
+
out_dim=dim,
|
| 130 |
+
context_pre_only=False,
|
| 131 |
+
bias=True,
|
| 132 |
+
processor=processor,
|
| 133 |
+
qk_norm=qk_norm,
|
| 134 |
+
eps=eps,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 138 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 139 |
+
|
| 140 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 141 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 142 |
+
|
| 143 |
+
# let chunk size default to None
|
| 144 |
+
self._chunk_size = None
|
| 145 |
+
self._chunk_dim = 0
|
| 146 |
+
|
| 147 |
+
def forward(
|
| 148 |
+
self,
|
| 149 |
+
hidden_states: torch.Tensor,
|
| 150 |
+
cond_hidden_states: torch.Tensor,
|
| 151 |
+
encoder_hidden_states: torch.Tensor,
|
| 152 |
+
temb: torch.Tensor,
|
| 153 |
+
cond_temb: torch.Tensor,
|
| 154 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 155 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 156 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 157 |
+
use_cond = cond_hidden_states is not None
|
| 158 |
+
|
| 159 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 160 |
+
if use_cond:
|
| 161 |
+
(
|
| 162 |
+
norm_cond_hidden_states,
|
| 163 |
+
cond_gate_msa,
|
| 164 |
+
cond_shift_mlp,
|
| 165 |
+
cond_scale_mlp,
|
| 166 |
+
cond_gate_mlp,
|
| 167 |
+
) = self.norm1(cond_hidden_states, emb=cond_temb)
|
| 168 |
+
|
| 169 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 170 |
+
encoder_hidden_states, emb=temb
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
|
| 174 |
+
|
| 175 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 176 |
+
# Attention.
|
| 177 |
+
attention_outputs = self.attn(
|
| 178 |
+
hidden_states=norm_hidden_states,
|
| 179 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 180 |
+
image_rotary_emb=image_rotary_emb,
|
| 181 |
+
use_cond=use_cond,
|
| 182 |
+
**joint_attention_kwargs,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
attn_output, context_attn_output = attention_outputs[:2]
|
| 186 |
+
cond_attn_output = attention_outputs[2] if use_cond else None
|
| 187 |
+
|
| 188 |
+
# Process attention outputs for the `hidden_states`.
|
| 189 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 190 |
+
hidden_states = hidden_states + attn_output
|
| 191 |
+
|
| 192 |
+
if use_cond:
|
| 193 |
+
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
|
| 194 |
+
cond_hidden_states = cond_hidden_states + cond_attn_output
|
| 195 |
+
|
| 196 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 197 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 198 |
+
|
| 199 |
+
if use_cond:
|
| 200 |
+
norm_cond_hidden_states = self.norm2(cond_hidden_states)
|
| 201 |
+
norm_cond_hidden_states = (
|
| 202 |
+
norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
|
| 203 |
+
+ cond_shift_mlp[:, None]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
ff_output = self.ff(norm_hidden_states)
|
| 207 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 208 |
+
hidden_states = hidden_states + ff_output
|
| 209 |
+
|
| 210 |
+
if use_cond:
|
| 211 |
+
cond_ff_output = self.ff(norm_cond_hidden_states)
|
| 212 |
+
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
|
| 213 |
+
cond_hidden_states = cond_hidden_states + cond_ff_output
|
| 214 |
+
|
| 215 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 216 |
+
|
| 217 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 218 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 219 |
+
|
| 220 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 221 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 222 |
+
|
| 223 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 224 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 225 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 226 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 227 |
+
|
| 228 |
+
return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class FluxTransformer2DModel(
|
| 232 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
| 233 |
+
):
|
| 234 |
+
_supports_gradient_checkpointing = True
|
| 235 |
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 236 |
+
|
| 237 |
+
@register_to_config
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
patch_size: int = 1,
|
| 241 |
+
in_channels: int = 64,
|
| 242 |
+
out_channels: Optional[int] = None,
|
| 243 |
+
num_layers: int = 19,
|
| 244 |
+
num_single_layers: int = 38,
|
| 245 |
+
attention_head_dim: int = 128,
|
| 246 |
+
num_attention_heads: int = 24,
|
| 247 |
+
joint_attention_dim: int = 4096,
|
| 248 |
+
pooled_projection_dim: int = 768,
|
| 249 |
+
guidance_embeds: bool = False,
|
| 250 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
| 251 |
+
):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.out_channels = out_channels or in_channels
|
| 254 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 255 |
+
|
| 256 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 257 |
+
|
| 258 |
+
text_time_guidance_cls = (
|
| 259 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 260 |
+
)
|
| 261 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 262 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 266 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 267 |
+
|
| 268 |
+
self.transformer_blocks = nn.ModuleList(
|
| 269 |
+
[
|
| 270 |
+
FluxTransformerBlock(
|
| 271 |
+
dim=self.inner_dim,
|
| 272 |
+
num_attention_heads=num_attention_heads,
|
| 273 |
+
attention_head_dim=attention_head_dim,
|
| 274 |
+
)
|
| 275 |
+
for _ in range(num_layers)
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 280 |
+
[
|
| 281 |
+
FluxSingleTransformerBlock(
|
| 282 |
+
dim=self.inner_dim,
|
| 283 |
+
num_attention_heads=num_attention_heads,
|
| 284 |
+
attention_head_dim=attention_head_dim,
|
| 285 |
+
)
|
| 286 |
+
for _ in range(num_single_layers)
|
| 287 |
+
]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 291 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 292 |
+
|
| 293 |
+
self.gradient_checkpointing = False
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 297 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 298 |
+
r"""
|
| 299 |
+
Returns:
|
| 300 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 301 |
+
indexed by its weight name.
|
| 302 |
+
"""
|
| 303 |
+
# set recursively
|
| 304 |
+
processors = {}
|
| 305 |
+
|
| 306 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 307 |
+
if hasattr(module, "get_processor"):
|
| 308 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 309 |
+
|
| 310 |
+
for sub_name, child in module.named_children():
|
| 311 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 312 |
+
|
| 313 |
+
return processors
|
| 314 |
+
|
| 315 |
+
for name, module in self.named_children():
|
| 316 |
+
fn_recursive_add_processors(name, module, processors)
|
| 317 |
+
|
| 318 |
+
return processors
|
| 319 |
+
|
| 320 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 321 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 322 |
+
r"""
|
| 323 |
+
Sets the attention processor to use to compute attention.
|
| 324 |
+
|
| 325 |
+
Parameters:
|
| 326 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 327 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 328 |
+
for **all** `Attention` layers.
|
| 329 |
+
|
| 330 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 331 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
count = len(self.attn_processors.keys())
|
| 335 |
+
|
| 336 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 339 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 343 |
+
if hasattr(module, "set_processor"):
|
| 344 |
+
if not isinstance(processor, dict):
|
| 345 |
+
module.set_processor(processor)
|
| 346 |
+
else:
|
| 347 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 348 |
+
|
| 349 |
+
for sub_name, child in module.named_children():
|
| 350 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 351 |
+
|
| 352 |
+
for name, module in self.named_children():
|
| 353 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 354 |
+
|
| 355 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
| 356 |
+
def fuse_qkv_projections(self):
|
| 357 |
+
"""
|
| 358 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 359 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 360 |
+
|
| 361 |
+
<Tip warning={true}>
|
| 362 |
+
|
| 363 |
+
This API is 🧪 experimental.
|
| 364 |
+
|
| 365 |
+
</Tip>
|
| 366 |
+
"""
|
| 367 |
+
self.original_attn_processors = None
|
| 368 |
+
|
| 369 |
+
for _, attn_processor in self.attn_processors.items():
|
| 370 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 371 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 372 |
+
|
| 373 |
+
self.original_attn_processors = self.attn_processors
|
| 374 |
+
|
| 375 |
+
for module in self.modules():
|
| 376 |
+
if isinstance(module, Attention):
|
| 377 |
+
module.fuse_projections(fuse=True)
|
| 378 |
+
|
| 379 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
| 380 |
+
|
| 381 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 382 |
+
def unfuse_qkv_projections(self):
|
| 383 |
+
"""Disables the fused QKV projection if enabled.
|
| 384 |
+
|
| 385 |
+
<Tip warning={true}>
|
| 386 |
+
|
| 387 |
+
This API is 🧪 experimental.
|
| 388 |
+
|
| 389 |
+
</Tip>
|
| 390 |
+
|
| 391 |
+
"""
|
| 392 |
+
if self.original_attn_processors is not None:
|
| 393 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 394 |
+
|
| 395 |
+
def _set_gradient_checkpointing(self, module=None, enable=False, gradient_checkpointing_func=None):
|
| 396 |
+
# Align with diffusers' enable_gradient_checkpointing API which may call
|
| 397 |
+
# without a `module` argument and pass only keyword args.
|
| 398 |
+
# Toggle on both the provided module (if any) and on self for safety.
|
| 399 |
+
if module is not None and hasattr(module, "gradient_checkpointing"):
|
| 400 |
+
module.gradient_checkpointing = enable
|
| 401 |
+
if hasattr(self, "gradient_checkpointing"):
|
| 402 |
+
self.gradient_checkpointing = enable
|
| 403 |
+
# Optionally store the provided function for future use.
|
| 404 |
+
if gradient_checkpointing_func is not None:
|
| 405 |
+
setattr(self, "_gradient_checkpointing_func", gradient_checkpointing_func)
|
| 406 |
+
|
| 407 |
+
def forward(
|
| 408 |
+
self,
|
| 409 |
+
hidden_states: torch.Tensor,
|
| 410 |
+
cond_hidden_states: torch.Tensor = None,
|
| 411 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 412 |
+
pooled_projections: torch.Tensor = None,
|
| 413 |
+
timestep: torch.LongTensor = None,
|
| 414 |
+
img_ids: torch.Tensor = None,
|
| 415 |
+
txt_ids: torch.Tensor = None,
|
| 416 |
+
guidance: torch.Tensor = None,
|
| 417 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 418 |
+
controlnet_block_samples=None,
|
| 419 |
+
controlnet_single_block_samples=None,
|
| 420 |
+
return_dict: bool = True,
|
| 421 |
+
controlnet_blocks_repeat: bool = False,
|
| 422 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 423 |
+
if cond_hidden_states is not None:
|
| 424 |
+
use_condition = True
|
| 425 |
+
else:
|
| 426 |
+
use_condition = False
|
| 427 |
+
|
| 428 |
+
if joint_attention_kwargs is not None:
|
| 429 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 430 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 431 |
+
else:
|
| 432 |
+
lora_scale = 1.0
|
| 433 |
+
|
| 434 |
+
if USE_PEFT_BACKEND:
|
| 435 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 436 |
+
scale_lora_layers(self, lora_scale)
|
| 437 |
+
else:
|
| 438 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 439 |
+
logger.warning(
|
| 440 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 444 |
+
if cond_hidden_states is not None:
|
| 445 |
+
if cond_hidden_states.shape[-1] == self.x_embedder.in_features:
|
| 446 |
+
cond_hidden_states = self.x_embedder(cond_hidden_states)
|
| 447 |
+
elif cond_hidden_states.shape[-1] == 64:
|
| 448 |
+
# 只用前64列权重和bias
|
| 449 |
+
weight = self.x_embedder.weight[:, :64] # [inner_dim, 64]
|
| 450 |
+
bias = self.x_embedder.bias
|
| 451 |
+
cond_hidden_states = torch.nn.functional.linear(cond_hidden_states, weight, bias)
|
| 452 |
+
|
| 453 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 454 |
+
if guidance is not None:
|
| 455 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 456 |
+
else:
|
| 457 |
+
guidance = None
|
| 458 |
+
|
| 459 |
+
temb = (
|
| 460 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 461 |
+
if guidance is None
|
| 462 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
cond_temb = (
|
| 466 |
+
self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
|
| 467 |
+
if guidance is None
|
| 468 |
+
else self.time_text_embed(
|
| 469 |
+
torch.ones_like(timestep) * 0, guidance, pooled_projections
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 474 |
+
|
| 475 |
+
if txt_ids.ndim == 3:
|
| 476 |
+
logger.warning(
|
| 477 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 478 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 479 |
+
)
|
| 480 |
+
txt_ids = txt_ids[0]
|
| 481 |
+
if img_ids.ndim == 3:
|
| 482 |
+
logger.warning(
|
| 483 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 484 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 485 |
+
)
|
| 486 |
+
img_ids = img_ids[0]
|
| 487 |
+
|
| 488 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 489 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 490 |
+
|
| 491 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 492 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 493 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 494 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 495 |
+
|
| 496 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 497 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 498 |
+
|
| 499 |
+
def create_custom_forward(module, return_dict=None):
|
| 500 |
+
def custom_forward(*inputs):
|
| 501 |
+
if return_dict is not None:
|
| 502 |
+
return module(*inputs, return_dict=return_dict)
|
| 503 |
+
else:
|
| 504 |
+
return module(*inputs)
|
| 505 |
+
|
| 506 |
+
return custom_forward
|
| 507 |
+
|
| 508 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 509 |
+
if use_condition:
|
| 510 |
+
encoder_hidden_states, hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 511 |
+
create_custom_forward(block),
|
| 512 |
+
hidden_states,
|
| 513 |
+
cond_hidden_states,
|
| 514 |
+
encoder_hidden_states,
|
| 515 |
+
temb,
|
| 516 |
+
cond_temb,
|
| 517 |
+
image_rotary_emb,
|
| 518 |
+
joint_attention_kwargs,
|
| 519 |
+
**ckpt_kwargs,
|
| 520 |
+
)
|
| 521 |
+
else:
|
| 522 |
+
encoder_hidden_states, hidden_states, _ = torch.utils.checkpoint.checkpoint(
|
| 523 |
+
create_custom_forward(block),
|
| 524 |
+
hidden_states,
|
| 525 |
+
None,
|
| 526 |
+
encoder_hidden_states,
|
| 527 |
+
temb,
|
| 528 |
+
None,
|
| 529 |
+
image_rotary_emb,
|
| 530 |
+
joint_attention_kwargs,
|
| 531 |
+
**ckpt_kwargs,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
else:
|
| 535 |
+
encoder_hidden_states, hidden_states, cond_hidden_states = block(
|
| 536 |
+
hidden_states=hidden_states,
|
| 537 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 538 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 539 |
+
temb=temb,
|
| 540 |
+
cond_temb=cond_temb if use_condition else None,
|
| 541 |
+
image_rotary_emb=image_rotary_emb,
|
| 542 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# controlnet residual
|
| 546 |
+
if controlnet_block_samples is not None:
|
| 547 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 548 |
+
interval_control = int(np.ceil(interval_control))
|
| 549 |
+
# For Xlabs ControlNet.
|
| 550 |
+
if controlnet_blocks_repeat:
|
| 551 |
+
hidden_states = (
|
| 552 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 556 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 557 |
+
|
| 558 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 559 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 560 |
+
|
| 561 |
+
def create_custom_forward(module, return_dict=None):
|
| 562 |
+
def custom_forward(*inputs):
|
| 563 |
+
if return_dict is not None:
|
| 564 |
+
return module(*inputs, return_dict=return_dict)
|
| 565 |
+
else:
|
| 566 |
+
return module(*inputs)
|
| 567 |
+
|
| 568 |
+
return custom_forward
|
| 569 |
+
|
| 570 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 571 |
+
if use_condition:
|
| 572 |
+
hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 573 |
+
create_custom_forward(block),
|
| 574 |
+
hidden_states,
|
| 575 |
+
cond_hidden_states,
|
| 576 |
+
temb,
|
| 577 |
+
cond_temb,
|
| 578 |
+
image_rotary_emb,
|
| 579 |
+
joint_attention_kwargs,
|
| 580 |
+
**ckpt_kwargs,
|
| 581 |
+
)
|
| 582 |
+
else:
|
| 583 |
+
hidden_states, _ = torch.utils.checkpoint.checkpoint(
|
| 584 |
+
create_custom_forward(block),
|
| 585 |
+
hidden_states,
|
| 586 |
+
None,
|
| 587 |
+
temb,
|
| 588 |
+
None,
|
| 589 |
+
image_rotary_emb,
|
| 590 |
+
joint_attention_kwargs,
|
| 591 |
+
**ckpt_kwargs,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
else:
|
| 595 |
+
hidden_states, cond_hidden_states = block(
|
| 596 |
+
hidden_states=hidden_states,
|
| 597 |
+
cond_hidden_states=cond_hidden_states if use_condition else None,
|
| 598 |
+
temb=temb,
|
| 599 |
+
cond_temb=cond_temb if use_condition else None,
|
| 600 |
+
image_rotary_emb=image_rotary_emb,
|
| 601 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# controlnet residual
|
| 605 |
+
if controlnet_single_block_samples is not None:
|
| 606 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 607 |
+
interval_control = int(np.ceil(interval_control))
|
| 608 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 609 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 610 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 614 |
+
|
| 615 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 616 |
+
output = self.proj_out(hidden_states)
|
| 617 |
+
|
| 618 |
+
if USE_PEFT_BACKEND:
|
| 619 |
+
# remove `lora_scale` from each PEFT layer
|
| 620 |
+
unscale_lora_layers(self, lora_scale)
|
| 621 |
+
|
| 622 |
+
if not return_dict:
|
| 623 |
+
return (output,)
|
| 624 |
+
|
| 625 |
+
return Transformer2DModelOutput(sample=output)
|
train/train_kontext_color.py
ADDED
|
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
from safetensors.torch import save_file
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
import transformers
|
| 17 |
+
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 21 |
+
|
| 22 |
+
import diffusers
|
| 23 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.optimization import get_scheduler
|
| 25 |
+
from diffusers.training_utils import (
|
| 26 |
+
cast_training_params,
|
| 27 |
+
compute_density_for_timestep_sampling,
|
| 28 |
+
compute_loss_weighting_for_sd3,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 31 |
+
from diffusers.utils import (
|
| 32 |
+
check_min_version,
|
| 33 |
+
is_wandb_available,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from src.prompt_helper import *
|
| 37 |
+
from src.lora_helper import *
|
| 38 |
+
from src.jsonl_datasets_kontext_color import make_train_dataset_inpaint_mask, collate_fn
|
| 39 |
+
from src.pipeline_flux_kontext_control import (
|
| 40 |
+
FluxKontextControlPipeline,
|
| 41 |
+
resize_position_encoding,
|
| 42 |
+
prepare_latent_subject_ids,
|
| 43 |
+
PREFERRED_KONTEXT_RESOLUTIONS
|
| 44 |
+
)
|
| 45 |
+
from src.transformer_flux import FluxTransformer2DModel
|
| 46 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 47 |
+
from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 48 |
+
from tqdm.auto import tqdm
|
| 49 |
+
|
| 50 |
+
if is_wandb_available():
|
| 51 |
+
import wandb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 55 |
+
check_min_version("0.31.0.dev0")
|
| 56 |
+
|
| 57 |
+
logger = get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def log_validation(
|
| 61 |
+
pipeline,
|
| 62 |
+
args,
|
| 63 |
+
accelerator,
|
| 64 |
+
pipeline_args,
|
| 65 |
+
step,
|
| 66 |
+
torch_dtype,
|
| 67 |
+
is_final_validation=False,
|
| 68 |
+
):
|
| 69 |
+
logger.info(
|
| 70 |
+
f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
|
| 71 |
+
)
|
| 72 |
+
pipeline = pipeline.to(accelerator.device)
|
| 73 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 74 |
+
|
| 75 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 76 |
+
autocast_ctx = nullcontext()
|
| 77 |
+
|
| 78 |
+
# Build per-case evaluation: require equal lengths for image, spatial image, and prompt
|
| 79 |
+
if args.validation_images is None or args.validation_images == ['None']:
|
| 80 |
+
raise ValueError("validation_images must be provided and non-empty")
|
| 81 |
+
if args.validation_prompt is None:
|
| 82 |
+
raise ValueError("validation_prompt must be provided and non-empty")
|
| 83 |
+
|
| 84 |
+
control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
|
| 85 |
+
spatial_ls = control_dict_root.get("spatial_images", []) or []
|
| 86 |
+
|
| 87 |
+
val_imgs = args.validation_images
|
| 88 |
+
prompts = args.validation_prompt
|
| 89 |
+
|
| 90 |
+
if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
results = []
|
| 96 |
+
|
| 97 |
+
def _resize_to_preferred(img: Image.Image) -> Image.Image:
|
| 98 |
+
w, h = img.size
|
| 99 |
+
aspect_ratio = w / h if h != 0 else 1.0
|
| 100 |
+
_, target_w, target_h = min(
|
| 101 |
+
(abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
|
| 102 |
+
for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
|
| 103 |
+
)
|
| 104 |
+
return img.resize((target_w, target_h), Image.BICUBIC)
|
| 105 |
+
|
| 106 |
+
# Distributed per-rank assignment: each process handles its own slice of cases
|
| 107 |
+
num_cases = len(prompts)
|
| 108 |
+
logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
|
| 109 |
+
|
| 110 |
+
# Indices assigned to this rank
|
| 111 |
+
rank = accelerator.process_index
|
| 112 |
+
world_size = accelerator.num_processes
|
| 113 |
+
local_indices = list(range(rank, num_cases, world_size))
|
| 114 |
+
|
| 115 |
+
local_images = []
|
| 116 |
+
with autocast_ctx:
|
| 117 |
+
for idx in local_indices:
|
| 118 |
+
try:
|
| 119 |
+
base_img = Image.open(val_imgs[idx]).convert("RGB")
|
| 120 |
+
resized_img = _resize_to_preferred(base_img)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
|
| 123 |
+
|
| 124 |
+
case_args = dict(pipeline_args) if pipeline_args is not None else {}
|
| 125 |
+
case_args.pop("height", None)
|
| 126 |
+
case_args.pop("width", None)
|
| 127 |
+
if resized_img is not None:
|
| 128 |
+
tw, th = resized_img.size
|
| 129 |
+
case_args["height"] = th
|
| 130 |
+
case_args["width"] = tw
|
| 131 |
+
|
| 132 |
+
case_control = dict(case_args.get("control_dict", {}))
|
| 133 |
+
spatial_case = spatial_ls[idx]
|
| 134 |
+
|
| 135 |
+
# Load spatial image if it's a path; else assume it's already an image
|
| 136 |
+
if isinstance(spatial_case, str):
|
| 137 |
+
spatial_img = Image.open(spatial_case).convert("RGB")
|
| 138 |
+
else:
|
| 139 |
+
spatial_img = spatial_case
|
| 140 |
+
|
| 141 |
+
case_control["spatial_images"] = [spatial_img]
|
| 142 |
+
case_control["subject_images"] = []
|
| 143 |
+
case_args["control_dict"] = case_control
|
| 144 |
+
|
| 145 |
+
case_args["prompt"] = prompts[idx]
|
| 146 |
+
img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
|
| 147 |
+
local_images.append(img)
|
| 148 |
+
|
| 149 |
+
# Gather all images per rank (pad to equal count) to main process
|
| 150 |
+
fixed_size = (1024, 1024)
|
| 151 |
+
max_local = int(math.ceil(num_cases / world_size)) if world_size > 0 else len(local_images)
|
| 152 |
+
# Build per-rank batch tensors
|
| 153 |
+
imgs_rank = []
|
| 154 |
+
idx_rank = []
|
| 155 |
+
has_rank = []
|
| 156 |
+
for j in range(max_local):
|
| 157 |
+
if j < len(local_images):
|
| 158 |
+
resized = local_images[j].resize(fixed_size, Image.BICUBIC)
|
| 159 |
+
img_np = np.asarray(resized).astype(np.uint8)
|
| 160 |
+
imgs_rank.append(torch.from_numpy(img_np))
|
| 161 |
+
idx_rank.append(local_indices[j])
|
| 162 |
+
has_rank.append(1)
|
| 163 |
+
else:
|
| 164 |
+
imgs_rank.append(torch.from_numpy(np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)))
|
| 165 |
+
idx_rank.append(-1)
|
| 166 |
+
has_rank.append(0)
|
| 167 |
+
imgs_rank_tensor = torch.stack([t.to(device=accelerator.device) for t in imgs_rank], dim=0) # [max_local, H, W, C]
|
| 168 |
+
idx_rank_tensor = torch.tensor(idx_rank, device=accelerator.device, dtype=torch.long) # [max_local]
|
| 169 |
+
has_rank_tensor = torch.tensor(has_rank, device=accelerator.device, dtype=torch.int) # [max_local]
|
| 170 |
+
|
| 171 |
+
gathered_has = accelerator.gather(has_rank_tensor) # [world * max_local]
|
| 172 |
+
gathered_idx = accelerator.gather(idx_rank_tensor) # [world * max_local]
|
| 173 |
+
gathered_imgs = accelerator.gather(imgs_rank_tensor) # [world * max_local, H, W, C]
|
| 174 |
+
|
| 175 |
+
if accelerator.is_main_process:
|
| 176 |
+
world = int(world_size)
|
| 177 |
+
slots = int(max_local)
|
| 178 |
+
try:
|
| 179 |
+
gathered_has = gathered_has.view(world, slots)
|
| 180 |
+
gathered_idx = gathered_idx.view(world, slots)
|
| 181 |
+
gathered_imgs = gathered_imgs.view(world, slots, fixed_size[1], fixed_size[0], 3)
|
| 182 |
+
except Exception:
|
| 183 |
+
# Fallback: treat as flat if reshape fails
|
| 184 |
+
gathered_has = gathered_has.view(-1, 1)
|
| 185 |
+
gathered_idx = gathered_idx.view(-1, 1)
|
| 186 |
+
gathered_imgs = gathered_imgs.view(-1, 1, fixed_size[1], fixed_size[0], 3)
|
| 187 |
+
world = int(gathered_has.shape[0])
|
| 188 |
+
slots = 1
|
| 189 |
+
for i in range(world):
|
| 190 |
+
for j in range(slots):
|
| 191 |
+
if int(gathered_has[i, j].item()) == 1:
|
| 192 |
+
idx = int(gathered_idx[i, j].item())
|
| 193 |
+
arr = gathered_imgs[i, j].cpu().numpy()
|
| 194 |
+
pil_img = Image.fromarray(arr.astype(np.uint8))
|
| 195 |
+
# Resize back to original validation image size
|
| 196 |
+
try:
|
| 197 |
+
orig = Image.open(val_imgs[idx]).convert("RGB")
|
| 198 |
+
pil_img = pil_img.resize(orig.size, Image.BICUBIC)
|
| 199 |
+
except Exception:
|
| 200 |
+
pass
|
| 201 |
+
results.append(pil_img)
|
| 202 |
+
|
| 203 |
+
# Log results (resize to 1024x1024 for saving or external trackers). Skip TensorBoard per request.
|
| 204 |
+
resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
|
| 205 |
+
for tracker in accelerator.trackers:
|
| 206 |
+
phase_name = "test" if is_final_validation else "validation"
|
| 207 |
+
if tracker.name == "tensorboard":
|
| 208 |
+
continue
|
| 209 |
+
if tracker.name == "wandb":
|
| 210 |
+
tracker.log({
|
| 211 |
+
phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
del pipeline
|
| 215 |
+
if torch.cuda.is_available():
|
| 216 |
+
torch.cuda.empty_cache()
|
| 217 |
+
|
| 218 |
+
return results
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
|
| 222 |
+
text_encoder_config = transformers.PretrainedConfig.from_pretrained(
|
| 223 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 224 |
+
)
|
| 225 |
+
model_class = text_encoder_config.architectures[0]
|
| 226 |
+
if model_class == "CLIPTextModel":
|
| 227 |
+
from transformers import CLIPTextModel
|
| 228 |
+
|
| 229 |
+
return CLIPTextModel
|
| 230 |
+
elif model_class == "T5EncoderModel":
|
| 231 |
+
from transformers import T5EncoderModel
|
| 232 |
+
|
| 233 |
+
return T5EncoderModel
|
| 234 |
+
else:
|
| 235 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def parse_args(input_args=None):
|
| 239 |
+
parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
|
| 240 |
+
parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
|
| 241 |
+
parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
|
| 242 |
+
parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
|
| 243 |
+
|
| 244 |
+
parser.add_argument("--train_data_dir", type=str, default="", help="Path to JSONL dataset.")
|
| 245 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
|
| 246 |
+
parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
|
| 247 |
+
parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
|
| 248 |
+
parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
|
| 249 |
+
|
| 250 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
| 251 |
+
parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
|
| 252 |
+
parser.add_argument("--kontext", type=str, default="disable")
|
| 253 |
+
parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
|
| 254 |
+
parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
|
| 255 |
+
parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
|
| 256 |
+
parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
|
| 257 |
+
parser.add_argument("--num_validation_images", type=int, default=4)
|
| 258 |
+
parser.add_argument("--validation_steps", type=int, default=20)
|
| 259 |
+
|
| 260 |
+
parser.add_argument("--ranks", type=int, nargs="+", default=[128], help="LoRA ranks")
|
| 261 |
+
parser.add_argument("--network_alphas", type=int, nargs="+", default=[128], help="LoRA network alphas")
|
| 262 |
+
parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
|
| 263 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 264 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 265 |
+
parser.add_argument("--num_train_epochs", type=int, default=50)
|
| 266 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
| 267 |
+
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
| 268 |
+
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
|
| 269 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 270 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 271 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 272 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 273 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
|
| 274 |
+
parser.add_argument("--scale_lr", action="store_true", default=False)
|
| 275 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant")
|
| 276 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
| 277 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
| 278 |
+
parser.add_argument("--lr_power", type=float, default=1.0)
|
| 279 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=1)
|
| 280 |
+
parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
|
| 281 |
+
parser.add_argument("--logit_mean", type=float, default=0.0)
|
| 282 |
+
parser.add_argument("--logit_std", type=float, default=1.0)
|
| 283 |
+
parser.add_argument("--mode_scale", type=float, default=1.29)
|
| 284 |
+
parser.add_argument("--optimizer", type=str, default="AdamW")
|
| 285 |
+
parser.add_argument("--use_8bit_adam", action="store_true")
|
| 286 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
| 287 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 288 |
+
parser.add_argument("--prodigy_beta3", type=float, default=None)
|
| 289 |
+
parser.add_argument("--prodigy_decouple", type=bool, default=True)
|
| 290 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
|
| 291 |
+
parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
|
| 292 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 293 |
+
parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
|
| 294 |
+
parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
|
| 295 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 296 |
+
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 297 |
+
parser.add_argument("--cache_latents", action="store_true", default=False)
|
| 298 |
+
parser.add_argument("--report_to", type=str, default="tensorboard")
|
| 299 |
+
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
|
| 300 |
+
parser.add_argument("--upcast_before_saving", action="store_true", default=False)
|
| 301 |
+
|
| 302 |
+
if input_args is not None:
|
| 303 |
+
args = parser.parse_args(input_args)
|
| 304 |
+
else:
|
| 305 |
+
args = parser.parse_args()
|
| 306 |
+
return args
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def main(args):
|
| 310 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 311 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 312 |
+
|
| 313 |
+
if args.output_dir is not None:
|
| 314 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 315 |
+
os.makedirs(args.logging_dir, exist_ok=True)
|
| 316 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 317 |
+
|
| 318 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 319 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 320 |
+
accelerator = Accelerator(
|
| 321 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 322 |
+
mixed_precision=args.mixed_precision,
|
| 323 |
+
log_with=args.report_to,
|
| 324 |
+
project_config=accelerator_project_config,
|
| 325 |
+
kwargs_handlers=[kwargs],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if torch.backends.mps.is_available():
|
| 329 |
+
accelerator.native_amp = False
|
| 330 |
+
|
| 331 |
+
if args.report_to == "wandb":
|
| 332 |
+
if not is_wandb_available():
|
| 333 |
+
raise ImportError("Install wandb for logging during training.")
|
| 334 |
+
|
| 335 |
+
logging.basicConfig(
|
| 336 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 337 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 338 |
+
level=logging.INFO,
|
| 339 |
+
)
|
| 340 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 341 |
+
if accelerator.is_local_main_process:
|
| 342 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 343 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 344 |
+
else:
|
| 345 |
+
transformers.utils.logging.set_verbosity_error()
|
| 346 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 347 |
+
|
| 348 |
+
if args.seed is not None:
|
| 349 |
+
set_seed(args.seed)
|
| 350 |
+
|
| 351 |
+
if accelerator.is_main_process and args.output_dir is not None:
|
| 352 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 353 |
+
|
| 354 |
+
# Tokenizers
|
| 355 |
+
tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
|
| 356 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
| 357 |
+
)
|
| 358 |
+
tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
|
| 359 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Text encoders
|
| 363 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
|
| 364 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
|
| 365 |
+
|
| 366 |
+
# Scheduler and models
|
| 367 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 368 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 369 |
+
text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
|
| 370 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
|
| 371 |
+
transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
|
| 372 |
+
|
| 373 |
+
# Train only LoRA adapters
|
| 374 |
+
transformer.requires_grad_(True)
|
| 375 |
+
vae.requires_grad_(False)
|
| 376 |
+
text_encoder_one.requires_grad_(False)
|
| 377 |
+
text_encoder_two.requires_grad_(False)
|
| 378 |
+
|
| 379 |
+
weight_dtype = torch.float32
|
| 380 |
+
if accelerator.mixed_precision == "fp16":
|
| 381 |
+
weight_dtype = torch.float16
|
| 382 |
+
elif accelerator.mixed_precision == "bf16":
|
| 383 |
+
weight_dtype = torch.bfloat16
|
| 384 |
+
|
| 385 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 386 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 387 |
+
|
| 388 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 389 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 390 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 391 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 392 |
+
|
| 393 |
+
if args.gradient_checkpointing:
|
| 394 |
+
transformer.enable_gradient_checkpointing()
|
| 395 |
+
|
| 396 |
+
# Setup LoRA attention processors
|
| 397 |
+
if args.pretrained_lora_path is not None:
|
| 398 |
+
lora_path = args.pretrained_lora_path
|
| 399 |
+
checkpoint = load_checkpoint(lora_path)
|
| 400 |
+
lora_attn_procs = {}
|
| 401 |
+
double_blocks_idx = list(range(19))
|
| 402 |
+
single_blocks_idx = list(range(38))
|
| 403 |
+
number = 1
|
| 404 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 405 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 406 |
+
if match:
|
| 407 |
+
layer_index = int(match.group(1))
|
| 408 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 409 |
+
lora_state_dicts = {}
|
| 410 |
+
for key, value in checkpoint.items():
|
| 411 |
+
if re.search(r'\.(\d+)\.', key):
|
| 412 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 413 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 414 |
+
lora_state_dicts[key] = value
|
| 415 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 416 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 417 |
+
)
|
| 418 |
+
for n in range(number):
|
| 419 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 420 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 421 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 422 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 423 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 424 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 425 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 426 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 427 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 428 |
+
lora_state_dicts = {}
|
| 429 |
+
for key, value in checkpoint.items():
|
| 430 |
+
if re.search(r'\.(\d+)\.', key):
|
| 431 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 432 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 433 |
+
lora_state_dicts[key] = value
|
| 434 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 435 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 436 |
+
)
|
| 437 |
+
for n in range(number):
|
| 438 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 439 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 440 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 441 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 442 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 443 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 444 |
+
else:
|
| 445 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 446 |
+
else:
|
| 447 |
+
lora_attn_procs = {}
|
| 448 |
+
double_blocks_idx = list(range(19))
|
| 449 |
+
single_blocks_idx = list(range(38))
|
| 450 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 451 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 452 |
+
if match:
|
| 453 |
+
layer_index = int(match.group(1))
|
| 454 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 455 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 456 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 457 |
+
)
|
| 458 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 459 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 460 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
lora_attn_procs[name] = attn_processor
|
| 464 |
+
|
| 465 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 466 |
+
transformer.train()
|
| 467 |
+
for n, param in transformer.named_parameters():
|
| 468 |
+
if '_lora' not in n:
|
| 469 |
+
param.requires_grad = False
|
| 470 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
|
| 471 |
+
|
| 472 |
+
def unwrap_model(model):
|
| 473 |
+
model = accelerator.unwrap_model(model)
|
| 474 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 475 |
+
return model
|
| 476 |
+
|
| 477 |
+
if args.resume_from_checkpoint:
|
| 478 |
+
path = args.resume_from_checkpoint
|
| 479 |
+
global_step = int(path.split("-")[-1])
|
| 480 |
+
initial_global_step = global_step
|
| 481 |
+
else:
|
| 482 |
+
initial_global_step = 0
|
| 483 |
+
global_step = 0
|
| 484 |
+
first_epoch = 0
|
| 485 |
+
|
| 486 |
+
if args.scale_lr:
|
| 487 |
+
args.learning_rate = (
|
| 488 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if args.mixed_precision == "fp16":
|
| 492 |
+
models = [transformer]
|
| 493 |
+
cast_training_params(models, dtype=torch.float32)
|
| 494 |
+
|
| 495 |
+
params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
|
| 496 |
+
transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
|
| 497 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
|
| 498 |
+
|
| 499 |
+
optimizer_class = torch.optim.AdamW
|
| 500 |
+
optimizer = optimizer_class(
|
| 501 |
+
[transformer_parameters_with_lr],
|
| 502 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 503 |
+
weight_decay=args.adam_weight_decay,
|
| 504 |
+
eps=args.adam_epsilon,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
tokenizers = [tokenizer_one, tokenizer_two]
|
| 508 |
+
text_encoders = [text_encoder_one, text_encoder_two]
|
| 509 |
+
|
| 510 |
+
train_dataset = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
|
| 511 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 512 |
+
train_dataset,
|
| 513 |
+
batch_size=args.train_batch_size,
|
| 514 |
+
shuffle=True,
|
| 515 |
+
collate_fn=collate_fn,
|
| 516 |
+
num_workers=args.dataloader_num_workers,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 520 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 521 |
+
|
| 522 |
+
overrode_max_train_steps = False
|
| 523 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 524 |
+
if args.resume_from_checkpoint:
|
| 525 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 526 |
+
if args.max_train_steps is None:
|
| 527 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 528 |
+
overrode_max_train_steps = True
|
| 529 |
+
|
| 530 |
+
lr_scheduler = get_scheduler(
|
| 531 |
+
args.lr_scheduler,
|
| 532 |
+
optimizer=optimizer,
|
| 533 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 534 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 535 |
+
num_cycles=args.lr_num_cycles,
|
| 536 |
+
power=args.lr_power,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 540 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 544 |
+
if overrode_max_train_steps:
|
| 545 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 546 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 547 |
+
|
| 548 |
+
# Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
|
| 549 |
+
def _sanitize_hparams(config_dict):
|
| 550 |
+
sanitized = {}
|
| 551 |
+
for key, value in dict(config_dict).items():
|
| 552 |
+
try:
|
| 553 |
+
if value is None:
|
| 554 |
+
continue
|
| 555 |
+
# numpy scalar types
|
| 556 |
+
if isinstance(value, (np.integer,)):
|
| 557 |
+
sanitized[key] = int(value)
|
| 558 |
+
elif isinstance(value, (np.floating,)):
|
| 559 |
+
sanitized[key] = float(value)
|
| 560 |
+
elif isinstance(value, (int, float, bool, str)):
|
| 561 |
+
sanitized[key] = value
|
| 562 |
+
elif isinstance(value, Path):
|
| 563 |
+
sanitized[key] = str(value)
|
| 564 |
+
elif isinstance(value, (list, tuple)):
|
| 565 |
+
# stringify simple sequences; skip if fails
|
| 566 |
+
sanitized[key] = str(value)
|
| 567 |
+
else:
|
| 568 |
+
# best-effort stringify
|
| 569 |
+
sanitized[key] = str(value)
|
| 570 |
+
except Exception:
|
| 571 |
+
# skip unconvertible entries
|
| 572 |
+
continue
|
| 573 |
+
return sanitized
|
| 574 |
+
|
| 575 |
+
if accelerator.is_main_process:
|
| 576 |
+
tracker_name = "Easy_Control_Kontext"
|
| 577 |
+
accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
|
| 578 |
+
|
| 579 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 580 |
+
logger.info("***** Running training *****")
|
| 581 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 582 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 583 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 584 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 585 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 586 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 587 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 588 |
+
|
| 589 |
+
progress_bar = tqdm(
|
| 590 |
+
range(0, args.max_train_steps),
|
| 591 |
+
initial=initial_global_step,
|
| 592 |
+
desc="Steps",
|
| 593 |
+
disable=not accelerator.is_local_main_process,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 597 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 598 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 599 |
+
timesteps = timesteps.to(accelerator.device)
|
| 600 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 601 |
+
sigma = sigmas[step_indices].flatten()
|
| 602 |
+
while len(sigma.shape) < n_dim:
|
| 603 |
+
sigma = sigma.unsqueeze(-1)
|
| 604 |
+
return sigma
|
| 605 |
+
|
| 606 |
+
# Kontext specifics
|
| 607 |
+
vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
|
| 608 |
+
# Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
|
| 609 |
+
height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 610 |
+
width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 611 |
+
offset = 64
|
| 612 |
+
|
| 613 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 614 |
+
transformer.train()
|
| 615 |
+
for step, batch in enumerate(train_dataloader):
|
| 616 |
+
models_to_accumulate = [transformer]
|
| 617 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 618 |
+
tokens = [batch["text_ids_1"], batch["text_ids_2"]]
|
| 619 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
|
| 620 |
+
prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 621 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 622 |
+
text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
|
| 623 |
+
|
| 624 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 625 |
+
height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
|
| 626 |
+
width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
|
| 627 |
+
|
| 628 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 629 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 630 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 631 |
+
|
| 632 |
+
latent_image_ids, cond_latent_image_ids = resize_position_encoding(
|
| 633 |
+
model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
noise = torch.randn_like(model_input)
|
| 637 |
+
bsz = model_input.shape[0]
|
| 638 |
+
|
| 639 |
+
u = compute_density_for_timestep_sampling(
|
| 640 |
+
weighting_scheme=args.weighting_scheme,
|
| 641 |
+
batch_size=bsz,
|
| 642 |
+
logit_mean=args.logit_mean,
|
| 643 |
+
logit_std=args.logit_std,
|
| 644 |
+
mode_scale=args.mode_scale,
|
| 645 |
+
)
|
| 646 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 647 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 648 |
+
|
| 649 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 650 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 651 |
+
|
| 652 |
+
packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
|
| 653 |
+
noisy_model_input,
|
| 654 |
+
batch_size=model_input.shape[0],
|
| 655 |
+
num_channels_latents=model_input.shape[1],
|
| 656 |
+
height=model_input.shape[2],
|
| 657 |
+
width=model_input.shape[3],
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
latent_image_ids_to_concat = [latent_image_ids]
|
| 661 |
+
packed_cond_model_input_to_concat = []
|
| 662 |
+
|
| 663 |
+
if args.kontext == "enable":
|
| 664 |
+
source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
|
| 665 |
+
source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
|
| 666 |
+
source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
|
| 667 |
+
image_latent_h, image_latent_w = source_image_latents.shape[2:]
|
| 668 |
+
packed_image_latents = FluxKontextControlPipeline._pack_latents(
|
| 669 |
+
source_image_latents,
|
| 670 |
+
batch_size=source_image_latents.shape[0],
|
| 671 |
+
num_channels_latents=source_image_latents.shape[1],
|
| 672 |
+
height=image_latent_h,
|
| 673 |
+
width=image_latent_w,
|
| 674 |
+
)
|
| 675 |
+
source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
|
| 676 |
+
batch_size=source_image_latents.shape[0],
|
| 677 |
+
height=image_latent_h // 2,
|
| 678 |
+
width=image_latent_w // 2,
|
| 679 |
+
device=accelerator.device,
|
| 680 |
+
dtype=weight_dtype,
|
| 681 |
+
)
|
| 682 |
+
source_image_ids[..., 0] = 1 # Mark as condition
|
| 683 |
+
latent_image_ids_to_concat.append(source_image_ids)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
subject_pixel_values = batch.get("subject_pixel_values")
|
| 687 |
+
if subject_pixel_values is not None:
|
| 688 |
+
subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
|
| 689 |
+
subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
|
| 690 |
+
subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 691 |
+
subject_input = subject_input.to(dtype=weight_dtype)
|
| 692 |
+
sub_number = subject_pixel_values.shape[-2] // args.cond_size
|
| 693 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
|
| 694 |
+
latent_subject_ids[..., 0] = 2
|
| 695 |
+
latent_subject_ids[:, 1] += offset
|
| 696 |
+
sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
|
| 697 |
+
latent_image_ids_to_concat.append(sub_latent_image_ids)
|
| 698 |
+
|
| 699 |
+
packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
|
| 700 |
+
subject_input,
|
| 701 |
+
batch_size=subject_input.shape[0],
|
| 702 |
+
num_channels_latents=subject_input.shape[1],
|
| 703 |
+
height=subject_input.shape[2],
|
| 704 |
+
width=subject_input.shape[3],
|
| 705 |
+
)
|
| 706 |
+
packed_cond_model_input_to_concat.append(packed_subject_model_input)
|
| 707 |
+
|
| 708 |
+
cond_pixel_values = batch.get("cond_pixel_values")
|
| 709 |
+
if cond_pixel_values is not None:
|
| 710 |
+
cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
|
| 711 |
+
cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
|
| 712 |
+
cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 713 |
+
cond_input = cond_input.to(dtype=weight_dtype)
|
| 714 |
+
cond_number = cond_pixel_values.shape[-2] // args.cond_size
|
| 715 |
+
cond_latent_image_ids[..., 0] = 2
|
| 716 |
+
cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
|
| 717 |
+
latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
|
| 718 |
+
|
| 719 |
+
packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
|
| 720 |
+
cond_input,
|
| 721 |
+
batch_size=cond_input.shape[0],
|
| 722 |
+
num_channels_latents=cond_input.shape[1],
|
| 723 |
+
height=cond_input.shape[2],
|
| 724 |
+
width=cond_input.shape[3],
|
| 725 |
+
)
|
| 726 |
+
packed_cond_model_input_to_concat.append(packed_cond_model_input)
|
| 727 |
+
|
| 728 |
+
latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
|
| 729 |
+
cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
|
| 730 |
+
|
| 731 |
+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
| 732 |
+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
| 733 |
+
guidance = guidance.expand(model_input.shape[0])
|
| 734 |
+
else:
|
| 735 |
+
guidance = None
|
| 736 |
+
|
| 737 |
+
latent_model_input=packed_noisy_model_input
|
| 738 |
+
if args.kontext == "enable":
|
| 739 |
+
latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
|
| 740 |
+
model_pred = transformer(
|
| 741 |
+
hidden_states=latent_model_input,
|
| 742 |
+
cond_hidden_states=cond_packed_noisy_model_input,
|
| 743 |
+
timestep=timesteps / 1000,
|
| 744 |
+
guidance=guidance,
|
| 745 |
+
pooled_projections=pooled_prompt_embeds,
|
| 746 |
+
encoder_hidden_states=prompt_embeds,
|
| 747 |
+
txt_ids=text_ids,
|
| 748 |
+
img_ids=latent_image_ids,
|
| 749 |
+
return_dict=False,
|
| 750 |
+
)[0]
|
| 751 |
+
|
| 752 |
+
model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
|
| 753 |
+
|
| 754 |
+
model_pred = FluxKontextControlPipeline._unpack_latents(
|
| 755 |
+
model_pred,
|
| 756 |
+
height=int(pixel_values.shape[-2]),
|
| 757 |
+
width=int(pixel_values.shape[-1]),
|
| 758 |
+
vae_scale_factor=vae_scale_factor,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 762 |
+
target = noise - model_input
|
| 763 |
+
|
| 764 |
+
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
|
| 765 |
+
loss = loss.mean()
|
| 766 |
+
accelerator.backward(loss)
|
| 767 |
+
if accelerator.sync_gradients:
|
| 768 |
+
params_to_clip = (transformer.parameters())
|
| 769 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 770 |
+
|
| 771 |
+
optimizer.step()
|
| 772 |
+
lr_scheduler.step()
|
| 773 |
+
optimizer.zero_grad()
|
| 774 |
+
|
| 775 |
+
if accelerator.sync_gradients:
|
| 776 |
+
progress_bar.update(1)
|
| 777 |
+
global_step += 1
|
| 778 |
+
|
| 779 |
+
if accelerator.is_main_process:
|
| 780 |
+
if global_step % args.checkpointing_steps == 0:
|
| 781 |
+
if args.checkpoints_total_limit is not None:
|
| 782 |
+
checkpoints = os.listdir(args.output_dir)
|
| 783 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 784 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 785 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 786 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 787 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 788 |
+
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
|
| 789 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 790 |
+
for removing_checkpoint in removing_checkpoints:
|
| 791 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 792 |
+
shutil.rmtree(removing_checkpoint)
|
| 793 |
+
|
| 794 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 795 |
+
os.makedirs(save_path, exist_ok=True)
|
| 796 |
+
unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
|
| 797 |
+
lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
|
| 798 |
+
save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
|
| 799 |
+
logger.info(f"Saved state to {save_path}")
|
| 800 |
+
|
| 801 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 802 |
+
progress_bar.set_postfix(**logs)
|
| 803 |
+
accelerator.log(logs, step=global_step)
|
| 804 |
+
|
| 805 |
+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
| 806 |
+
# Create pipeline on every rank to run validation in parallel
|
| 807 |
+
pipeline = FluxKontextControlPipeline.from_pretrained(
|
| 808 |
+
args.pretrained_model_name_or_path,
|
| 809 |
+
vae=vae,
|
| 810 |
+
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
| 811 |
+
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
| 812 |
+
transformer=accelerator.unwrap_model(transformer),
|
| 813 |
+
revision=args.revision,
|
| 814 |
+
variant=args.variant,
|
| 815 |
+
torch_dtype=weight_dtype,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
|
| 819 |
+
spatial_paths = args.spatial_test_images
|
| 820 |
+
else:
|
| 821 |
+
spatial_paths = []
|
| 822 |
+
|
| 823 |
+
pipeline_args = {
|
| 824 |
+
"prompt": args.validation_prompt,
|
| 825 |
+
"cond_size": args.cond_size,
|
| 826 |
+
"guidance_scale": 3.5,
|
| 827 |
+
"num_inference_steps": 20,
|
| 828 |
+
"max_sequence_length": 128,
|
| 829 |
+
"control_dict": {"spatial_images": spatial_paths},
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
images = log_validation(
|
| 833 |
+
pipeline=pipeline,
|
| 834 |
+
args=args,
|
| 835 |
+
accelerator=accelerator,
|
| 836 |
+
pipeline_args=pipeline_args,
|
| 837 |
+
step=global_step,
|
| 838 |
+
torch_dtype=weight_dtype,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# Only main process saves/logs
|
| 842 |
+
if accelerator.is_main_process:
|
| 843 |
+
save_path = os.path.join(args.output_dir, "validation")
|
| 844 |
+
os.makedirs(save_path, exist_ok=True)
|
| 845 |
+
save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
|
| 846 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 847 |
+
for idx, img in enumerate(images):
|
| 848 |
+
img.save(os.path.join(save_folder, f"{idx}.jpg"))
|
| 849 |
+
del pipeline
|
| 850 |
+
|
| 851 |
+
accelerator.wait_for_everyone()
|
| 852 |
+
accelerator.end_training()
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
if __name__ == "__main__":
|
| 856 |
+
args = parse_args()
|
| 857 |
+
main(args)
|
| 858 |
+
|
train/train_kontext_color.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_DIR="" # your flux path
|
| 2 |
+
export OUTPUT_DIR="" # your save path
|
| 3 |
+
export CONFIG="./default_config.yaml"
|
| 4 |
+
export TRAIN_DATA="" # your data jsonl file
|
| 5 |
+
export LOG_PATH="$OUTPUT_DIR/log"
|
| 6 |
+
|
| 7 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_color.py \
|
| 8 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
| 9 |
+
--lora_num=1 \
|
| 10 |
+
--cond_size=512 \
|
| 11 |
+
--ranks 128 \
|
| 12 |
+
--network_alphas 128 \
|
| 13 |
+
--output_dir=$OUTPUT_DIR \
|
| 14 |
+
--logging_dir=$LOG_PATH \
|
| 15 |
+
--mixed_precision="bf16" \
|
| 16 |
+
--train_data_dir=$TRAIN_DATA \
|
| 17 |
+
--learning_rate=1e-4 \
|
| 18 |
+
--train_batch_size=1 \
|
| 19 |
+
--num_train_epochs=1 \
|
| 20 |
+
--validation_steps=100 \
|
| 21 |
+
--checkpointing_steps=1000 \
|
| 22 |
+
--validation_images "./kontext_color_test/img_1.png" \
|
| 23 |
+
--spatial_test_images "./kontext_color_test/color_1.png" \
|
| 24 |
+
--validation_prompt "Let this woman have red purple and blue hair" \
|
| 25 |
+
--num_validation_images=1
|
train/train_kontext_complete_lora.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_DIR="" # your flux path
|
| 2 |
+
export OUTPUT_DIR="" # your save path
|
| 3 |
+
export CONFIG="./default_config.yaml"
|
| 4 |
+
export LOG_PATH="$OUTPUT_DIR/log"
|
| 5 |
+
|
| 6 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_lora.py \
|
| 7 |
+
--train_data_jsonl "" \
|
| 8 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
| 9 |
+
--output_dir=$OUTPUT_DIR \
|
| 10 |
+
--logging_dir=$LOG_PATH \
|
| 11 |
+
--mixed_precision="bf16" \
|
| 12 |
+
--learning_rate=1e-4 \
|
| 13 |
+
--train_batch_size=1 \
|
| 14 |
+
--num_train_epochs=5 \
|
| 15 |
+
--validation_steps=100 \
|
| 16 |
+
--checkpointing_steps=500 \
|
| 17 |
+
--validation_images "./kontext_complete_test/img_1.png" \
|
| 18 |
+
--validation_prompt "" \
|
| 19 |
+
--gradient_checkpointing \
|
| 20 |
+
--num_validation_images=1
|
train/train_kontext_edge.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
from safetensors.torch import save_file
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
import transformers
|
| 17 |
+
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 21 |
+
|
| 22 |
+
import diffusers
|
| 23 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.optimization import get_scheduler
|
| 25 |
+
from diffusers.training_utils import (
|
| 26 |
+
cast_training_params,
|
| 27 |
+
compute_density_for_timestep_sampling,
|
| 28 |
+
compute_loss_weighting_for_sd3,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 31 |
+
from diffusers.utils import (
|
| 32 |
+
check_min_version,
|
| 33 |
+
is_wandb_available,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from src.prompt_helper import *
|
| 37 |
+
from src.lora_helper import *
|
| 38 |
+
from src.jsonl_datasets_kontext_edge import make_train_dataset_inpaint_mask, collate_fn
|
| 39 |
+
from src.pipeline_flux_kontext_control import (
|
| 40 |
+
FluxKontextControlPipeline,
|
| 41 |
+
resize_position_encoding,
|
| 42 |
+
prepare_latent_subject_ids,
|
| 43 |
+
PREFERRED_KONTEXT_RESOLUTIONS
|
| 44 |
+
)
|
| 45 |
+
from src.transformer_flux import FluxTransformer2DModel
|
| 46 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 47 |
+
from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 48 |
+
from tqdm.auto import tqdm
|
| 49 |
+
|
| 50 |
+
if is_wandb_available():
|
| 51 |
+
import wandb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 55 |
+
check_min_version("0.31.0.dev0")
|
| 56 |
+
|
| 57 |
+
logger = get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def log_validation(
|
| 61 |
+
pipeline,
|
| 62 |
+
args,
|
| 63 |
+
accelerator,
|
| 64 |
+
pipeline_args,
|
| 65 |
+
step,
|
| 66 |
+
torch_dtype,
|
| 67 |
+
is_final_validation=False,
|
| 68 |
+
):
|
| 69 |
+
logger.info(
|
| 70 |
+
f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
|
| 71 |
+
)
|
| 72 |
+
pipeline = pipeline.to(accelerator.device)
|
| 73 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 74 |
+
|
| 75 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 76 |
+
autocast_ctx = nullcontext()
|
| 77 |
+
|
| 78 |
+
# Build per-case evaluation: require equal lengths for image, spatial image, and prompt
|
| 79 |
+
if args.validation_images is None or args.validation_images == ['None']:
|
| 80 |
+
raise ValueError("validation_images must be provided and non-empty")
|
| 81 |
+
if args.validation_prompt is None:
|
| 82 |
+
raise ValueError("validation_prompt must be provided and non-empty")
|
| 83 |
+
|
| 84 |
+
control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
|
| 85 |
+
spatial_ls = control_dict_root.get("spatial_images", []) or []
|
| 86 |
+
|
| 87 |
+
val_imgs = args.validation_images
|
| 88 |
+
prompts = args.validation_prompt
|
| 89 |
+
|
| 90 |
+
if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
results = []
|
| 96 |
+
|
| 97 |
+
def _resize_to_preferred(img: Image.Image) -> Image.Image:
|
| 98 |
+
w, h = img.size
|
| 99 |
+
aspect_ratio = w / h if h != 0 else 1.0
|
| 100 |
+
_, target_w, target_h = min(
|
| 101 |
+
(abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
|
| 102 |
+
for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
|
| 103 |
+
)
|
| 104 |
+
return img.resize((target_w, target_h), Image.BICUBIC)
|
| 105 |
+
|
| 106 |
+
# Strict per-case loop
|
| 107 |
+
num_cases = len(prompts)
|
| 108 |
+
logger.info(f"Paired validation: {num_cases} (image, spatial, prompt) cases")
|
| 109 |
+
with autocast_ctx:
|
| 110 |
+
for idx in range(num_cases):
|
| 111 |
+
resized_img = None
|
| 112 |
+
# If validation image path is a non-empty string, load and resize; otherwise, skip passing image
|
| 113 |
+
if isinstance(val_imgs[idx], str) and val_imgs[idx] != "":
|
| 114 |
+
try:
|
| 115 |
+
base_img = Image.open(val_imgs[idx]).convert("RGB")
|
| 116 |
+
resized_img = _resize_to_preferred(base_img)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
|
| 119 |
+
|
| 120 |
+
case_args = dict(pipeline_args) if pipeline_args is not None else {}
|
| 121 |
+
case_args.pop("height", None)
|
| 122 |
+
case_args.pop("width", None)
|
| 123 |
+
if resized_img is not None:
|
| 124 |
+
tw, th = resized_img.size
|
| 125 |
+
case_args["height"] = th
|
| 126 |
+
case_args["width"] = tw
|
| 127 |
+
else:
|
| 128 |
+
# When no image is provided, default to 1024x1024
|
| 129 |
+
case_args["height"] = 1024
|
| 130 |
+
case_args["width"] = 1024
|
| 131 |
+
|
| 132 |
+
# Bind single spatial control image per case; pass it directly (no masking)
|
| 133 |
+
case_control = dict(case_args.get("control_dict", {}))
|
| 134 |
+
spatial_case = spatial_ls[idx]
|
| 135 |
+
|
| 136 |
+
# Load spatial image if it's a path; else assume it's already an image
|
| 137 |
+
try:
|
| 138 |
+
spatial_img = Image.open(spatial_case).convert("RGB") if isinstance(spatial_case, str) else spatial_case
|
| 139 |
+
except Exception:
|
| 140 |
+
spatial_img = spatial_case
|
| 141 |
+
|
| 142 |
+
case_control["spatial_images"] = [spatial_img]
|
| 143 |
+
case_control["subject_images"] = []
|
| 144 |
+
case_args["control_dict"] = case_control
|
| 145 |
+
|
| 146 |
+
# Override prompt per case
|
| 147 |
+
case_args["prompt"] = prompts[idx]
|
| 148 |
+
|
| 149 |
+
if resized_img is not None:
|
| 150 |
+
img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
|
| 151 |
+
else:
|
| 152 |
+
img = pipeline(**case_args, generator=generator).images[0]
|
| 153 |
+
results.append(img)
|
| 154 |
+
|
| 155 |
+
# Log results (resize to 1024x1024 for logging only)
|
| 156 |
+
resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
|
| 157 |
+
for tracker in accelerator.trackers:
|
| 158 |
+
phase_name = "test" if is_final_validation else "validation"
|
| 159 |
+
if tracker.name == "tensorboard":
|
| 160 |
+
np_images = np.stack([np.asarray(img) for img in resized_for_log])
|
| 161 |
+
tracker.writer.add_images(phase_name, np_images, step, dataformats="NHWC")
|
| 162 |
+
if tracker.name == "wandb":
|
| 163 |
+
tracker.log({
|
| 164 |
+
phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
del pipeline
|
| 168 |
+
if torch.cuda.is_available():
|
| 169 |
+
torch.cuda.empty_cache()
|
| 170 |
+
|
| 171 |
+
return results
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
|
| 175 |
+
text_encoder_config = transformers.PretrainedConfig.from_pretrained(
|
| 176 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 177 |
+
)
|
| 178 |
+
model_class = text_encoder_config.architectures[0]
|
| 179 |
+
if model_class == "CLIPTextModel":
|
| 180 |
+
from transformers import CLIPTextModel
|
| 181 |
+
|
| 182 |
+
return CLIPTextModel
|
| 183 |
+
elif model_class == "T5EncoderModel":
|
| 184 |
+
from transformers import T5EncoderModel
|
| 185 |
+
|
| 186 |
+
return T5EncoderModel
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def parse_args(input_args=None):
|
| 192 |
+
parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
|
| 193 |
+
parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
|
| 194 |
+
parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
|
| 195 |
+
parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
|
| 196 |
+
|
| 197 |
+
parser.add_argument("--train_data_dir", type=str, default="", help="Path to JSONL dataset.")
|
| 198 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
|
| 199 |
+
parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
|
| 200 |
+
parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
|
| 201 |
+
parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
|
| 202 |
+
|
| 203 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
| 204 |
+
parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
|
| 205 |
+
parser.add_argument("--kontext", type=str, default="disable")
|
| 206 |
+
parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
|
| 207 |
+
parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
|
| 208 |
+
parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
|
| 209 |
+
parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
|
| 210 |
+
parser.add_argument("--num_validation_images", type=int, default=4)
|
| 211 |
+
parser.add_argument("--validation_steps", type=int, default=20)
|
| 212 |
+
|
| 213 |
+
parser.add_argument("--ranks", type=int, nargs="+", default=[128], help="LoRA ranks")
|
| 214 |
+
parser.add_argument("--network_alphas", type=int, nargs="+", default=[128], help="LoRA network alphas")
|
| 215 |
+
parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
|
| 216 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 217 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 218 |
+
parser.add_argument("--num_train_epochs", type=int, default=50)
|
| 219 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
| 220 |
+
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
| 221 |
+
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
|
| 222 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 223 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 224 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 225 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 226 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
|
| 227 |
+
parser.add_argument("--scale_lr", action="store_true", default=False)
|
| 228 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant")
|
| 229 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
| 230 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
| 231 |
+
parser.add_argument("--lr_power", type=float, default=1.0)
|
| 232 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=1)
|
| 233 |
+
parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
|
| 234 |
+
parser.add_argument("--logit_mean", type=float, default=0.0)
|
| 235 |
+
parser.add_argument("--logit_std", type=float, default=1.0)
|
| 236 |
+
parser.add_argument("--mode_scale", type=float, default=1.29)
|
| 237 |
+
parser.add_argument("--optimizer", type=str, default="AdamW")
|
| 238 |
+
parser.add_argument("--use_8bit_adam", action="store_true")
|
| 239 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
| 240 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 241 |
+
parser.add_argument("--prodigy_beta3", type=float, default=None)
|
| 242 |
+
parser.add_argument("--prodigy_decouple", type=bool, default=True)
|
| 243 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
|
| 244 |
+
parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
|
| 245 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 246 |
+
parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
|
| 247 |
+
parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
|
| 248 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 249 |
+
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 250 |
+
parser.add_argument("--cache_latents", action="store_true", default=False)
|
| 251 |
+
parser.add_argument("--report_to", type=str, default="tensorboard")
|
| 252 |
+
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
|
| 253 |
+
parser.add_argument("--upcast_before_saving", action="store_true", default=False)
|
| 254 |
+
|
| 255 |
+
if input_args is not None:
|
| 256 |
+
args = parser.parse_args(input_args)
|
| 257 |
+
else:
|
| 258 |
+
args = parser.parse_args()
|
| 259 |
+
return args
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def main(args):
|
| 263 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 264 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 265 |
+
|
| 266 |
+
if args.output_dir is not None:
|
| 267 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 268 |
+
os.makedirs(args.logging_dir, exist_ok=True)
|
| 269 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 270 |
+
|
| 271 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 272 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 273 |
+
accelerator = Accelerator(
|
| 274 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 275 |
+
mixed_precision=args.mixed_precision,
|
| 276 |
+
log_with=args.report_to,
|
| 277 |
+
project_config=accelerator_project_config,
|
| 278 |
+
kwargs_handlers=[kwargs],
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if torch.backends.mps.is_available():
|
| 282 |
+
accelerator.native_amp = False
|
| 283 |
+
|
| 284 |
+
if args.report_to == "wandb":
|
| 285 |
+
if not is_wandb_available():
|
| 286 |
+
raise ImportError("Install wandb for logging during training.")
|
| 287 |
+
|
| 288 |
+
logging.basicConfig(
|
| 289 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 290 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 291 |
+
level=logging.INFO,
|
| 292 |
+
)
|
| 293 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 294 |
+
if accelerator.is_local_main_process:
|
| 295 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 296 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 297 |
+
else:
|
| 298 |
+
transformers.utils.logging.set_verbosity_error()
|
| 299 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 300 |
+
|
| 301 |
+
if args.seed is not None:
|
| 302 |
+
set_seed(args.seed)
|
| 303 |
+
|
| 304 |
+
if accelerator.is_main_process and args.output_dir is not None:
|
| 305 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 306 |
+
|
| 307 |
+
# Tokenizers
|
| 308 |
+
tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
|
| 309 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
| 310 |
+
)
|
| 311 |
+
tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
|
| 312 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Text encoders
|
| 316 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
|
| 317 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
|
| 318 |
+
|
| 319 |
+
# Scheduler and models
|
| 320 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 321 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 322 |
+
text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
|
| 323 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
|
| 324 |
+
transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
|
| 325 |
+
|
| 326 |
+
# Train only LoRA adapters
|
| 327 |
+
transformer.requires_grad_(True)
|
| 328 |
+
vae.requires_grad_(False)
|
| 329 |
+
text_encoder_one.requires_grad_(False)
|
| 330 |
+
text_encoder_two.requires_grad_(False)
|
| 331 |
+
|
| 332 |
+
weight_dtype = torch.float32
|
| 333 |
+
if accelerator.mixed_precision == "fp16":
|
| 334 |
+
weight_dtype = torch.float16
|
| 335 |
+
elif accelerator.mixed_precision == "bf16":
|
| 336 |
+
weight_dtype = torch.bfloat16
|
| 337 |
+
|
| 338 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 339 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 340 |
+
|
| 341 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 342 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 343 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 344 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 345 |
+
|
| 346 |
+
if args.gradient_checkpointing:
|
| 347 |
+
transformer.enable_gradient_checkpointing()
|
| 348 |
+
|
| 349 |
+
# Setup LoRA attention processors
|
| 350 |
+
if args.pretrained_lora_path is not None:
|
| 351 |
+
lora_path = args.pretrained_lora_path
|
| 352 |
+
checkpoint = load_checkpoint(lora_path)
|
| 353 |
+
lora_attn_procs = {}
|
| 354 |
+
double_blocks_idx = list(range(19))
|
| 355 |
+
single_blocks_idx = list(range(38))
|
| 356 |
+
number = 1
|
| 357 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 358 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 359 |
+
if match:
|
| 360 |
+
layer_index = int(match.group(1))
|
| 361 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 362 |
+
lora_state_dicts = {}
|
| 363 |
+
for key, value in checkpoint.items():
|
| 364 |
+
if re.search(r'\.(\d+)\.', key):
|
| 365 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 366 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 367 |
+
lora_state_dicts[key] = value
|
| 368 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 369 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 370 |
+
)
|
| 371 |
+
for n in range(number):
|
| 372 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 373 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 374 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 375 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 376 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 377 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 378 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 379 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 380 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 381 |
+
lora_state_dicts = {}
|
| 382 |
+
for key, value in checkpoint.items():
|
| 383 |
+
if re.search(r'\.(\d+)\.', key):
|
| 384 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 385 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 386 |
+
lora_state_dicts[key] = value
|
| 387 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 388 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 389 |
+
)
|
| 390 |
+
for n in range(number):
|
| 391 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 392 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 393 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 394 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 395 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 396 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 397 |
+
else:
|
| 398 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 399 |
+
else:
|
| 400 |
+
lora_attn_procs = {}
|
| 401 |
+
double_blocks_idx = list(range(19))
|
| 402 |
+
single_blocks_idx = list(range(38))
|
| 403 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 404 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 405 |
+
if match:
|
| 406 |
+
layer_index = int(match.group(1))
|
| 407 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 408 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 409 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 410 |
+
)
|
| 411 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 412 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 413 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
lora_attn_procs[name] = attn_processor
|
| 417 |
+
|
| 418 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 419 |
+
transformer.train()
|
| 420 |
+
for n, param in transformer.named_parameters():
|
| 421 |
+
if '_lora' not in n:
|
| 422 |
+
param.requires_grad = False
|
| 423 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
|
| 424 |
+
|
| 425 |
+
def unwrap_model(model):
|
| 426 |
+
model = accelerator.unwrap_model(model)
|
| 427 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 428 |
+
return model
|
| 429 |
+
|
| 430 |
+
if args.resume_from_checkpoint:
|
| 431 |
+
path = args.resume_from_checkpoint
|
| 432 |
+
global_step = int(path.split("-")[-1])
|
| 433 |
+
initial_global_step = global_step
|
| 434 |
+
else:
|
| 435 |
+
initial_global_step = 0
|
| 436 |
+
global_step = 0
|
| 437 |
+
first_epoch = 0
|
| 438 |
+
|
| 439 |
+
if args.scale_lr:
|
| 440 |
+
args.learning_rate = (
|
| 441 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if args.mixed_precision == "fp16":
|
| 445 |
+
models = [transformer]
|
| 446 |
+
cast_training_params(models, dtype=torch.float32)
|
| 447 |
+
|
| 448 |
+
params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
|
| 449 |
+
transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
|
| 450 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
|
| 451 |
+
|
| 452 |
+
optimizer_class = torch.optim.AdamW
|
| 453 |
+
optimizer = optimizer_class(
|
| 454 |
+
[transformer_parameters_with_lr],
|
| 455 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 456 |
+
weight_decay=args.adam_weight_decay,
|
| 457 |
+
eps=args.adam_epsilon,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
tokenizers = [tokenizer_one, tokenizer_two]
|
| 461 |
+
text_encoders = [text_encoder_one, text_encoder_two]
|
| 462 |
+
|
| 463 |
+
train_dataset = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
|
| 464 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 465 |
+
train_dataset,
|
| 466 |
+
batch_size=args.train_batch_size,
|
| 467 |
+
shuffle=True,
|
| 468 |
+
collate_fn=collate_fn,
|
| 469 |
+
num_workers=args.dataloader_num_workers,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 473 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 474 |
+
|
| 475 |
+
overrode_max_train_steps = False
|
| 476 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 477 |
+
if args.resume_from_checkpoint:
|
| 478 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 479 |
+
if args.max_train_steps is None:
|
| 480 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 481 |
+
overrode_max_train_steps = True
|
| 482 |
+
|
| 483 |
+
lr_scheduler = get_scheduler(
|
| 484 |
+
args.lr_scheduler,
|
| 485 |
+
optimizer=optimizer,
|
| 486 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 487 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 488 |
+
num_cycles=args.lr_num_cycles,
|
| 489 |
+
power=args.lr_power,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 493 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 497 |
+
if overrode_max_train_steps:
|
| 498 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 499 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 500 |
+
|
| 501 |
+
# Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
|
| 502 |
+
def _sanitize_hparams(config_dict):
|
| 503 |
+
sanitized = {}
|
| 504 |
+
for key, value in dict(config_dict).items():
|
| 505 |
+
try:
|
| 506 |
+
if value is None:
|
| 507 |
+
continue
|
| 508 |
+
# numpy scalar types
|
| 509 |
+
if isinstance(value, (np.integer,)):
|
| 510 |
+
sanitized[key] = int(value)
|
| 511 |
+
elif isinstance(value, (np.floating,)):
|
| 512 |
+
sanitized[key] = float(value)
|
| 513 |
+
elif isinstance(value, (int, float, bool, str)):
|
| 514 |
+
sanitized[key] = value
|
| 515 |
+
elif isinstance(value, Path):
|
| 516 |
+
sanitized[key] = str(value)
|
| 517 |
+
elif isinstance(value, (list, tuple)):
|
| 518 |
+
# stringify simple sequences; skip if fails
|
| 519 |
+
sanitized[key] = str(value)
|
| 520 |
+
else:
|
| 521 |
+
# best-effort stringify
|
| 522 |
+
sanitized[key] = str(value)
|
| 523 |
+
except Exception:
|
| 524 |
+
# skip unconvertible entries
|
| 525 |
+
continue
|
| 526 |
+
return sanitized
|
| 527 |
+
|
| 528 |
+
if accelerator.is_main_process:
|
| 529 |
+
tracker_name = "Easy_Control_Kontext"
|
| 530 |
+
accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
|
| 531 |
+
|
| 532 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 533 |
+
logger.info("***** Running training *****")
|
| 534 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 535 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 536 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 537 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 538 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 539 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 540 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 541 |
+
|
| 542 |
+
progress_bar = tqdm(
|
| 543 |
+
range(0, args.max_train_steps),
|
| 544 |
+
initial=initial_global_step,
|
| 545 |
+
desc="Steps",
|
| 546 |
+
disable=not accelerator.is_local_main_process,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 550 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 551 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 552 |
+
timesteps = timesteps.to(accelerator.device)
|
| 553 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 554 |
+
sigma = sigmas[step_indices].flatten()
|
| 555 |
+
while len(sigma.shape) < n_dim:
|
| 556 |
+
sigma = sigma.unsqueeze(-1)
|
| 557 |
+
return sigma
|
| 558 |
+
|
| 559 |
+
# Kontext specifics
|
| 560 |
+
vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
|
| 561 |
+
# Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
|
| 562 |
+
height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 563 |
+
width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 564 |
+
offset = 64
|
| 565 |
+
|
| 566 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 567 |
+
transformer.train()
|
| 568 |
+
for step, batch in enumerate(train_dataloader):
|
| 569 |
+
models_to_accumulate = [transformer]
|
| 570 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 571 |
+
tokens = [batch["text_ids_1"], batch["text_ids_2"]]
|
| 572 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
|
| 573 |
+
prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 574 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 575 |
+
text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
|
| 576 |
+
|
| 577 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 578 |
+
height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
|
| 579 |
+
width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
|
| 580 |
+
|
| 581 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 582 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 583 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 584 |
+
|
| 585 |
+
latent_image_ids, cond_latent_image_ids = resize_position_encoding(
|
| 586 |
+
model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
noise = torch.randn_like(model_input)
|
| 590 |
+
bsz = model_input.shape[0]
|
| 591 |
+
|
| 592 |
+
u = compute_density_for_timestep_sampling(
|
| 593 |
+
weighting_scheme=args.weighting_scheme,
|
| 594 |
+
batch_size=bsz,
|
| 595 |
+
logit_mean=args.logit_mean,
|
| 596 |
+
logit_std=args.logit_std,
|
| 597 |
+
mode_scale=args.mode_scale,
|
| 598 |
+
)
|
| 599 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 600 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 601 |
+
|
| 602 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 603 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 604 |
+
|
| 605 |
+
packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
|
| 606 |
+
noisy_model_input,
|
| 607 |
+
batch_size=model_input.shape[0],
|
| 608 |
+
num_channels_latents=model_input.shape[1],
|
| 609 |
+
height=model_input.shape[2],
|
| 610 |
+
width=model_input.shape[3],
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
latent_image_ids_to_concat = [latent_image_ids]
|
| 614 |
+
packed_cond_model_input_to_concat = []
|
| 615 |
+
|
| 616 |
+
if args.kontext == "enable":
|
| 617 |
+
source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
|
| 618 |
+
source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
|
| 619 |
+
source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
|
| 620 |
+
image_latent_h, image_latent_w = source_image_latents.shape[2:]
|
| 621 |
+
packed_image_latents = FluxKontextControlPipeline._pack_latents(
|
| 622 |
+
source_image_latents,
|
| 623 |
+
batch_size=source_image_latents.shape[0],
|
| 624 |
+
num_channels_latents=source_image_latents.shape[1],
|
| 625 |
+
height=image_latent_h,
|
| 626 |
+
width=image_latent_w,
|
| 627 |
+
)
|
| 628 |
+
source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
|
| 629 |
+
batch_size=source_image_latents.shape[0],
|
| 630 |
+
height=image_latent_h // 2,
|
| 631 |
+
width=image_latent_w // 2,
|
| 632 |
+
device=accelerator.device,
|
| 633 |
+
dtype=weight_dtype,
|
| 634 |
+
)
|
| 635 |
+
source_image_ids[..., 0] = 1 # Mark as condition
|
| 636 |
+
latent_image_ids_to_concat.append(source_image_ids)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
subject_pixel_values = batch.get("subject_pixel_values")
|
| 640 |
+
if subject_pixel_values is not None:
|
| 641 |
+
subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
|
| 642 |
+
subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
|
| 643 |
+
subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 644 |
+
subject_input = subject_input.to(dtype=weight_dtype)
|
| 645 |
+
sub_number = subject_pixel_values.shape[-2] // args.cond_size
|
| 646 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
|
| 647 |
+
latent_subject_ids[..., 0] = 2
|
| 648 |
+
latent_subject_ids[:, 1] += offset
|
| 649 |
+
sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
|
| 650 |
+
latent_image_ids_to_concat.append(sub_latent_image_ids)
|
| 651 |
+
|
| 652 |
+
packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
|
| 653 |
+
subject_input,
|
| 654 |
+
batch_size=subject_input.shape[0],
|
| 655 |
+
num_channels_latents=subject_input.shape[1],
|
| 656 |
+
height=subject_input.shape[2],
|
| 657 |
+
width=subject_input.shape[3],
|
| 658 |
+
)
|
| 659 |
+
packed_cond_model_input_to_concat.append(packed_subject_model_input)
|
| 660 |
+
|
| 661 |
+
cond_pixel_values = batch.get("cond_pixel_values")
|
| 662 |
+
if cond_pixel_values is not None:
|
| 663 |
+
cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
|
| 664 |
+
cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
|
| 665 |
+
cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 666 |
+
cond_input = cond_input.to(dtype=weight_dtype)
|
| 667 |
+
cond_number = cond_pixel_values.shape[-2] // args.cond_size
|
| 668 |
+
cond_latent_image_ids[..., 0] = 2
|
| 669 |
+
cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
|
| 670 |
+
latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
|
| 671 |
+
|
| 672 |
+
packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
|
| 673 |
+
cond_input,
|
| 674 |
+
batch_size=cond_input.shape[0],
|
| 675 |
+
num_channels_latents=cond_input.shape[1],
|
| 676 |
+
height=cond_input.shape[2],
|
| 677 |
+
width=cond_input.shape[3],
|
| 678 |
+
)
|
| 679 |
+
packed_cond_model_input_to_concat.append(packed_cond_model_input)
|
| 680 |
+
|
| 681 |
+
latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
|
| 682 |
+
cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
|
| 683 |
+
|
| 684 |
+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
| 685 |
+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
| 686 |
+
guidance = guidance.expand(model_input.shape[0])
|
| 687 |
+
else:
|
| 688 |
+
guidance = None
|
| 689 |
+
|
| 690 |
+
latent_model_input=packed_noisy_model_input
|
| 691 |
+
if args.kontext == "enable":
|
| 692 |
+
latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
|
| 693 |
+
model_pred = transformer(
|
| 694 |
+
hidden_states=latent_model_input,
|
| 695 |
+
cond_hidden_states=cond_packed_noisy_model_input,
|
| 696 |
+
timestep=timesteps / 1000,
|
| 697 |
+
guidance=guidance,
|
| 698 |
+
pooled_projections=pooled_prompt_embeds,
|
| 699 |
+
encoder_hidden_states=prompt_embeds,
|
| 700 |
+
txt_ids=text_ids,
|
| 701 |
+
img_ids=latent_image_ids,
|
| 702 |
+
return_dict=False,
|
| 703 |
+
)[0]
|
| 704 |
+
|
| 705 |
+
model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
|
| 706 |
+
|
| 707 |
+
model_pred = FluxKontextControlPipeline._unpack_latents(
|
| 708 |
+
model_pred,
|
| 709 |
+
height=int(pixel_values.shape[-2]),
|
| 710 |
+
width=int(pixel_values.shape[-1]),
|
| 711 |
+
vae_scale_factor=vae_scale_factor,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 715 |
+
target = noise - model_input
|
| 716 |
+
|
| 717 |
+
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
|
| 718 |
+
loss = loss.mean()
|
| 719 |
+
accelerator.backward(loss)
|
| 720 |
+
if accelerator.sync_gradients:
|
| 721 |
+
params_to_clip = (transformer.parameters())
|
| 722 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 723 |
+
|
| 724 |
+
optimizer.step()
|
| 725 |
+
lr_scheduler.step()
|
| 726 |
+
optimizer.zero_grad()
|
| 727 |
+
|
| 728 |
+
if accelerator.sync_gradients:
|
| 729 |
+
progress_bar.update(1)
|
| 730 |
+
global_step += 1
|
| 731 |
+
|
| 732 |
+
if accelerator.is_main_process:
|
| 733 |
+
if global_step % args.checkpointing_steps == 0:
|
| 734 |
+
if args.checkpoints_total_limit is not None:
|
| 735 |
+
checkpoints = os.listdir(args.output_dir)
|
| 736 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 737 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 738 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 739 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 740 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 741 |
+
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
|
| 742 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 743 |
+
for removing_checkpoint in removing_checkpoints:
|
| 744 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 745 |
+
shutil.rmtree(removing_checkpoint)
|
| 746 |
+
|
| 747 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 748 |
+
os.makedirs(save_path, exist_ok=True)
|
| 749 |
+
unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
|
| 750 |
+
lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
|
| 751 |
+
save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
|
| 752 |
+
logger.info(f"Saved state to {save_path}")
|
| 753 |
+
|
| 754 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 755 |
+
progress_bar.set_postfix(**logs)
|
| 756 |
+
accelerator.log(logs, step=global_step)
|
| 757 |
+
|
| 758 |
+
if accelerator.is_main_process:
|
| 759 |
+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
| 760 |
+
pipeline = FluxKontextControlPipeline.from_pretrained(
|
| 761 |
+
args.pretrained_model_name_or_path,
|
| 762 |
+
vae=vae,
|
| 763 |
+
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
| 764 |
+
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
| 765 |
+
transformer=accelerator.unwrap_model(transformer),
|
| 766 |
+
revision=args.revision,
|
| 767 |
+
variant=args.variant,
|
| 768 |
+
torch_dtype=weight_dtype,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
if args.subject_test_images is not None and len(args.subject_test_images) != 0 and args.subject_test_images != ['None']:
|
| 772 |
+
subject_paths = args.subject_test_images
|
| 773 |
+
subject_ls = [Image.open(image_path).convert("RGB") for image_path in subject_paths]
|
| 774 |
+
else:
|
| 775 |
+
subject_ls = []
|
| 776 |
+
if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
|
| 777 |
+
spatial_paths = args.spatial_test_images
|
| 778 |
+
spatial_ls = [Image.open(image_path).convert("RGB") for image_path in spatial_paths]
|
| 779 |
+
else:
|
| 780 |
+
spatial_ls = []
|
| 781 |
+
|
| 782 |
+
pipeline_args = {
|
| 783 |
+
"prompt": args.validation_prompt,
|
| 784 |
+
"cond_size": args.cond_size,
|
| 785 |
+
"guidance_scale": 3.5,
|
| 786 |
+
"num_inference_steps": 20,
|
| 787 |
+
"max_sequence_length": 128,
|
| 788 |
+
"control_dict": {"spatial_images": spatial_ls, "subject_images": subject_ls},
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
images = log_validation(
|
| 792 |
+
pipeline=pipeline,
|
| 793 |
+
args=args,
|
| 794 |
+
accelerator=accelerator,
|
| 795 |
+
pipeline_args=pipeline_args,
|
| 796 |
+
step=global_step,
|
| 797 |
+
torch_dtype=weight_dtype,
|
| 798 |
+
)
|
| 799 |
+
save_path = os.path.join(args.output_dir, "validation")
|
| 800 |
+
os.makedirs(save_path, exist_ok=True)
|
| 801 |
+
save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
|
| 802 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 803 |
+
for idx, img in enumerate(images):
|
| 804 |
+
img.save(os.path.join(save_folder, f"{idx}.jpg"))
|
| 805 |
+
del pipeline
|
| 806 |
+
|
| 807 |
+
accelerator.wait_for_everyone()
|
| 808 |
+
accelerator.end_training()
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
if __name__ == "__main__":
|
| 812 |
+
args = parse_args()
|
| 813 |
+
main(args)
|
| 814 |
+
|
train/train_kontext_edge.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_DIR="/robby/share/Editing/lzc/FLUX.1-Kontext-dev" # your flux path
|
| 2 |
+
export OUTPUT_DIR="/robby/share/Editing/lzc/EasyControl_kontext_edge_test_hed" # your save path
|
| 3 |
+
export CONFIG="./default_config.yaml"
|
| 4 |
+
export TRAIN_DATA="/robby/share/MM/zkc/data/i2i_csv/pexel_Qwen2_5VL7BInstruct.csv " # your data jsonl file
|
| 5 |
+
export LOG_PATH="$OUTPUT_DIR/log"
|
| 6 |
+
|
| 7 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_edge.py \
|
| 8 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
| 9 |
+
--lora_num=1 \
|
| 10 |
+
--cond_size=512 \
|
| 11 |
+
--ranks 128 \
|
| 12 |
+
--network_alphas 128 \
|
| 13 |
+
--output_dir=$OUTPUT_DIR \
|
| 14 |
+
--logging_dir=$LOG_PATH \
|
| 15 |
+
--mixed_precision="bf16" \
|
| 16 |
+
--train_data_dir=$TRAIN_DATA \
|
| 17 |
+
--learning_rate=1e-4 \
|
| 18 |
+
--train_batch_size=1 \
|
| 19 |
+
--num_train_epochs=1 \
|
| 20 |
+
--validation_steps=500 \
|
| 21 |
+
--checkpointing_steps=1000 \
|
| 22 |
+
--validation_images "./kontext_edge_test/img_1.png" "./kontext_edge_test/img_2.png" "" "" "./kontext_edge_test/img_3.png" \
|
| 23 |
+
--spatial_test_images "./kontext_edge_test/edge_1.png" "./kontext_edge_test/edge_2.png" "./kontext_edge_test/edge_1.png" "./kontext_edge_test/edge_2.png" "./kontext_edge_test/edge_3.png" \
|
| 24 |
+
--validation_prompt "The cake was cut off a piece" "Let this black woman wearing a transparent sunglasses" "This image shows a beautifully decorated cake with golden-orange sides and white frosting on top, and a piece of cake is being cut. The cake is displayed on a rustic wooden slice that serves as a cake stand." "This is a striking portrait photograph featuring a person wearing an ornate golden crown and a heart-shape sunglasses. The subject has dramatic golden metallic eyeshadow that extends across their eyelids, complementing the warm tones of the crown." "move the cup to the left" \
|
| 25 |
+
--num_validation_images=1
|
train/train_kontext_interactive_lora.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_DIR="" # your flux path
|
| 2 |
+
export OUTPUT_DIR="" # your save path
|
| 3 |
+
export CONFIG="./default_config.yaml"
|
| 4 |
+
export LOG_PATH="$OUTPUT_DIR/log"
|
| 5 |
+
|
| 6 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_lora.py \
|
| 7 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
| 8 |
+
--output_dir=$OUTPUT_DIR \
|
| 9 |
+
--logging_dir=$LOG_PATH \
|
| 10 |
+
--mixed_precision="bf16" \
|
| 11 |
+
--learning_rate=1e-4 \
|
| 12 |
+
--train_batch_size=1 \
|
| 13 |
+
--num_train_epochs=10 \
|
| 14 |
+
--validation_steps=100 \
|
| 15 |
+
--checkpointing_steps=500 \
|
| 16 |
+
--validation_images "./kontext_interactive_test/img_1.png" \
|
| 17 |
+
--validation_prompt "Let the man hold the AK47 using both hands." \
|
| 18 |
+
--num_validation_images=1
|
train/train_kontext_local.py
ADDED
|
@@ -0,0 +1,876 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
from safetensors.torch import save_file
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
import transformers
|
| 17 |
+
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 21 |
+
|
| 22 |
+
import diffusers
|
| 23 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.optimization import get_scheduler
|
| 25 |
+
from diffusers.training_utils import (
|
| 26 |
+
cast_training_params,
|
| 27 |
+
compute_density_for_timestep_sampling,
|
| 28 |
+
compute_loss_weighting_for_sd3,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 31 |
+
from diffusers.utils import (
|
| 32 |
+
check_min_version,
|
| 33 |
+
is_wandb_available,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from src.prompt_helper import *
|
| 37 |
+
from src.lora_helper import *
|
| 38 |
+
from src.jsonl_datasets_kontext_local import make_train_dataset_mixed, collate_fn
|
| 39 |
+
from src.pipeline_flux_kontext_control import (
|
| 40 |
+
FluxKontextControlPipeline,
|
| 41 |
+
resize_position_encoding,
|
| 42 |
+
prepare_latent_subject_ids,
|
| 43 |
+
PREFERRED_KONTEXT_RESOLUTIONS
|
| 44 |
+
)
|
| 45 |
+
from src.transformer_flux import FluxTransformer2DModel
|
| 46 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
| 47 |
+
from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
|
| 48 |
+
from tqdm.auto import tqdm
|
| 49 |
+
|
| 50 |
+
if is_wandb_available():
|
| 51 |
+
import wandb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 55 |
+
check_min_version("0.31.0.dev0")
|
| 56 |
+
|
| 57 |
+
logger = get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def compute_background_preserving_loss(model_pred, target, mask_values, weighting, background_weight: float = 3.0):
|
| 61 |
+
"""
|
| 62 |
+
Compute loss with higher penalty on background (non-masked) regions to preserve them.
|
| 63 |
+
model_pred/target: [B, C, H, W]
|
| 64 |
+
mask_values: [B, 1, H_img, W_img] with values in {0,1} at image resolution
|
| 65 |
+
weighting: broadcastable to [B, C, H, W]
|
| 66 |
+
Returns per-pixel loss map [B, C, H, W]
|
| 67 |
+
"""
|
| 68 |
+
base_loss = (weighting.float() * (model_pred.float() - target.float()) ** 2)
|
| 69 |
+
mask_latent = torch.nn.functional.interpolate(
|
| 70 |
+
mask_values,
|
| 71 |
+
size=(model_pred.shape[2], model_pred.shape[3]),
|
| 72 |
+
mode='bilinear',
|
| 73 |
+
align_corners=False,
|
| 74 |
+
)
|
| 75 |
+
foreground_mask = mask_latent
|
| 76 |
+
background_mask = 1.0 - mask_latent
|
| 77 |
+
foreground_mask = foreground_mask.expand_as(base_loss)
|
| 78 |
+
background_mask = background_mask.expand_as(base_loss)
|
| 79 |
+
foreground_loss = base_loss * foreground_mask
|
| 80 |
+
background_loss = base_loss * background_mask * float(background_weight)
|
| 81 |
+
total_loss = foreground_loss + background_loss
|
| 82 |
+
return total_loss
|
| 83 |
+
|
| 84 |
+
def log_validation(
|
| 85 |
+
pipeline,
|
| 86 |
+
args,
|
| 87 |
+
accelerator,
|
| 88 |
+
pipeline_args,
|
| 89 |
+
step,
|
| 90 |
+
torch_dtype,
|
| 91 |
+
is_final_validation=False,
|
| 92 |
+
):
|
| 93 |
+
logger.info(
|
| 94 |
+
f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
|
| 95 |
+
)
|
| 96 |
+
pipeline = pipeline.to(accelerator.device)
|
| 97 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 98 |
+
|
| 99 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 100 |
+
autocast_ctx = nullcontext()
|
| 101 |
+
|
| 102 |
+
# Build per-case evaluation: require equal lengths for image, spatial image, and prompt
|
| 103 |
+
if args.validation_images is None or args.validation_images == ['None']:
|
| 104 |
+
raise ValueError("validation_images must be provided and non-empty")
|
| 105 |
+
if args.validation_prompt is None:
|
| 106 |
+
raise ValueError("validation_prompt must be provided and non-empty")
|
| 107 |
+
|
| 108 |
+
control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
|
| 109 |
+
spatial_ls = control_dict_root.get("spatial_images", []) or []
|
| 110 |
+
|
| 111 |
+
val_imgs = args.validation_images
|
| 112 |
+
prompts = args.validation_prompt
|
| 113 |
+
|
| 114 |
+
if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
results = []
|
| 120 |
+
|
| 121 |
+
def _resize_to_preferred(img: Image.Image) -> Image.Image:
|
| 122 |
+
w, h = img.size
|
| 123 |
+
aspect_ratio = w / h if h != 0 else 1.0
|
| 124 |
+
_, target_w, target_h = min(
|
| 125 |
+
(abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
|
| 126 |
+
for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
|
| 127 |
+
)
|
| 128 |
+
return img.resize((target_w, target_h), Image.BICUBIC)
|
| 129 |
+
|
| 130 |
+
# Distributed per-rank assignment: each process handles its own slice of cases
|
| 131 |
+
num_cases = len(prompts)
|
| 132 |
+
logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
|
| 133 |
+
|
| 134 |
+
rank = accelerator.process_index
|
| 135 |
+
world_size = accelerator.num_processes
|
| 136 |
+
local_indices = list(range(rank, num_cases, world_size))
|
| 137 |
+
|
| 138 |
+
local_images = []
|
| 139 |
+
with autocast_ctx:
|
| 140 |
+
for idx in local_indices:
|
| 141 |
+
try:
|
| 142 |
+
base_img = Image.open(val_imgs[idx]).convert("RGB")
|
| 143 |
+
resized_img = _resize_to_preferred(base_img)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
|
| 146 |
+
|
| 147 |
+
case_args = dict(pipeline_args) if pipeline_args is not None else {}
|
| 148 |
+
case_args.pop("height", None)
|
| 149 |
+
case_args.pop("width", None)
|
| 150 |
+
if resized_img is not None:
|
| 151 |
+
tw, th = resized_img.size
|
| 152 |
+
case_args["height"] = th
|
| 153 |
+
case_args["width"] = tw
|
| 154 |
+
|
| 155 |
+
case_control = dict(case_args.get("control_dict", {}))
|
| 156 |
+
spatial_case = spatial_ls[idx]
|
| 157 |
+
|
| 158 |
+
# Compose masked image cond: resized_img * (1 - binary_mask)
|
| 159 |
+
try:
|
| 160 |
+
mask_img = Image.open(spatial_case).convert("L") if isinstance(spatial_case, str) else spatial_case.convert("L")
|
| 161 |
+
except Exception:
|
| 162 |
+
mask_img = spatial_case.convert("L")
|
| 163 |
+
mask_img = mask_img.resize(resized_img.size, Image.NEAREST)
|
| 164 |
+
mask_np = np.array(mask_img)
|
| 165 |
+
mask_bin = (mask_np > 127).astype(np.uint8)
|
| 166 |
+
inv_mask = (1 - mask_bin).astype(np.uint8)
|
| 167 |
+
base_np = np.array(resized_img)
|
| 168 |
+
masked_np = base_np * inv_mask[..., None]
|
| 169 |
+
masked_img = Image.fromarray(masked_np.astype(np.uint8))
|
| 170 |
+
|
| 171 |
+
case_control["spatial_images"] = [masked_img]
|
| 172 |
+
case_args["control_dict"] = case_control
|
| 173 |
+
|
| 174 |
+
case_args["prompt"] = prompts[idx]
|
| 175 |
+
img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
|
| 176 |
+
local_images.append(img)
|
| 177 |
+
|
| 178 |
+
# Gather one image per rank (pad missing ranks with black images) to main process
|
| 179 |
+
fixed_size = (1024, 1024)
|
| 180 |
+
has_sample = torch.tensor([1 if len(local_images) > 0 else 0], device=accelerator.device, dtype=torch.int)
|
| 181 |
+
local_idx = torch.tensor([local_indices[0] if len(local_indices) > 0 else -1], device=accelerator.device, dtype=torch.long)
|
| 182 |
+
if len(local_images) > 0:
|
| 183 |
+
gathered_img = local_images[0].resize(fixed_size, Image.BICUBIC)
|
| 184 |
+
img_np = np.asarray(gathered_img).astype(np.uint8)
|
| 185 |
+
else:
|
| 186 |
+
img_np = np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
|
| 187 |
+
img_tensor = torch.from_numpy(img_np).to(device=accelerator.device)
|
| 188 |
+
if img_tensor.ndim == 3:
|
| 189 |
+
img_tensor = img_tensor.unsqueeze(0)
|
| 190 |
+
|
| 191 |
+
gathered_has = accelerator.gather(has_sample)
|
| 192 |
+
gathered_idx = accelerator.gather(local_idx)
|
| 193 |
+
gathered_imgs = accelerator.gather(img_tensor)
|
| 194 |
+
|
| 195 |
+
if accelerator.is_main_process:
|
| 196 |
+
for i in range(int(gathered_has.shape[0])):
|
| 197 |
+
if int(gathered_has[i].item()) == 1:
|
| 198 |
+
idx = int(gathered_idx[i].item())
|
| 199 |
+
arr = gathered_imgs[i].cpu().numpy()
|
| 200 |
+
pil_img = Image.fromarray(arr.astype(np.uint8))
|
| 201 |
+
# Resize back to original validation image size
|
| 202 |
+
try:
|
| 203 |
+
orig = Image.open(val_imgs[idx]).convert("RGB")
|
| 204 |
+
pil_img = pil_img.resize(orig.size, Image.BICUBIC)
|
| 205 |
+
except Exception:
|
| 206 |
+
pass
|
| 207 |
+
results.append(pil_img)
|
| 208 |
+
|
| 209 |
+
del pipeline
|
| 210 |
+
if torch.cuda.is_available():
|
| 211 |
+
torch.cuda.empty_cache()
|
| 212 |
+
|
| 213 |
+
return results
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
|
| 217 |
+
text_encoder_config = transformers.PretrainedConfig.from_pretrained(
|
| 218 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 219 |
+
)
|
| 220 |
+
model_class = text_encoder_config.architectures[0]
|
| 221 |
+
if model_class == "CLIPTextModel":
|
| 222 |
+
from transformers import CLIPTextModel
|
| 223 |
+
|
| 224 |
+
return CLIPTextModel
|
| 225 |
+
elif model_class == "T5EncoderModel":
|
| 226 |
+
from transformers import T5EncoderModel
|
| 227 |
+
|
| 228 |
+
return T5EncoderModel
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def parse_args(input_args=None):
|
| 234 |
+
parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
|
| 235 |
+
parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
|
| 236 |
+
parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
|
| 237 |
+
parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
|
| 238 |
+
|
| 239 |
+
# New dataset (local edits + inpaint JSONL) mixed 1:1
|
| 240 |
+
parser.add_argument("--local_edits_json", type=str, default="/robby/share/Editing/qingyan/InstructV2V/Qwen2_5_72B_instructs_10W.json", help="Path to local edits JSON")
|
| 241 |
+
parser.add_argument("--train_data_dir", type=str, default="/robby/share/Editing/lzc/data/pexel_final/inpaint_edit_outputs_merged.jsonl", help="Path to inpaint JSONL file for mixing 1:1")
|
| 242 |
+
parser.add_argument("--source_frames_dir", type=str, default="/robby/share/Editing/qingyan/InstructV2V/pexel-video-merged-1frame", help="Root dir containing group folders like 0139")
|
| 243 |
+
parser.add_argument("--target_frames_dir", type=str, default="/robby/share/Editing/qingyan/InstructV2V/pexel-video-1frame-kontext-edit/local", help="Root dir containing group folders like 0139")
|
| 244 |
+
parser.add_argument("--masks_dir", type=str, default="/robby/share/Editing/lzc/InstructV2V/diff_masks", help="Root dir of precomputed masks organized as <group>/<prefix>_{i}.png")
|
| 245 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
|
| 246 |
+
parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
|
| 247 |
+
parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
|
| 248 |
+
parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
|
| 249 |
+
|
| 250 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
| 251 |
+
parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
|
| 252 |
+
parser.add_argument("--kontext", type=str, default="enable")
|
| 253 |
+
parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
|
| 254 |
+
parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
|
| 255 |
+
parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
|
| 256 |
+
parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
|
| 257 |
+
parser.add_argument("--num_validation_images", type=int, default=4)
|
| 258 |
+
parser.add_argument("--validation_steps", type=int, default=20)
|
| 259 |
+
|
| 260 |
+
parser.add_argument("--ranks", type=int, nargs="+", default=[256], help="LoRA ranks")
|
| 261 |
+
parser.add_argument("--network_alphas", type=int, nargs="+", default=[256], help="LoRA network alphas")
|
| 262 |
+
parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
|
| 263 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 264 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 265 |
+
parser.add_argument("--num_train_epochs", type=int, default=50)
|
| 266 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
| 267 |
+
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
| 268 |
+
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
|
| 269 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 270 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 271 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 272 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 273 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
|
| 274 |
+
parser.add_argument("--scale_lr", action="store_true", default=False)
|
| 275 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant")
|
| 276 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
| 277 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
| 278 |
+
parser.add_argument("--lr_power", type=float, default=1.0)
|
| 279 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=8)
|
| 280 |
+
parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
|
| 281 |
+
parser.add_argument("--logit_mean", type=float, default=0.0)
|
| 282 |
+
parser.add_argument("--logit_std", type=float, default=1.0)
|
| 283 |
+
parser.add_argument("--mode_scale", type=float, default=1.29)
|
| 284 |
+
parser.add_argument("--optimizer", type=str, default="AdamW")
|
| 285 |
+
parser.add_argument("--use_8bit_adam", action="store_true")
|
| 286 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
| 287 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 288 |
+
parser.add_argument("--prodigy_beta3", type=float, default=None)
|
| 289 |
+
parser.add_argument("--prodigy_decouple", type=bool, default=True)
|
| 290 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
|
| 291 |
+
parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
|
| 292 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 293 |
+
parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
|
| 294 |
+
parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
|
| 295 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 296 |
+
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 297 |
+
parser.add_argument("--cache_latents", action="store_true", default=False)
|
| 298 |
+
parser.add_argument("--report_to", type=str, default="tensorboard")
|
| 299 |
+
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
|
| 300 |
+
parser.add_argument("--upcast_before_saving", action="store_true", default=False)
|
| 301 |
+
parser.add_argument("--mix_ratio", type=float, default=0, help="Ratio of inpaint to local edits (B per A). 0=only local edits, 1=1:1, 2=1:2")
|
| 302 |
+
parser.add_argument("--background_weight", type=float, default=1.0, help="Background preserving loss weight multiplier")
|
| 303 |
+
|
| 304 |
+
# Blending options for dataset pixel_values
|
| 305 |
+
parser.add_argument("--blend_pixel_values", action="store_true", help="Blend target/source into pixel_values using mask")
|
| 306 |
+
parser.add_argument("--blend_kernel", type=int, default=21, help="Gaussian blur kernel size (must be odd)")
|
| 307 |
+
parser.add_argument("--blend_sigma", type=float, default=10.0, help="Gaussian blur sigma")
|
| 308 |
+
|
| 309 |
+
if input_args is not None:
|
| 310 |
+
args = parser.parse_args(input_args)
|
| 311 |
+
else:
|
| 312 |
+
args = parser.parse_args()
|
| 313 |
+
return args
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def main(args):
|
| 317 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 318 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 319 |
+
|
| 320 |
+
if args.output_dir is not None:
|
| 321 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 322 |
+
os.makedirs(args.logging_dir, exist_ok=True)
|
| 323 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 324 |
+
|
| 325 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 326 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 327 |
+
accelerator = Accelerator(
|
| 328 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 329 |
+
mixed_precision=args.mixed_precision,
|
| 330 |
+
log_with=args.report_to,
|
| 331 |
+
project_config=accelerator_project_config,
|
| 332 |
+
kwargs_handlers=[kwargs],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if torch.backends.mps.is_available():
|
| 336 |
+
accelerator.native_amp = False
|
| 337 |
+
|
| 338 |
+
if args.report_to == "wandb":
|
| 339 |
+
if not is_wandb_available():
|
| 340 |
+
raise ImportError("Install wandb for logging during training.")
|
| 341 |
+
|
| 342 |
+
logging.basicConfig(
|
| 343 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 344 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 345 |
+
level=logging.INFO,
|
| 346 |
+
)
|
| 347 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 348 |
+
if accelerator.is_local_main_process:
|
| 349 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 350 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 351 |
+
else:
|
| 352 |
+
transformers.utils.logging.set_verbosity_error()
|
| 353 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 354 |
+
|
| 355 |
+
if args.seed is not None:
|
| 356 |
+
set_seed(args.seed)
|
| 357 |
+
|
| 358 |
+
if accelerator.is_main_process and args.output_dir is not None:
|
| 359 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 360 |
+
|
| 361 |
+
# Tokenizers
|
| 362 |
+
tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
|
| 363 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
| 364 |
+
)
|
| 365 |
+
tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
|
| 366 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Text encoders
|
| 370 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
|
| 371 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
|
| 372 |
+
|
| 373 |
+
# Scheduler and models
|
| 374 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 375 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 376 |
+
text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
|
| 377 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
|
| 378 |
+
transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
|
| 379 |
+
|
| 380 |
+
# Train only LoRA adapters
|
| 381 |
+
transformer.requires_grad_(True)
|
| 382 |
+
vae.requires_grad_(False)
|
| 383 |
+
text_encoder_one.requires_grad_(False)
|
| 384 |
+
text_encoder_two.requires_grad_(False)
|
| 385 |
+
|
| 386 |
+
weight_dtype = torch.float32
|
| 387 |
+
if accelerator.mixed_precision == "fp16":
|
| 388 |
+
weight_dtype = torch.float16
|
| 389 |
+
elif accelerator.mixed_precision == "bf16":
|
| 390 |
+
weight_dtype = torch.bfloat16
|
| 391 |
+
|
| 392 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 393 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 394 |
+
|
| 395 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 396 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 397 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 398 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 399 |
+
|
| 400 |
+
if args.gradient_checkpointing:
|
| 401 |
+
transformer.enable_gradient_checkpointing()
|
| 402 |
+
|
| 403 |
+
# Setup LoRA attention processors
|
| 404 |
+
if args.pretrained_lora_path is not None:
|
| 405 |
+
lora_path = args.pretrained_lora_path
|
| 406 |
+
checkpoint = load_checkpoint(lora_path)
|
| 407 |
+
lora_attn_procs = {}
|
| 408 |
+
double_blocks_idx = list(range(19))
|
| 409 |
+
single_blocks_idx = list(range(38))
|
| 410 |
+
number = 1
|
| 411 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 412 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 413 |
+
if match:
|
| 414 |
+
layer_index = int(match.group(1))
|
| 415 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 416 |
+
lora_state_dicts = {}
|
| 417 |
+
for key, value in checkpoint.items():
|
| 418 |
+
if re.search(r'\.(\d+)\.', key):
|
| 419 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 420 |
+
if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
|
| 421 |
+
lora_state_dicts[key] = value
|
| 422 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 423 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 424 |
+
)
|
| 425 |
+
for n in range(number):
|
| 426 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 427 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 428 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 429 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 430 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 431 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 432 |
+
lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
|
| 433 |
+
lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
|
| 434 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 435 |
+
lora_state_dicts = {}
|
| 436 |
+
for key, value in checkpoint.items():
|
| 437 |
+
if re.search(r'\.(\d+)\.', key):
|
| 438 |
+
checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
|
| 439 |
+
if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
|
| 440 |
+
lora_state_dicts[key] = value
|
| 441 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 442 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 443 |
+
)
|
| 444 |
+
for n in range(number):
|
| 445 |
+
lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
|
| 446 |
+
lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
|
| 447 |
+
lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
|
| 448 |
+
lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
|
| 449 |
+
lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
|
| 450 |
+
lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
|
| 451 |
+
else:
|
| 452 |
+
lora_attn_procs[name] = FluxAttnProcessor2_0()
|
| 453 |
+
else:
|
| 454 |
+
lora_attn_procs = {}
|
| 455 |
+
double_blocks_idx = list(range(19))
|
| 456 |
+
single_blocks_idx = list(range(38))
|
| 457 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 458 |
+
match = re.search(r'\.(\d+)\.', name)
|
| 459 |
+
if match:
|
| 460 |
+
layer_index = int(match.group(1))
|
| 461 |
+
if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
|
| 462 |
+
lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
|
| 463 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 464 |
+
)
|
| 465 |
+
elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
|
| 466 |
+
lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
|
| 467 |
+
dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
lora_attn_procs[name] = attn_processor
|
| 471 |
+
|
| 472 |
+
transformer.set_attn_processor(lora_attn_procs)
|
| 473 |
+
transformer.train()
|
| 474 |
+
for n, param in transformer.named_parameters():
|
| 475 |
+
if '_lora' not in n:
|
| 476 |
+
param.requires_grad = False
|
| 477 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
|
| 478 |
+
|
| 479 |
+
def unwrap_model(model):
|
| 480 |
+
model = accelerator.unwrap_model(model)
|
| 481 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 482 |
+
return model
|
| 483 |
+
|
| 484 |
+
if args.resume_from_checkpoint:
|
| 485 |
+
path = args.resume_from_checkpoint
|
| 486 |
+
global_step = int(path.split("-")[-1])
|
| 487 |
+
initial_global_step = global_step
|
| 488 |
+
else:
|
| 489 |
+
initial_global_step = 0
|
| 490 |
+
global_step = 0
|
| 491 |
+
first_epoch = 0
|
| 492 |
+
|
| 493 |
+
if args.scale_lr:
|
| 494 |
+
args.learning_rate = (
|
| 495 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
if args.mixed_precision == "fp16":
|
| 499 |
+
models = [transformer]
|
| 500 |
+
cast_training_params(models, dtype=torch.float32)
|
| 501 |
+
|
| 502 |
+
params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
|
| 503 |
+
transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
|
| 504 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
|
| 505 |
+
|
| 506 |
+
optimizer_class = torch.optim.AdamW
|
| 507 |
+
optimizer = optimizer_class(
|
| 508 |
+
[transformer_parameters_with_lr],
|
| 509 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 510 |
+
weight_decay=args.adam_weight_decay,
|
| 511 |
+
eps=args.adam_epsilon,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
tokenizers = [tokenizer_one, tokenizer_two]
|
| 515 |
+
text_encoders = [text_encoder_one, text_encoder_two]
|
| 516 |
+
|
| 517 |
+
train_dataset = make_train_dataset_mixed(args, tokenizers, accelerator)
|
| 518 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 519 |
+
train_dataset,
|
| 520 |
+
batch_size=args.train_batch_size,
|
| 521 |
+
shuffle=True,
|
| 522 |
+
collate_fn=collate_fn,
|
| 523 |
+
num_workers=args.dataloader_num_workers,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 527 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 528 |
+
|
| 529 |
+
overrode_max_train_steps = False
|
| 530 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 531 |
+
if args.resume_from_checkpoint:
|
| 532 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 533 |
+
if args.max_train_steps is None:
|
| 534 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 535 |
+
overrode_max_train_steps = True
|
| 536 |
+
|
| 537 |
+
lr_scheduler = get_scheduler(
|
| 538 |
+
args.lr_scheduler,
|
| 539 |
+
optimizer=optimizer,
|
| 540 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 541 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 542 |
+
num_cycles=args.lr_num_cycles,
|
| 543 |
+
power=args.lr_power,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 547 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 551 |
+
if overrode_max_train_steps:
|
| 552 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 553 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 554 |
+
|
| 555 |
+
# Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
|
| 556 |
+
def _sanitize_hparams(config_dict):
|
| 557 |
+
sanitized = {}
|
| 558 |
+
for key, value in dict(config_dict).items():
|
| 559 |
+
try:
|
| 560 |
+
if value is None:
|
| 561 |
+
continue
|
| 562 |
+
# numpy scalar types
|
| 563 |
+
if isinstance(value, (np.integer,)):
|
| 564 |
+
sanitized[key] = int(value)
|
| 565 |
+
elif isinstance(value, (np.floating,)):
|
| 566 |
+
sanitized[key] = float(value)
|
| 567 |
+
elif isinstance(value, (int, float, bool, str)):
|
| 568 |
+
sanitized[key] = value
|
| 569 |
+
elif isinstance(value, Path):
|
| 570 |
+
sanitized[key] = str(value)
|
| 571 |
+
elif isinstance(value, (list, tuple)):
|
| 572 |
+
# stringify simple sequences; skip if fails
|
| 573 |
+
sanitized[key] = str(value)
|
| 574 |
+
else:
|
| 575 |
+
# best-effort stringify
|
| 576 |
+
sanitized[key] = str(value)
|
| 577 |
+
except Exception:
|
| 578 |
+
# skip unconvertible entries
|
| 579 |
+
continue
|
| 580 |
+
return sanitized
|
| 581 |
+
|
| 582 |
+
if accelerator.is_main_process:
|
| 583 |
+
tracker_name = "Easy_Control_Kontext"
|
| 584 |
+
accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
|
| 585 |
+
|
| 586 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 587 |
+
logger.info("***** Running training *****")
|
| 588 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 589 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 590 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 591 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 592 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 593 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 594 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 595 |
+
|
| 596 |
+
progress_bar = tqdm(
|
| 597 |
+
range(0, args.max_train_steps),
|
| 598 |
+
initial=initial_global_step,
|
| 599 |
+
desc="Steps",
|
| 600 |
+
disable=not accelerator.is_local_main_process,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 604 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 605 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 606 |
+
timesteps = timesteps.to(accelerator.device)
|
| 607 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 608 |
+
sigma = sigmas[step_indices].flatten()
|
| 609 |
+
while len(sigma.shape) < n_dim:
|
| 610 |
+
sigma = sigma.unsqueeze(-1)
|
| 611 |
+
return sigma
|
| 612 |
+
|
| 613 |
+
# Kontext specifics
|
| 614 |
+
vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
|
| 615 |
+
# Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
|
| 616 |
+
height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 617 |
+
width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
|
| 618 |
+
offset = 64
|
| 619 |
+
|
| 620 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 621 |
+
transformer.train()
|
| 622 |
+
for step, batch in enumerate(train_dataloader):
|
| 623 |
+
models_to_accumulate = [transformer]
|
| 624 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 625 |
+
tokens = [batch["text_ids_1"], batch["text_ids_2"]]
|
| 626 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
|
| 627 |
+
prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 628 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 629 |
+
text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
|
| 630 |
+
|
| 631 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 632 |
+
height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
|
| 633 |
+
width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
|
| 634 |
+
|
| 635 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 636 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 637 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 638 |
+
|
| 639 |
+
latent_image_ids, cond_latent_image_ids = resize_position_encoding(
|
| 640 |
+
model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
noise = torch.randn_like(model_input)
|
| 644 |
+
bsz = model_input.shape[0]
|
| 645 |
+
|
| 646 |
+
u = compute_density_for_timestep_sampling(
|
| 647 |
+
weighting_scheme=args.weighting_scheme,
|
| 648 |
+
batch_size=bsz,
|
| 649 |
+
logit_mean=args.logit_mean,
|
| 650 |
+
logit_std=args.logit_std,
|
| 651 |
+
mode_scale=args.mode_scale,
|
| 652 |
+
)
|
| 653 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 654 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 655 |
+
|
| 656 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 657 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 658 |
+
|
| 659 |
+
packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
|
| 660 |
+
noisy_model_input,
|
| 661 |
+
batch_size=model_input.shape[0],
|
| 662 |
+
num_channels_latents=model_input.shape[1],
|
| 663 |
+
height=model_input.shape[2],
|
| 664 |
+
width=model_input.shape[3],
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
latent_image_ids_to_concat = [latent_image_ids]
|
| 668 |
+
packed_cond_model_input_to_concat = []
|
| 669 |
+
|
| 670 |
+
if args.kontext == "enable":
|
| 671 |
+
source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
|
| 672 |
+
source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
|
| 673 |
+
source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
|
| 674 |
+
image_latent_h, image_latent_w = source_image_latents.shape[2:]
|
| 675 |
+
packed_image_latents = FluxKontextControlPipeline._pack_latents(
|
| 676 |
+
source_image_latents,
|
| 677 |
+
batch_size=source_image_latents.shape[0],
|
| 678 |
+
num_channels_latents=source_image_latents.shape[1],
|
| 679 |
+
height=image_latent_h,
|
| 680 |
+
width=image_latent_w,
|
| 681 |
+
)
|
| 682 |
+
source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
|
| 683 |
+
batch_size=source_image_latents.shape[0],
|
| 684 |
+
height=image_latent_h // 2,
|
| 685 |
+
width=image_latent_w // 2,
|
| 686 |
+
device=accelerator.device,
|
| 687 |
+
dtype=weight_dtype,
|
| 688 |
+
)
|
| 689 |
+
source_image_ids[..., 0] = 1 # Mark as condition
|
| 690 |
+
latent_image_ids_to_concat.append(source_image_ids)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
subject_pixel_values = batch.get("subject_pixel_values")
|
| 694 |
+
if subject_pixel_values is not None:
|
| 695 |
+
subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
|
| 696 |
+
subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
|
| 697 |
+
subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 698 |
+
subject_input = subject_input.to(dtype=weight_dtype)
|
| 699 |
+
sub_number = subject_pixel_values.shape[-2] // args.cond_size
|
| 700 |
+
latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
|
| 701 |
+
latent_subject_ids[..., 0] = 2
|
| 702 |
+
latent_subject_ids[:, 1] += offset
|
| 703 |
+
sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
|
| 704 |
+
latent_image_ids_to_concat.append(sub_latent_image_ids)
|
| 705 |
+
|
| 706 |
+
packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
|
| 707 |
+
subject_input,
|
| 708 |
+
batch_size=subject_input.shape[0],
|
| 709 |
+
num_channels_latents=subject_input.shape[1],
|
| 710 |
+
height=subject_input.shape[2],
|
| 711 |
+
width=subject_input.shape[3],
|
| 712 |
+
)
|
| 713 |
+
packed_cond_model_input_to_concat.append(packed_subject_model_input)
|
| 714 |
+
|
| 715 |
+
cond_pixel_values = batch.get("cond_pixel_values")
|
| 716 |
+
if cond_pixel_values is not None:
|
| 717 |
+
cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
|
| 718 |
+
cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
|
| 719 |
+
cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 720 |
+
cond_input = cond_input.to(dtype=weight_dtype)
|
| 721 |
+
cond_number = cond_pixel_values.shape[-2] // args.cond_size
|
| 722 |
+
cond_latent_image_ids[..., 0] = 2
|
| 723 |
+
cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
|
| 724 |
+
latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
|
| 725 |
+
|
| 726 |
+
packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
|
| 727 |
+
cond_input,
|
| 728 |
+
batch_size=cond_input.shape[0],
|
| 729 |
+
num_channels_latents=cond_input.shape[1],
|
| 730 |
+
height=cond_input.shape[2],
|
| 731 |
+
width=cond_input.shape[3],
|
| 732 |
+
)
|
| 733 |
+
packed_cond_model_input_to_concat.append(packed_cond_model_input)
|
| 734 |
+
|
| 735 |
+
latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
|
| 736 |
+
cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
|
| 737 |
+
|
| 738 |
+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
| 739 |
+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
| 740 |
+
guidance = guidance.expand(model_input.shape[0])
|
| 741 |
+
else:
|
| 742 |
+
guidance = None
|
| 743 |
+
|
| 744 |
+
latent_model_input=packed_noisy_model_input
|
| 745 |
+
if args.kontext == "enable":
|
| 746 |
+
latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
|
| 747 |
+
model_pred = transformer(
|
| 748 |
+
hidden_states=latent_model_input,
|
| 749 |
+
cond_hidden_states=cond_packed_noisy_model_input,
|
| 750 |
+
timestep=timesteps / 1000,
|
| 751 |
+
guidance=guidance,
|
| 752 |
+
pooled_projections=pooled_prompt_embeds,
|
| 753 |
+
encoder_hidden_states=prompt_embeds,
|
| 754 |
+
txt_ids=text_ids,
|
| 755 |
+
img_ids=latent_image_ids,
|
| 756 |
+
return_dict=False,
|
| 757 |
+
)[0]
|
| 758 |
+
|
| 759 |
+
model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
|
| 760 |
+
|
| 761 |
+
model_pred = FluxKontextControlPipeline._unpack_latents(
|
| 762 |
+
model_pred,
|
| 763 |
+
height=int(pixel_values.shape[-2]),
|
| 764 |
+
width=int(pixel_values.shape[-1]),
|
| 765 |
+
vae_scale_factor=vae_scale_factor,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 769 |
+
target = noise - model_input
|
| 770 |
+
|
| 771 |
+
# mask_values = batch.get("mask_values")
|
| 772 |
+
# if mask_values is not None:
|
| 773 |
+
# mask_values = mask_values.to(device=accelerator.device, dtype=model_pred.dtype)
|
| 774 |
+
# loss_map = compute_background_preserving_loss(
|
| 775 |
+
# model_pred=model_pred,
|
| 776 |
+
# target=target,
|
| 777 |
+
# mask_values=mask_values,
|
| 778 |
+
# weighting=weighting,
|
| 779 |
+
# background_weight=args.background_weight,
|
| 780 |
+
# )
|
| 781 |
+
# loss = torch.mean(loss_map.reshape(target.shape[0], -1), 1)
|
| 782 |
+
# loss = loss.mean()
|
| 783 |
+
# else:
|
| 784 |
+
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
|
| 785 |
+
loss = loss.mean()
|
| 786 |
+
accelerator.backward(loss)
|
| 787 |
+
if accelerator.sync_gradients:
|
| 788 |
+
params_to_clip = (transformer.parameters())
|
| 789 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 790 |
+
|
| 791 |
+
optimizer.step()
|
| 792 |
+
lr_scheduler.step()
|
| 793 |
+
optimizer.zero_grad()
|
| 794 |
+
|
| 795 |
+
if accelerator.sync_gradients:
|
| 796 |
+
progress_bar.update(1)
|
| 797 |
+
global_step += 1
|
| 798 |
+
|
| 799 |
+
if accelerator.is_main_process:
|
| 800 |
+
if global_step % args.checkpointing_steps == 0:
|
| 801 |
+
if args.checkpoints_total_limit is not None:
|
| 802 |
+
checkpoints = os.listdir(args.output_dir)
|
| 803 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 804 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 805 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 806 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 807 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 808 |
+
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
|
| 809 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 810 |
+
for removing_checkpoint in removing_checkpoints:
|
| 811 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 812 |
+
shutil.rmtree(removing_checkpoint)
|
| 813 |
+
|
| 814 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 815 |
+
os.makedirs(save_path, exist_ok=True)
|
| 816 |
+
unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
|
| 817 |
+
lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
|
| 818 |
+
save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
|
| 819 |
+
logger.info(f"Saved state to {save_path}")
|
| 820 |
+
|
| 821 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 822 |
+
progress_bar.set_postfix(**logs)
|
| 823 |
+
accelerator.log(logs, step=global_step)
|
| 824 |
+
|
| 825 |
+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
| 826 |
+
pipeline = FluxKontextControlPipeline.from_pretrained(
|
| 827 |
+
args.pretrained_model_name_or_path,
|
| 828 |
+
vae=vae,
|
| 829 |
+
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
| 830 |
+
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
| 831 |
+
transformer=accelerator.unwrap_model(transformer),
|
| 832 |
+
revision=args.revision,
|
| 833 |
+
variant=args.variant,
|
| 834 |
+
torch_dtype=weight_dtype,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
|
| 838 |
+
spatial_paths = args.spatial_test_images
|
| 839 |
+
spatial_ls = [Image.open(image_path).convert("RGB") for image_path in spatial_paths]
|
| 840 |
+
else:
|
| 841 |
+
spatial_ls = []
|
| 842 |
+
|
| 843 |
+
pipeline_args = {
|
| 844 |
+
"prompt": args.validation_prompt,
|
| 845 |
+
"cond_size": args.cond_size,
|
| 846 |
+
"guidance_scale": 3.5,
|
| 847 |
+
"num_inference_steps": 20,
|
| 848 |
+
"max_sequence_length": 128,
|
| 849 |
+
"control_dict": {"spatial_images": spatial_ls},
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
images = log_validation(
|
| 853 |
+
pipeline=pipeline,
|
| 854 |
+
args=args,
|
| 855 |
+
accelerator=accelerator,
|
| 856 |
+
pipeline_args=pipeline_args,
|
| 857 |
+
step=global_step,
|
| 858 |
+
torch_dtype=weight_dtype,
|
| 859 |
+
)
|
| 860 |
+
if accelerator.is_main_process:
|
| 861 |
+
save_path = os.path.join(args.output_dir, "validation")
|
| 862 |
+
os.makedirs(save_path, exist_ok=True)
|
| 863 |
+
save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
|
| 864 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 865 |
+
for idx, img in enumerate(images):
|
| 866 |
+
img.save(os.path.join(save_folder, f"{idx}.jpg"))
|
| 867 |
+
del pipeline
|
| 868 |
+
|
| 869 |
+
accelerator.wait_for_everyone()
|
| 870 |
+
accelerator.end_training()
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
if __name__ == "__main__":
|
| 874 |
+
args = parse_args()
|
| 875 |
+
main(args)
|
| 876 |
+
|
train/train_kontext_local.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_DIR="" # your flux path
|
| 2 |
+
export OUTPUT_DIR="" # your save path
|
| 3 |
+
export CONFIG="./default_config.yaml"
|
| 4 |
+
export LOG_PATH="$OUTPUT_DIR/log"
|
| 5 |
+
|
| 6 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_qy.py \
|
| 7 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
| 8 |
+
--pretrained_lora_path "" \
|
| 9 |
+
--lora_num=1 \
|
| 10 |
+
--cond_size=512 \
|
| 11 |
+
--ranks 128 \
|
| 12 |
+
--network_alphas 128 \
|
| 13 |
+
--output_dir=$OUTPUT_DIR \
|
| 14 |
+
--logging_dir=$LOG_PATH \
|
| 15 |
+
--mixed_precision="bf16" \
|
| 16 |
+
--learning_rate=1e-4 \
|
| 17 |
+
--train_batch_size=1 \
|
| 18 |
+
--num_train_epochs=1 \
|
| 19 |
+
--validation_steps=250 \
|
| 20 |
+
--checkpointing_steps=1000 \
|
| 21 |
+
--validation_images "./kontext_local_test/img_1.png" \
|
| 22 |
+
--spatial_test_images "./kontext_local_test/mask_1.png" \
|
| 23 |
+
--validation_prompt "convert the dinosaur into blue color" \
|
| 24 |
+
--gradient_checkpointing \
|
| 25 |
+
--blend_pixel_values \
|
| 26 |
+
--num_validation_images=1
|
train/train_kontext_lora.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
from safetensors.torch import save_file
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils.checkpoint
|
| 17 |
+
import transformers
|
| 18 |
+
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
from accelerate.logging import get_logger
|
| 21 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
| 22 |
+
|
| 23 |
+
import diffusers
|
| 24 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline
|
| 25 |
+
from diffusers.optimization import get_scheduler
|
| 26 |
+
from diffusers.training_utils import (
|
| 27 |
+
cast_training_params,
|
| 28 |
+
compute_density_for_timestep_sampling,
|
| 29 |
+
compute_loss_weighting_for_sd3,
|
| 30 |
+
)
|
| 31 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
| 32 |
+
from diffusers.utils import (
|
| 33 |
+
check_min_version,
|
| 34 |
+
is_wandb_available,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from src.prompt_helper import *
|
| 38 |
+
from src.lora_helper import *
|
| 39 |
+
from src.jsonl_datasets_kontext_interactive_lora import make_interactive_dataset_subjects, make_placement_dataset_subjects, make_pexels_dataset_subjects, make_mixed_dataset, collate_fn
|
| 40 |
+
from diffusers import FluxKontextPipeline
|
| 41 |
+
from diffusers.models import FluxTransformer2DModel
|
| 42 |
+
from tqdm.auto import tqdm
|
| 43 |
+
from peft import LoraConfig
|
| 44 |
+
from peft.utils import get_peft_model_state_dict
|
| 45 |
+
from diffusers.utils import convert_state_dict_to_diffusers
|
| 46 |
+
|
| 47 |
+
if is_wandb_available():
|
| 48 |
+
import wandb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 52 |
+
check_min_version("0.31.0.dev0")
|
| 53 |
+
|
| 54 |
+
logger = get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 58 |
+
(672, 1568),
|
| 59 |
+
(688, 1504),
|
| 60 |
+
(720, 1456),
|
| 61 |
+
(752, 1392),
|
| 62 |
+
(832, 1248),
|
| 63 |
+
(880, 1184),
|
| 64 |
+
(944, 1104),
|
| 65 |
+
(1024, 1024),
|
| 66 |
+
(1104, 944),
|
| 67 |
+
(1184, 880),
|
| 68 |
+
(1248, 832),
|
| 69 |
+
(1392, 752),
|
| 70 |
+
(1456, 720),
|
| 71 |
+
(1504, 688),
|
| 72 |
+
(1568, 672),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def log_validation(
|
| 77 |
+
pipeline,
|
| 78 |
+
args,
|
| 79 |
+
accelerator,
|
| 80 |
+
pipeline_args,
|
| 81 |
+
step,
|
| 82 |
+
torch_dtype,
|
| 83 |
+
is_final_validation=False,
|
| 84 |
+
):
|
| 85 |
+
logger.info(
|
| 86 |
+
f"Running validation... Paired evaluation for image and prompt."
|
| 87 |
+
)
|
| 88 |
+
pipeline = pipeline.to(device=accelerator.device, dtype=torch_dtype)
|
| 89 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 90 |
+
|
| 91 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 92 |
+
# Match compute dtype for validation to avoid dtype mismatches (e.g., VAE bf16 vs float latents)
|
| 93 |
+
if torch_dtype in (torch.float16, torch.bfloat16):
|
| 94 |
+
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 95 |
+
autocast_ctx = torch.autocast(device_type=device_type, dtype=torch_dtype)
|
| 96 |
+
else:
|
| 97 |
+
autocast_ctx = nullcontext()
|
| 98 |
+
|
| 99 |
+
# Build per-case evaluation
|
| 100 |
+
if args.validation_images is None or args.validation_images == ['None']:
|
| 101 |
+
raise ValueError("validation_images must be provided and non-empty")
|
| 102 |
+
if args.validation_prompt is None:
|
| 103 |
+
raise ValueError("validation_prompt must be provided and non-empty")
|
| 104 |
+
|
| 105 |
+
val_imgs = args.validation_images
|
| 106 |
+
prompts = args.validation_prompt
|
| 107 |
+
# Prepend instruction to each prompt (same as dataset/test requirement)
|
| 108 |
+
instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary."
|
| 109 |
+
try:
|
| 110 |
+
prompts = [f"{instruction} {p}".strip() if isinstance(p, str) and len(p.strip()) > 0 else instruction for p in prompts]
|
| 111 |
+
except Exception:
|
| 112 |
+
# Fallback: keep original prompts if unexpected
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
if not (len(val_imgs) == len(prompts)):
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
results = []
|
| 121 |
+
|
| 122 |
+
def _resize_to_preferred(img: Image.Image) -> Image.Image:
|
| 123 |
+
w, h = img.size
|
| 124 |
+
aspect_ratio = w / h if h != 0 else 1.0
|
| 125 |
+
_, target_w, target_h = min(
|
| 126 |
+
(abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
|
| 127 |
+
for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
|
| 128 |
+
)
|
| 129 |
+
return img.resize((target_w, target_h), Image.BICUBIC)
|
| 130 |
+
|
| 131 |
+
# Distributed per-rank assignment: each process handles its own slice of cases
|
| 132 |
+
num_cases = len(prompts)
|
| 133 |
+
logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
|
| 134 |
+
|
| 135 |
+
# Indices assigned to this rank
|
| 136 |
+
rank = accelerator.process_index
|
| 137 |
+
world_size = accelerator.num_processes
|
| 138 |
+
local_indices = list(range(rank, num_cases, world_size))
|
| 139 |
+
|
| 140 |
+
local_images = []
|
| 141 |
+
with autocast_ctx:
|
| 142 |
+
for idx in local_indices:
|
| 143 |
+
try:
|
| 144 |
+
base_img = Image.open(val_imgs[idx]).convert("RGB")
|
| 145 |
+
resized_img = _resize_to_preferred(base_img)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
|
| 148 |
+
|
| 149 |
+
case_args = dict(pipeline_args) if pipeline_args is not None else {}
|
| 150 |
+
case_args.pop("height", None)
|
| 151 |
+
case_args.pop("width", None)
|
| 152 |
+
if resized_img is not None:
|
| 153 |
+
tw, th = resized_img.size
|
| 154 |
+
case_args["height"] = th
|
| 155 |
+
case_args["width"] = tw
|
| 156 |
+
|
| 157 |
+
case_args["prompt"] = prompts[idx]
|
| 158 |
+
img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
|
| 159 |
+
local_images.append(img)
|
| 160 |
+
|
| 161 |
+
# Gather all images per rank (pad to equal count) to main process
|
| 162 |
+
fixed_size = (1024, 1024)
|
| 163 |
+
max_local = int(math.ceil(num_cases / world_size)) if world_size > 0 else len(local_images)
|
| 164 |
+
# Build per-rank batch tensors
|
| 165 |
+
imgs_rank = []
|
| 166 |
+
idx_rank = []
|
| 167 |
+
has_rank = []
|
| 168 |
+
for j in range(max_local):
|
| 169 |
+
if j < len(local_images):
|
| 170 |
+
resized = local_images[j].resize(fixed_size, Image.BICUBIC)
|
| 171 |
+
img_np = np.asarray(resized).astype(np.uint8)
|
| 172 |
+
imgs_rank.append(torch.from_numpy(img_np))
|
| 173 |
+
idx_rank.append(local_indices[j])
|
| 174 |
+
has_rank.append(1)
|
| 175 |
+
else:
|
| 176 |
+
imgs_rank.append(torch.from_numpy(np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)))
|
| 177 |
+
idx_rank.append(-1)
|
| 178 |
+
has_rank.append(0)
|
| 179 |
+
imgs_rank_tensor = torch.stack([t.to(device=accelerator.device) for t in imgs_rank], dim=0) # [max_local, H, W, C]
|
| 180 |
+
idx_rank_tensor = torch.tensor(idx_rank, device=accelerator.device, dtype=torch.long) # [max_local]
|
| 181 |
+
has_rank_tensor = torch.tensor(has_rank, device=accelerator.device, dtype=torch.int) # [max_local]
|
| 182 |
+
|
| 183 |
+
gathered_has = accelerator.gather(has_rank_tensor) # [world * max_local]
|
| 184 |
+
gathered_idx = accelerator.gather(idx_rank_tensor) # [world * max_local]
|
| 185 |
+
gathered_imgs = accelerator.gather(imgs_rank_tensor) # [world * max_local, H, W, C]
|
| 186 |
+
|
| 187 |
+
if accelerator.is_main_process:
|
| 188 |
+
world = int(world_size)
|
| 189 |
+
slots = int(max_local)
|
| 190 |
+
try:
|
| 191 |
+
gathered_has = gathered_has.view(world, slots)
|
| 192 |
+
gathered_idx = gathered_idx.view(world, slots)
|
| 193 |
+
gathered_imgs = gathered_imgs.view(world, slots, fixed_size[1], fixed_size[0], 3)
|
| 194 |
+
except Exception:
|
| 195 |
+
# Fallback: treat as flat if reshape fails
|
| 196 |
+
gathered_has = gathered_has.view(-1, 1)
|
| 197 |
+
gathered_idx = gathered_idx.view(-1, 1)
|
| 198 |
+
gathered_imgs = gathered_imgs.view(-1, 1, fixed_size[1], fixed_size[0], 3)
|
| 199 |
+
world = int(gathered_has.shape[0])
|
| 200 |
+
slots = 1
|
| 201 |
+
for i in range(world):
|
| 202 |
+
for j in range(slots):
|
| 203 |
+
if int(gathered_has[i, j].item()) == 1:
|
| 204 |
+
idx = int(gathered_idx[i, j].item())
|
| 205 |
+
arr = gathered_imgs[i, j].cpu().numpy()
|
| 206 |
+
pil_img = Image.fromarray(arr.astype(np.uint8))
|
| 207 |
+
# Resize back to original validation image size
|
| 208 |
+
try:
|
| 209 |
+
orig = Image.open(val_imgs[idx]).convert("RGB")
|
| 210 |
+
pil_img = pil_img.resize(orig.size, Image.BICUBIC)
|
| 211 |
+
except Exception:
|
| 212 |
+
pass
|
| 213 |
+
results.append(pil_img)
|
| 214 |
+
|
| 215 |
+
# Log results (resize to 1024x1024 for saving or external trackers). Skip TensorBoard per request.
|
| 216 |
+
resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
|
| 217 |
+
for tracker in accelerator.trackers:
|
| 218 |
+
phase_name = "test" if is_final_validation else "validation"
|
| 219 |
+
if tracker.name == "tensorboard":
|
| 220 |
+
continue
|
| 221 |
+
if tracker.name == "wandb":
|
| 222 |
+
tracker.log({
|
| 223 |
+
phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
del pipeline
|
| 227 |
+
if torch.cuda.is_available():
|
| 228 |
+
torch.cuda.empty_cache()
|
| 229 |
+
|
| 230 |
+
return results
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def save_with_retry(img: Image.Image, path: str, max_retries: int = 3) -> bool:
|
| 234 |
+
"""Save PIL image with simple retry and exponential backoff to mitigate transient I/O errors."""
|
| 235 |
+
last_err = None
|
| 236 |
+
for attempt in range(max_retries):
|
| 237 |
+
try:
|
| 238 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 239 |
+
img.save(path)
|
| 240 |
+
return True
|
| 241 |
+
except OSError as e:
|
| 242 |
+
last_err = e
|
| 243 |
+
# Exponential backoff: 1.0, 1.5, 2.25 seconds ...
|
| 244 |
+
time.sleep(1.5 ** attempt)
|
| 245 |
+
logger.warning(f"Failed to save {path} after {max_retries} retries: {last_err}")
|
| 246 |
+
return False
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
|
| 250 |
+
text_encoder_config = transformers.PretrainedConfig.from_pretrained(
|
| 251 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 252 |
+
)
|
| 253 |
+
model_class = text_encoder_config.architectures[0]
|
| 254 |
+
if model_class == "CLIPTextModel":
|
| 255 |
+
from transformers import CLIPTextModel
|
| 256 |
+
|
| 257 |
+
return CLIPTextModel
|
| 258 |
+
elif model_class == "T5EncoderModel":
|
| 259 |
+
from transformers import T5EncoderModel
|
| 260 |
+
|
| 261 |
+
return T5EncoderModel
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def parse_args(input_args=None):
|
| 267 |
+
parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
|
| 268 |
+
parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
|
| 269 |
+
|
| 270 |
+
# Dataset arguments
|
| 271 |
+
parser.add_argument("--dataset_mode", type=str, default="mixed", choices=["interactive", "placement", "pexels", "mixed"],
|
| 272 |
+
help="Dataset mode: interactive, placement, pexels, or mixed")
|
| 273 |
+
parser.add_argument("--train_data_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl",
|
| 274 |
+
help="Path to interactive dataset JSONL")
|
| 275 |
+
parser.add_argument("--placement_data_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl",
|
| 276 |
+
help="Path to placement dataset JSONL")
|
| 277 |
+
parser.add_argument("--pexels_data_jsonl", type=str, default=None,
|
| 278 |
+
help="Path to pexels dataset JSONL")
|
| 279 |
+
parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1",
|
| 280 |
+
help="Base directory for interactive dataset")
|
| 281 |
+
parser.add_argument("--placement_base_dir", type=str, default="/robby/share/Editing/lzc/subject_placement",
|
| 282 |
+
help="Base directory for placement dataset")
|
| 283 |
+
parser.add_argument("--pexels_base_dir", type=str, default=None,
|
| 284 |
+
help="Base directory for pexels dataset")
|
| 285 |
+
parser.add_argument("--pexels_relight_base_dir", type=str, default=None,
|
| 286 |
+
help="Base directory for pexels relighted images")
|
| 287 |
+
parser.add_argument("--seg_base_dir", type=str, default=None,
|
| 288 |
+
help="Directory containing segmentation maps for pexels dataset")
|
| 289 |
+
parser.add_argument("--interactive_weight", type=float, default=1.0,
|
| 290 |
+
help="Sampling weight for interactive dataset (default: 1.0)")
|
| 291 |
+
parser.add_argument("--placement_weight", type=float, default=1.0,
|
| 292 |
+
help="Sampling weight for placement dataset (default: 1.0)")
|
| 293 |
+
parser.add_argument("--pexels_weight", type=float, default=0.1,
|
| 294 |
+
help="Sampling weight for pexels dataset (default: 1.0)")
|
| 295 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
|
| 296 |
+
parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
|
| 297 |
+
parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
|
| 298 |
+
parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
|
| 299 |
+
|
| 300 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
| 301 |
+
parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
|
| 302 |
+
parser.add_argument("--kontext", type=str, default="enable")
|
| 303 |
+
parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
|
| 304 |
+
parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
|
| 305 |
+
parser.add_argument("--num_validation_images", type=int, default=4)
|
| 306 |
+
parser.add_argument("--validation_steps", type=int, default=20)
|
| 307 |
+
|
| 308 |
+
parser.add_argument("--ranks", type=int, nargs="+", default=[32], help="LoRA ranks")
|
| 309 |
+
parser.add_argument("--output_dir", type=str, default="", help="Output directory")
|
| 310 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 311 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 312 |
+
parser.add_argument("--num_train_epochs", type=int, default=50)
|
| 313 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
| 314 |
+
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
| 315 |
+
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
|
| 316 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 317 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 318 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 319 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 320 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
|
| 321 |
+
parser.add_argument("--scale_lr", action="store_true", default=False)
|
| 322 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant")
|
| 323 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
| 324 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
| 325 |
+
parser.add_argument("--lr_power", type=float, default=1.0)
|
| 326 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=8)
|
| 327 |
+
parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
|
| 328 |
+
parser.add_argument("--logit_mean", type=float, default=0.0)
|
| 329 |
+
parser.add_argument("--logit_std", type=float, default=1.0)
|
| 330 |
+
parser.add_argument("--mode_scale", type=float, default=1.29)
|
| 331 |
+
parser.add_argument("--optimizer", type=str, default="AdamW")
|
| 332 |
+
parser.add_argument("--use_8bit_adam", action="store_true")
|
| 333 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
| 334 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 335 |
+
parser.add_argument("--prodigy_beta3", type=float, default=None)
|
| 336 |
+
parser.add_argument("--prodigy_decouple", type=bool, default=True)
|
| 337 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
|
| 338 |
+
parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
|
| 339 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 340 |
+
parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
|
| 341 |
+
parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
|
| 342 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 343 |
+
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 344 |
+
parser.add_argument("--cache_latents", action="store_true", default=False)
|
| 345 |
+
parser.add_argument("--report_to", type=str, default="tensorboard")
|
| 346 |
+
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
|
| 347 |
+
parser.add_argument("--upcast_before_saving", action="store_true", default=False)
|
| 348 |
+
|
| 349 |
+
# Blending options for dataset pixel_values
|
| 350 |
+
parser.add_argument("--blend_pixel_values", action="store_true", help="Blend target/source into pixel_values using mask")
|
| 351 |
+
parser.add_argument("--blend_kernel", type=int, default=21, help="Gaussian blur kernel size (must be odd)")
|
| 352 |
+
parser.add_argument("--blend_sigma", type=float, default=10.0, help="Gaussian blur sigma")
|
| 353 |
+
|
| 354 |
+
if input_args is not None:
|
| 355 |
+
args = parser.parse_args(input_args)
|
| 356 |
+
else:
|
| 357 |
+
args = parser.parse_args()
|
| 358 |
+
return args
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def main(args):
|
| 362 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
| 363 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 364 |
+
|
| 365 |
+
if args.output_dir is not None:
|
| 366 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 367 |
+
os.makedirs(args.logging_dir, exist_ok=True)
|
| 368 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 369 |
+
|
| 370 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 371 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 372 |
+
accelerator = Accelerator(
|
| 373 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 374 |
+
mixed_precision=args.mixed_precision,
|
| 375 |
+
log_with=args.report_to,
|
| 376 |
+
project_config=accelerator_project_config,
|
| 377 |
+
kwargs_handlers=[kwargs],
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if torch.backends.mps.is_available():
|
| 381 |
+
accelerator.native_amp = False
|
| 382 |
+
|
| 383 |
+
if args.report_to == "wandb":
|
| 384 |
+
if not is_wandb_available():
|
| 385 |
+
raise ImportError("Install wandb for logging during training.")
|
| 386 |
+
|
| 387 |
+
logging.basicConfig(
|
| 388 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 389 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 390 |
+
level=logging.INFO,
|
| 391 |
+
)
|
| 392 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 393 |
+
if accelerator.is_local_main_process:
|
| 394 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 395 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 396 |
+
else:
|
| 397 |
+
transformers.utils.logging.set_verbosity_error()
|
| 398 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 399 |
+
|
| 400 |
+
if args.seed is not None:
|
| 401 |
+
set_seed(args.seed)
|
| 402 |
+
|
| 403 |
+
if accelerator.is_main_process and args.output_dir is not None:
|
| 404 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 405 |
+
|
| 406 |
+
# Tokenizers
|
| 407 |
+
tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
|
| 408 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
| 409 |
+
)
|
| 410 |
+
tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
|
| 411 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Text encoders
|
| 415 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
|
| 416 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
|
| 417 |
+
|
| 418 |
+
# Scheduler and models
|
| 419 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 420 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
| 421 |
+
text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
|
| 422 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
|
| 423 |
+
transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
|
| 424 |
+
|
| 425 |
+
# Train only LoRA adapters: freeze base transformer/text encoders/vae
|
| 426 |
+
transformer.requires_grad_(False)
|
| 427 |
+
vae.requires_grad_(False)
|
| 428 |
+
text_encoder_one.requires_grad_(False)
|
| 429 |
+
text_encoder_two.requires_grad_(False)
|
| 430 |
+
|
| 431 |
+
weight_dtype = torch.float32
|
| 432 |
+
if accelerator.mixed_precision == "fp16":
|
| 433 |
+
weight_dtype = torch.float16
|
| 434 |
+
elif accelerator.mixed_precision == "bf16":
|
| 435 |
+
weight_dtype = torch.bfloat16
|
| 436 |
+
|
| 437 |
+
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
| 438 |
+
raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
|
| 439 |
+
|
| 440 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 441 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 442 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
| 443 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
| 444 |
+
|
| 445 |
+
if args.gradient_checkpointing:
|
| 446 |
+
transformer.enable_gradient_checkpointing()
|
| 447 |
+
|
| 448 |
+
# Setup standard PEFT LoRA on FluxTransformer2DModel
|
| 449 |
+
# target_modules = [
|
| 450 |
+
# "attn.to_k",
|
| 451 |
+
# "attn.to_q",
|
| 452 |
+
# "attn.to_v",
|
| 453 |
+
# "attn.to_out.0",
|
| 454 |
+
# "attn.add_k_proj",
|
| 455 |
+
# "attn.add_q_proj",
|
| 456 |
+
# "attn.add_v_proj",
|
| 457 |
+
# "attn.to_add_out",
|
| 458 |
+
# "ff.net.0.proj",
|
| 459 |
+
# "ff.net.2",
|
| 460 |
+
# "ff_context.net.0.proj",
|
| 461 |
+
# "ff_context.net.2",
|
| 462 |
+
# ]
|
| 463 |
+
target_modules = [
|
| 464 |
+
"attn.to_k",
|
| 465 |
+
"attn.to_q",
|
| 466 |
+
"attn.to_v",
|
| 467 |
+
"attn.to_out.0",
|
| 468 |
+
"attn.add_k_proj",
|
| 469 |
+
"attn.add_q_proj",
|
| 470 |
+
"attn.add_v_proj",
|
| 471 |
+
"attn.to_add_out",
|
| 472 |
+
"ff.net.0.proj",
|
| 473 |
+
"ff.net.2",
|
| 474 |
+
"ff_context.net.0.proj",
|
| 475 |
+
"ff_context.net.2",
|
| 476 |
+
# ===========================================================
|
| 477 |
+
# 【补全部分 1】: 单流模块 (single_transformer_blocks) 的专属层
|
| 478 |
+
# ===========================================================
|
| 479 |
+
# 说明:单流块中的注意力层 (to_q, to_k, to_v) 已被上面的通用名称覆盖。
|
| 480 |
+
# 这里补充的是它们特有的 MLP 和输出层。
|
| 481 |
+
"proj_mlp",
|
| 482 |
+
"proj_out", # 这个名称也会匹配单流块各自的输出层和模型总输出层
|
| 483 |
+
|
| 484 |
+
# ===========================================================
|
| 485 |
+
# 【补全部分 2】: 所有的归一化 (Norm) 层
|
| 486 |
+
# ===========================================================
|
| 487 |
+
# 说明:这些层负责调整特征分布,对风格学习很重要。
|
| 488 |
+
# 使用 "linear" 可以一次性匹配所有以 ".linear" 结尾的Norm层。
|
| 489 |
+
"linear", # 匹配 norm1.linear, norm1_context.linear, norm.linear, norm_out.linear
|
| 490 |
+
]
|
| 491 |
+
lora_rank = int(args.ranks[0]) if isinstance(args.ranks, list) and len(args.ranks) > 0 else 256
|
| 492 |
+
lora_config = LoraConfig(
|
| 493 |
+
r=lora_rank,
|
| 494 |
+
lora_alpha=lora_rank,
|
| 495 |
+
init_lora_weights="gaussian",
|
| 496 |
+
target_modules=target_modules,
|
| 497 |
+
)
|
| 498 |
+
transformer.add_adapter(lora_config)
|
| 499 |
+
transformer.train()
|
| 500 |
+
print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
|
| 501 |
+
|
| 502 |
+
def unwrap_model(model):
|
| 503 |
+
model = accelerator.unwrap_model(model)
|
| 504 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
| 505 |
+
return model
|
| 506 |
+
|
| 507 |
+
if args.resume_from_checkpoint:
|
| 508 |
+
path = args.resume_from_checkpoint
|
| 509 |
+
global_step = int(path.split("-")[-1])
|
| 510 |
+
initial_global_step = global_step
|
| 511 |
+
else:
|
| 512 |
+
initial_global_step = 0
|
| 513 |
+
global_step = 0
|
| 514 |
+
first_epoch = 0
|
| 515 |
+
|
| 516 |
+
if args.scale_lr:
|
| 517 |
+
args.learning_rate = (
|
| 518 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if args.mixed_precision == "fp16":
|
| 522 |
+
models = [transformer]
|
| 523 |
+
cast_training_params(models, dtype=torch.float32)
|
| 524 |
+
|
| 525 |
+
params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
|
| 526 |
+
transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
|
| 527 |
+
# print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
|
| 528 |
+
|
| 529 |
+
optimizer_class = torch.optim.AdamW
|
| 530 |
+
optimizer = optimizer_class(
|
| 531 |
+
[transformer_parameters_with_lr],
|
| 532 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 533 |
+
weight_decay=args.adam_weight_decay,
|
| 534 |
+
eps=args.adam_epsilon,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
tokenizers = [tokenizer_one, tokenizer_two]
|
| 538 |
+
text_encoders = [text_encoder_one, text_encoder_two]
|
| 539 |
+
|
| 540 |
+
# Create dataset based on mode
|
| 541 |
+
if args.dataset_mode == "mixed":
|
| 542 |
+
# Mixed mode: combine all available datasets
|
| 543 |
+
train_dataset = make_mixed_dataset(
|
| 544 |
+
args,
|
| 545 |
+
tokenizers,
|
| 546 |
+
interactive_jsonl_path=args.train_data_jsonl,
|
| 547 |
+
placement_jsonl_path=args.placement_data_jsonl,
|
| 548 |
+
pexels_jsonl_path=args.pexels_data_jsonl,
|
| 549 |
+
interactive_base_dir=args.interactive_base_dir,
|
| 550 |
+
placement_base_dir=args.placement_base_dir,
|
| 551 |
+
pexels_base_dir=args.pexels_base_dir,
|
| 552 |
+
interactive_weight=args.interactive_weight,
|
| 553 |
+
placement_weight=args.placement_weight,
|
| 554 |
+
pexels_weight=args.pexels_weight,
|
| 555 |
+
accelerator=accelerator
|
| 556 |
+
)
|
| 557 |
+
weights_str = []
|
| 558 |
+
if args.train_data_jsonl:
|
| 559 |
+
weights_str.append(f"Interactive: {args.interactive_weight:.2f}")
|
| 560 |
+
if args.placement_data_jsonl:
|
| 561 |
+
weights_str.append(f"Placement: {args.placement_weight:.2f}")
|
| 562 |
+
if args.pexels_data_jsonl:
|
| 563 |
+
weights_str.append(f"Pexels: {args.pexels_weight:.2f}")
|
| 564 |
+
logger.info(f"Mixed dataset created with weights - {', '.join(weights_str)}")
|
| 565 |
+
elif args.dataset_mode == "pexels":
|
| 566 |
+
if not args.pexels_data_jsonl:
|
| 567 |
+
raise ValueError("pexels_data_jsonl must be provided for pexels mode")
|
| 568 |
+
train_dataset = make_pexels_dataset_subjects(args, tokenizers, accelerator)
|
| 569 |
+
elif args.dataset_mode == "placement":
|
| 570 |
+
if not args.placement_data_jsonl:
|
| 571 |
+
raise ValueError("placement_data_jsonl must be provided for placement mode")
|
| 572 |
+
train_dataset = make_placement_dataset_subjects(args, tokenizers, accelerator)
|
| 573 |
+
else: # interactive mode
|
| 574 |
+
train_dataset = make_interactive_dataset_subjects(args, tokenizers, accelerator)
|
| 575 |
+
|
| 576 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 577 |
+
train_dataset,
|
| 578 |
+
batch_size=args.train_batch_size,
|
| 579 |
+
shuffle=True,
|
| 580 |
+
collate_fn=collate_fn,
|
| 581 |
+
num_workers=args.dataloader_num_workers,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
vae_config_shift_factor = vae.config.shift_factor
|
| 585 |
+
vae_config_scaling_factor = vae.config.scaling_factor
|
| 586 |
+
|
| 587 |
+
overrode_max_train_steps = False
|
| 588 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 589 |
+
if args.resume_from_checkpoint:
|
| 590 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 591 |
+
if args.max_train_steps is None:
|
| 592 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 593 |
+
overrode_max_train_steps = True
|
| 594 |
+
|
| 595 |
+
lr_scheduler = get_scheduler(
|
| 596 |
+
args.lr_scheduler,
|
| 597 |
+
optimizer=optimizer,
|
| 598 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 599 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 600 |
+
num_cycles=args.lr_num_cycles,
|
| 601 |
+
power=args.lr_power,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 605 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 609 |
+
if overrode_max_train_steps:
|
| 610 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 611 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 612 |
+
|
| 613 |
+
# Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
|
| 614 |
+
def _sanitize_hparams(config_dict):
|
| 615 |
+
sanitized = {}
|
| 616 |
+
for key, value in dict(config_dict).items():
|
| 617 |
+
try:
|
| 618 |
+
if value is None:
|
| 619 |
+
continue
|
| 620 |
+
# numpy scalar types
|
| 621 |
+
if isinstance(value, (np.integer,)):
|
| 622 |
+
sanitized[key] = int(value)
|
| 623 |
+
elif isinstance(value, (np.floating,)):
|
| 624 |
+
sanitized[key] = float(value)
|
| 625 |
+
elif isinstance(value, (int, float, bool, str)):
|
| 626 |
+
sanitized[key] = value
|
| 627 |
+
elif isinstance(value, Path):
|
| 628 |
+
sanitized[key] = str(value)
|
| 629 |
+
elif isinstance(value, (list, tuple)):
|
| 630 |
+
# stringify simple sequences; skip if fails
|
| 631 |
+
sanitized[key] = str(value)
|
| 632 |
+
else:
|
| 633 |
+
# best-effort stringify
|
| 634 |
+
sanitized[key] = str(value)
|
| 635 |
+
except Exception:
|
| 636 |
+
# skip unconvertible entries
|
| 637 |
+
continue
|
| 638 |
+
return sanitized
|
| 639 |
+
|
| 640 |
+
if accelerator.is_main_process:
|
| 641 |
+
tracker_name = "Easy_Control_Kontext"
|
| 642 |
+
accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
|
| 643 |
+
|
| 644 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 645 |
+
logger.info("***** Running training *****")
|
| 646 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 647 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 648 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 649 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 650 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 651 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 652 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 653 |
+
|
| 654 |
+
progress_bar = tqdm(
|
| 655 |
+
range(0, args.max_train_steps),
|
| 656 |
+
initial=initial_global_step,
|
| 657 |
+
desc="Steps",
|
| 658 |
+
disable=not accelerator.is_local_main_process,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
| 662 |
+
sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
| 663 |
+
schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
|
| 664 |
+
timesteps = timesteps.to(accelerator.device)
|
| 665 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 666 |
+
sigma = sigmas[step_indices].flatten()
|
| 667 |
+
while len(sigma.shape) < n_dim:
|
| 668 |
+
sigma = sigma.unsqueeze(-1)
|
| 669 |
+
return sigma
|
| 670 |
+
|
| 671 |
+
# Kontext specifics
|
| 672 |
+
vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
|
| 673 |
+
|
| 674 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 675 |
+
transformer.train()
|
| 676 |
+
for step, batch in enumerate(train_dataloader):
|
| 677 |
+
models_to_accumulate = [transformer]
|
| 678 |
+
with accelerator.accumulate(models_to_accumulate):
|
| 679 |
+
tokens = [batch["text_ids_1"], batch["text_ids_2"]]
|
| 680 |
+
prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
|
| 681 |
+
prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 682 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
|
| 683 |
+
text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
|
| 684 |
+
|
| 685 |
+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
| 686 |
+
height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
|
| 687 |
+
width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
|
| 688 |
+
|
| 689 |
+
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 690 |
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
| 691 |
+
model_input = model_input.to(dtype=weight_dtype)
|
| 692 |
+
|
| 693 |
+
# Prepare latent ids for transformer (positional encodings)
|
| 694 |
+
latent_image_ids = FluxKontextPipeline._prepare_latent_image_ids(
|
| 695 |
+
batch_size=model_input.shape[0],
|
| 696 |
+
height=model_input.shape[2] // 2,
|
| 697 |
+
width=model_input.shape[3] // 2,
|
| 698 |
+
device=accelerator.device,
|
| 699 |
+
dtype=weight_dtype,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
noise = torch.randn_like(model_input)
|
| 703 |
+
bsz = model_input.shape[0]
|
| 704 |
+
|
| 705 |
+
u = compute_density_for_timestep_sampling(
|
| 706 |
+
weighting_scheme=args.weighting_scheme,
|
| 707 |
+
batch_size=bsz,
|
| 708 |
+
logit_mean=args.logit_mean,
|
| 709 |
+
logit_std=args.logit_std,
|
| 710 |
+
mode_scale=args.mode_scale,
|
| 711 |
+
)
|
| 712 |
+
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
| 713 |
+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
| 714 |
+
|
| 715 |
+
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
| 716 |
+
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
| 717 |
+
|
| 718 |
+
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
|
| 719 |
+
noisy_model_input,
|
| 720 |
+
batch_size=model_input.shape[0],
|
| 721 |
+
num_channels_latents=model_input.shape[1],
|
| 722 |
+
height=model_input.shape[2],
|
| 723 |
+
width=model_input.shape[3],
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
|
| 727 |
+
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
| 728 |
+
guidance = guidance.expand(model_input.shape[0])
|
| 729 |
+
else:
|
| 730 |
+
guidance = None
|
| 731 |
+
|
| 732 |
+
# If kontext editing is enabled, append source image latents to the sequence
|
| 733 |
+
latent_model_input = packed_noisy_model_input
|
| 734 |
+
if args.kontext == "enable":
|
| 735 |
+
source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
|
| 736 |
+
source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
|
| 737 |
+
source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
|
| 738 |
+
image_latent_h, image_latent_w = source_image_latents.shape[2:]
|
| 739 |
+
packed_image_latents = FluxKontextPipeline._pack_latents(
|
| 740 |
+
source_image_latents,
|
| 741 |
+
batch_size=source_image_latents.shape[0],
|
| 742 |
+
num_channels_latents=source_image_latents.shape[1],
|
| 743 |
+
height=image_latent_h,
|
| 744 |
+
width=image_latent_w,
|
| 745 |
+
)
|
| 746 |
+
source_image_ids = FluxKontextPipeline._prepare_latent_image_ids(
|
| 747 |
+
batch_size=source_image_latents.shape[0],
|
| 748 |
+
height=image_latent_h // 2,
|
| 749 |
+
width=image_latent_w // 2,
|
| 750 |
+
device=accelerator.device,
|
| 751 |
+
dtype=weight_dtype,
|
| 752 |
+
)
|
| 753 |
+
source_image_ids[..., 0] = 1
|
| 754 |
+
latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
|
| 755 |
+
latent_image_ids = torch.cat([latent_image_ids, source_image_ids], dim=0)
|
| 756 |
+
|
| 757 |
+
# Forward transformer with packed latents and ids
|
| 758 |
+
model_pred = transformer(
|
| 759 |
+
hidden_states=latent_model_input,
|
| 760 |
+
timestep=timesteps / 1000,
|
| 761 |
+
guidance=guidance,
|
| 762 |
+
pooled_projections=pooled_prompt_embeds,
|
| 763 |
+
encoder_hidden_states=prompt_embeds,
|
| 764 |
+
txt_ids=text_ids,
|
| 765 |
+
img_ids=latent_image_ids,
|
| 766 |
+
return_dict=False,
|
| 767 |
+
)[0]
|
| 768 |
+
|
| 769 |
+
model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
|
| 770 |
+
|
| 771 |
+
model_pred = FluxKontextPipeline._unpack_latents(
|
| 772 |
+
model_pred,
|
| 773 |
+
height=int(pixel_values.shape[-2]),
|
| 774 |
+
width=int(pixel_values.shape[-1]),
|
| 775 |
+
vae_scale_factor=vae_scale_factor,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 779 |
+
target = noise - model_input
|
| 780 |
+
|
| 781 |
+
loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
|
| 782 |
+
loss = loss.mean()
|
| 783 |
+
accelerator.backward(loss)
|
| 784 |
+
if accelerator.sync_gradients:
|
| 785 |
+
params_to_clip = (transformer.parameters())
|
| 786 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 787 |
+
|
| 788 |
+
optimizer.step()
|
| 789 |
+
lr_scheduler.step()
|
| 790 |
+
optimizer.zero_grad()
|
| 791 |
+
|
| 792 |
+
if accelerator.sync_gradients:
|
| 793 |
+
progress_bar.update(1)
|
| 794 |
+
global_step += 1
|
| 795 |
+
|
| 796 |
+
if accelerator.is_main_process:
|
| 797 |
+
if global_step % args.checkpointing_steps == 0:
|
| 798 |
+
if args.checkpoints_total_limit is not None:
|
| 799 |
+
checkpoints = os.listdir(args.output_dir)
|
| 800 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 801 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 802 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
| 803 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| 804 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 805 |
+
logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
|
| 806 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 807 |
+
for removing_checkpoint in removing_checkpoints:
|
| 808 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| 809 |
+
shutil.rmtree(removing_checkpoint)
|
| 810 |
+
|
| 811 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 812 |
+
os.makedirs(save_path, exist_ok=True)
|
| 813 |
+
unwrapped = accelerator.unwrap_model(transformer)
|
| 814 |
+
peft_state = get_peft_model_state_dict(unwrapped)
|
| 815 |
+
# Convert PEFT state dict to diffusers LoRA format for transformer
|
| 816 |
+
diffusers_lora = convert_state_dict_to_diffusers(peft_state)
|
| 817 |
+
save_file(diffusers_lora, os.path.join(save_path, "pytorch_lora_weights.safetensors"))
|
| 818 |
+
logger.info(f"Saved state to {save_path}")
|
| 819 |
+
|
| 820 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 821 |
+
progress_bar.set_postfix(**logs)
|
| 822 |
+
accelerator.log(logs, step=global_step)
|
| 823 |
+
|
| 824 |
+
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
| 825 |
+
# Create pipeline on every rank to run validation in parallel
|
| 826 |
+
pipeline = FluxKontextPipeline.from_pretrained(
|
| 827 |
+
args.pretrained_model_name_or_path,
|
| 828 |
+
vae=vae,
|
| 829 |
+
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
| 830 |
+
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
| 831 |
+
transformer=accelerator.unwrap_model(transformer),
|
| 832 |
+
revision=args.revision,
|
| 833 |
+
variant=args.variant,
|
| 834 |
+
torch_dtype=weight_dtype,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
pipeline_args = {
|
| 838 |
+
"prompt": args.validation_prompt,
|
| 839 |
+
"guidance_scale": 3.5,
|
| 840 |
+
"num_inference_steps": 20,
|
| 841 |
+
"max_sequence_length": 128,
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
images = log_validation(
|
| 845 |
+
pipeline=pipeline,
|
| 846 |
+
args=args,
|
| 847 |
+
accelerator=accelerator,
|
| 848 |
+
pipeline_args=pipeline_args,
|
| 849 |
+
step=global_step,
|
| 850 |
+
torch_dtype=weight_dtype,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# Only main process saves/logs
|
| 854 |
+
if accelerator.is_main_process:
|
| 855 |
+
save_path = os.path.join(args.output_dir, "validation")
|
| 856 |
+
os.makedirs(save_path, exist_ok=True)
|
| 857 |
+
save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
|
| 858 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 859 |
+
for idx, img in enumerate(images):
|
| 860 |
+
out_path = os.path.join(save_folder, f"{idx}.jpg")
|
| 861 |
+
save_with_retry(img, out_path)
|
| 862 |
+
del pipeline
|
| 863 |
+
|
| 864 |
+
accelerator.wait_for_everyone()
|
| 865 |
+
accelerator.end_training()
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
if __name__ == "__main__":
|
| 869 |
+
args = parse_args()
|
| 870 |
+
main(args)
|
| 871 |
+
|
util.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import Counter
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import cv2 # OpenCV
|
| 6 |
+
import torch
|
| 7 |
+
import re
|
| 8 |
+
import io
|
| 9 |
+
import base64
|
| 10 |
+
from PIL import Image, ImageOps
|
| 11 |
+
from src.pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
|
| 12 |
+
|
| 13 |
+
def get_bounding_box_from_mask(mask, padded=False):
|
| 14 |
+
mask = mask.squeeze()
|
| 15 |
+
rows, cols = torch.where(mask > 0.5)
|
| 16 |
+
if len(rows) == 0 or len(cols) == 0:
|
| 17 |
+
return (0, 0, 0, 0)
|
| 18 |
+
height, width = mask.shape
|
| 19 |
+
if padded:
|
| 20 |
+
padded_size = max(width, height)
|
| 21 |
+
if width < height:
|
| 22 |
+
offset_x = (padded_size - width) / 2
|
| 23 |
+
offset_y = 0
|
| 24 |
+
else:
|
| 25 |
+
offset_y = (padded_size - height) / 2
|
| 26 |
+
offset_x = 0
|
| 27 |
+
top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
|
| 28 |
+
bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
|
| 29 |
+
top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
|
| 30 |
+
bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
|
| 31 |
+
else:
|
| 32 |
+
offset_x = 0
|
| 33 |
+
offset_y = 0
|
| 34 |
+
|
| 35 |
+
top_left_x = round(float(torch.min(cols).item() / width), 3)
|
| 36 |
+
bottom_right_x = round(float(torch.max(cols).item() / width), 3)
|
| 37 |
+
top_left_y = round(float(torch.min(rows).item() / height), 3)
|
| 38 |
+
bottom_right_y = round(float(torch.max(rows).item() / height), 3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
| 42 |
+
|
| 43 |
+
def extract_bbox(text):
|
| 44 |
+
pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
|
| 45 |
+
match = re.search(pattern, text)
|
| 46 |
+
return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))
|
| 47 |
+
|
| 48 |
+
def resize_bbox(bbox, width_ratio, height_ratio):
|
| 49 |
+
x1, y1, x2, y2 = bbox
|
| 50 |
+
new_x1 = int(x1 * width_ratio)
|
| 51 |
+
new_y1 = int(y1 * height_ratio)
|
| 52 |
+
new_x2 = int(x2 * width_ratio)
|
| 53 |
+
new_y2 = int(y2 * height_ratio)
|
| 54 |
+
|
| 55 |
+
return (new_x1, new_y1, new_x2, new_y2)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def tensor_to_base64(tensor, quality=80, method=6):
|
| 59 |
+
tensor = tensor.squeeze(0).clone().detach().cpu()
|
| 60 |
+
|
| 61 |
+
if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
|
| 62 |
+
tensor *= 255
|
| 63 |
+
tensor = tensor.to(torch.uint8)
|
| 64 |
+
|
| 65 |
+
if tensor.ndim == 2: # 灰度图像
|
| 66 |
+
pil_image = Image.fromarray(tensor.numpy(), 'L')
|
| 67 |
+
pil_image = pil_image.convert('RGB')
|
| 68 |
+
elif tensor.ndim == 3:
|
| 69 |
+
if tensor.shape[2] == 1: # 单通道
|
| 70 |
+
pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
|
| 71 |
+
pil_image = pil_image.convert('RGB')
|
| 72 |
+
elif tensor.shape[2] == 3: # RGB
|
| 73 |
+
pil_image = Image.fromarray(tensor.numpy(), 'RGB')
|
| 74 |
+
elif tensor.shape[2] == 4: # RGBA
|
| 75 |
+
pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
|
| 80 |
+
|
| 81 |
+
buffered = io.BytesIO()
|
| 82 |
+
pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
|
| 83 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 84 |
+
return img_str
|
| 85 |
+
|
| 86 |
+
def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
|
| 87 |
+
image = Image.open(image_path)
|
| 88 |
+
image = ImageOps.exif_transpose(image)
|
| 89 |
+
|
| 90 |
+
if image.mode == 'RGBA':
|
| 91 |
+
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
| 92 |
+
image = Image.alpha_composite(background, image)
|
| 93 |
+
image = image.convert(convert_to)
|
| 94 |
+
image_array = np.array(image).astype(np.float32) / 255.0
|
| 95 |
+
|
| 96 |
+
if has_alpha and convert_to == 'RGBA':
|
| 97 |
+
image_tensor = torch.from_numpy(image_array)[None,]
|
| 98 |
+
else:
|
| 99 |
+
if len(image_array.shape) == 3 and image_array.shape[2] > 3:
|
| 100 |
+
image_array = image_array[:, :, :3]
|
| 101 |
+
image_tensor = torch.from_numpy(image_array)[None,]
|
| 102 |
+
|
| 103 |
+
return image_tensor
|
| 104 |
+
|
| 105 |
+
def process_background(base64_image, convert_to='RGB', size=None):
|
| 106 |
+
image_data = read_base64_image(base64_image)
|
| 107 |
+
image = Image.open(image_data)
|
| 108 |
+
image = ImageOps.exif_transpose(image)
|
| 109 |
+
image = image.convert(convert_to)
|
| 110 |
+
|
| 111 |
+
# Select preferred size by closest aspect ratio, then snap to multiple_of
|
| 112 |
+
w0, h0 = image.size
|
| 113 |
+
aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
|
| 114 |
+
# Choose the (w, h) whose aspect ratio is closest to the input
|
| 115 |
+
_, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
|
| 116 |
+
multiple_of = 16 # default: vae_scale_factor (8) * 2
|
| 117 |
+
tw = (tw // multiple_of) * multiple_of
|
| 118 |
+
th = (th // multiple_of) * multiple_of
|
| 119 |
+
|
| 120 |
+
if (w0, h0) != (tw, th):
|
| 121 |
+
image = image.resize((tw, th), resample=Image.BICUBIC)
|
| 122 |
+
|
| 123 |
+
image_array = np.array(image).astype(np.uint8)
|
| 124 |
+
image_tensor = torch.from_numpy(image_array)[None,]
|
| 125 |
+
return image_tensor
|
| 126 |
+
|
| 127 |
+
def read_base64_image(base64_image):
|
| 128 |
+
if base64_image.startswith("data:image/png;base64,"):
|
| 129 |
+
base64_image = base64_image.split(",")[1]
|
| 130 |
+
elif base64_image.startswith("data:image/jpeg;base64,"):
|
| 131 |
+
base64_image = base64_image.split(",")[1]
|
| 132 |
+
elif base64_image.startswith("data:image/webp;base64,"):
|
| 133 |
+
base64_image = base64_image.split(",")[1]
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError("Unsupported image format.")
|
| 136 |
+
image_data = base64.b64decode(base64_image)
|
| 137 |
+
return io.BytesIO(image_data)
|
| 138 |
+
|
| 139 |
+
def create_alpha_mask(image_path):
|
| 140 |
+
"""Create an alpha mask from the alpha channel of an image."""
|
| 141 |
+
image = Image.open(image_path)
|
| 142 |
+
image = ImageOps.exif_transpose(image)
|
| 143 |
+
mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
|
| 144 |
+
if 'A' in image.getbands():
|
| 145 |
+
alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
|
| 146 |
+
mask[0] = 1.0 - torch.from_numpy(alpha_channel)
|
| 147 |
+
return mask
|
| 148 |
+
|
| 149 |
+
def get_mask_bbox(mask_tensor, padding=10):
|
| 150 |
+
assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
|
| 151 |
+
_, H, W = mask_tensor.shape
|
| 152 |
+
mask_2d = mask_tensor.squeeze(0)
|
| 153 |
+
|
| 154 |
+
y_coords, x_coords = torch.where(mask_2d > 0)
|
| 155 |
+
|
| 156 |
+
if len(y_coords) == 0:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
x_min = int(torch.min(x_coords))
|
| 160 |
+
y_min = int(torch.min(y_coords))
|
| 161 |
+
x_max = int(torch.max(x_coords))
|
| 162 |
+
y_max = int(torch.max(y_coords))
|
| 163 |
+
|
| 164 |
+
x_min = max(0, x_min - padding)
|
| 165 |
+
y_min = max(0, y_min - padding)
|
| 166 |
+
x_max = min(W - 1, x_max + padding)
|
| 167 |
+
y_max = min(H - 1, y_max + padding)
|
| 168 |
+
|
| 169 |
+
return x_min, y_min, x_max, y_max
|
| 170 |
+
|
| 171 |
+
def tensor_to_pil(tensor):
|
| 172 |
+
tensor = tensor.squeeze(0).clone().detach().cpu()
|
| 173 |
+
if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
|
| 174 |
+
if tensor.max() <= 1.0:
|
| 175 |
+
tensor *= 255
|
| 176 |
+
tensor = tensor.to(torch.uint8)
|
| 177 |
+
|
| 178 |
+
if tensor.ndim == 2: # 灰度图像 [H, W]
|
| 179 |
+
return Image.fromarray(tensor.numpy(), 'L')
|
| 180 |
+
elif tensor.ndim == 3:
|
| 181 |
+
if tensor.shape[2] == 1: # 单通道 [H, W, 1]
|
| 182 |
+
return Image.fromarray(tensor.numpy().squeeze(2), 'L')
|
| 183 |
+
elif tensor.shape[2] >= 3: # RGB [H, W, 3]
|
| 184 |
+
return Image.fromarray(tensor.numpy(), 'RGB')
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"不支持的tensor维度: {tensor.ndim}")
|
utils_node.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import trange
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
import scipy.ndimage
|
| 10 |
+
import cv2
|
| 11 |
+
from train.src.condition.util import HWC3, common_input_validate
|
| 12 |
+
|
| 13 |
+
def check_image_mask(image, mask, name):
|
| 14 |
+
if len(image.shape) < 4:
|
| 15 |
+
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
| 16 |
+
image = image[None,:,:,:]
|
| 17 |
+
|
| 18 |
+
if len(mask.shape) > 3:
|
| 19 |
+
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
|
| 20 |
+
# take first mask, red channel
|
| 21 |
+
mask = (mask[:,:,:,0])[:,:,:]
|
| 22 |
+
elif len(mask.shape) < 3:
|
| 23 |
+
# mask tensor shape should be [B, H, W] but batch somehow is missing
|
| 24 |
+
mask = mask[None,:,:]
|
| 25 |
+
|
| 26 |
+
if image.shape[0] > mask.shape[0]:
|
| 27 |
+
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
|
| 28 |
+
if mask.shape[0] == 1:
|
| 29 |
+
print(name, "will copy the mask to fill batch")
|
| 30 |
+
mask = torch.cat([mask] * image.shape[0], dim=0)
|
| 31 |
+
else:
|
| 32 |
+
print(name, "will add empty masks to fill batch")
|
| 33 |
+
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
|
| 34 |
+
mask = torch.cat([mask, empty_mask], dim=0)
|
| 35 |
+
elif image.shape[0] < mask.shape[0]:
|
| 36 |
+
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
|
| 37 |
+
mask = mask[:image.shape[0],:,:]
|
| 38 |
+
|
| 39 |
+
return (image, mask)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cv2_resize_shortest_edge(image, size):
|
| 43 |
+
h, w = image.shape[:2]
|
| 44 |
+
if h < w:
|
| 45 |
+
new_h = size
|
| 46 |
+
new_w = int(round(w / h * size))
|
| 47 |
+
else:
|
| 48 |
+
new_w = size
|
| 49 |
+
new_h = int(round(h / w * size))
|
| 50 |
+
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 51 |
+
return resized_image
|
| 52 |
+
|
| 53 |
+
def apply_color(img, res=512):
|
| 54 |
+
img = cv2_resize_shortest_edge(img, res)
|
| 55 |
+
h, w = img.shape[:2]
|
| 56 |
+
|
| 57 |
+
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
|
| 58 |
+
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 59 |
+
return input_img_color
|
| 60 |
+
|
| 61 |
+
#Color T2I like multiples-of-64, upscale methods are fixed.
|
| 62 |
+
class ColorDetector:
|
| 63 |
+
def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
|
| 64 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 65 |
+
input_image = HWC3(input_image)
|
| 66 |
+
detected_map = HWC3(apply_color(input_image, detect_resolution))
|
| 67 |
+
|
| 68 |
+
if output_type == "pil":
|
| 69 |
+
detected_map = Image.fromarray(detected_map)
|
| 70 |
+
|
| 71 |
+
return detected_map
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class InpaintPreprocessor:
|
| 75 |
+
def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
|
| 76 |
+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
|
| 77 |
+
mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
|
| 78 |
+
image = image.clone()
|
| 79 |
+
if black_pixel_for_xinsir_cn:
|
| 80 |
+
masked_pixel = 0.0
|
| 81 |
+
else:
|
| 82 |
+
masked_pixel = -1.0
|
| 83 |
+
image[mask > 0.5] = masked_pixel
|
| 84 |
+
return (image,)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BlendInpaint:
|
| 88 |
+
def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
|
| 89 |
+
|
| 90 |
+
original, mask = check_image_mask(original, mask, 'Blend Inpaint')
|
| 91 |
+
|
| 92 |
+
if len(inpaint.shape) < 4:
|
| 93 |
+
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
| 94 |
+
inpaint = inpaint[None,:,:,:]
|
| 95 |
+
|
| 96 |
+
if inpaint.shape[0] < original.shape[0]:
|
| 97 |
+
print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
|
| 98 |
+
original= original[:inpaint.shape[0],:,:]
|
| 99 |
+
mask = mask[:inpaint.shape[0],:,:]
|
| 100 |
+
|
| 101 |
+
if inpaint.shape[0] > original.shape[0]:
|
| 102 |
+
# batch over inpaint
|
| 103 |
+
count = 0
|
| 104 |
+
original_list = []
|
| 105 |
+
mask_list = []
|
| 106 |
+
origin_list = []
|
| 107 |
+
while (count < inpaint.shape[0]):
|
| 108 |
+
for i in range(original.shape[0]):
|
| 109 |
+
original_list.append(original[i][None,:,:,:])
|
| 110 |
+
mask_list.append(mask[i][None,:,:])
|
| 111 |
+
if origin is not None:
|
| 112 |
+
origin_list.append(origin[i][None,:])
|
| 113 |
+
count += 1
|
| 114 |
+
if count >= inpaint.shape[0]:
|
| 115 |
+
break
|
| 116 |
+
original = torch.concat(original_list, dim=0)
|
| 117 |
+
mask = torch.concat(mask_list, dim=0)
|
| 118 |
+
if origin is not None:
|
| 119 |
+
origin = torch.concat(origin_list, dim=0)
|
| 120 |
+
|
| 121 |
+
if kernel % 2 == 0:
|
| 122 |
+
kernel += 1
|
| 123 |
+
transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
|
| 124 |
+
|
| 125 |
+
ret = []
|
| 126 |
+
blurred = []
|
| 127 |
+
for i in range(inpaint.shape[0]):
|
| 128 |
+
if origin is None:
|
| 129 |
+
blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
|
| 130 |
+
blurred.append(blurred_mask[0])
|
| 131 |
+
|
| 132 |
+
result = torch.nn.functional.interpolate(
|
| 133 |
+
inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
|
| 134 |
+
size=(
|
| 135 |
+
original[i].shape[0],
|
| 136 |
+
original[i].shape[1],
|
| 137 |
+
)
|
| 138 |
+
).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
|
| 139 |
+
else:
|
| 140 |
+
# got mask from CutForInpaint
|
| 141 |
+
height, width, _ = original[i].shape
|
| 142 |
+
x0 = origin[i][0].item()
|
| 143 |
+
y0 = origin[i][1].item()
|
| 144 |
+
|
| 145 |
+
if mask[i].shape[0] < height or mask[i].shape[1] < width:
|
| 146 |
+
padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
|
| 147 |
+
y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
|
| 148 |
+
else:
|
| 149 |
+
padded_mask = mask[i]
|
| 150 |
+
blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
|
| 151 |
+
blurred.append(blurred_mask[0][0])
|
| 152 |
+
|
| 153 |
+
result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
|
| 154 |
+
y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
|
| 155 |
+
result = result[None,:,:,:].to(original.device).to(original.dtype)
|
| 156 |
+
|
| 157 |
+
ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
|
| 158 |
+
|
| 159 |
+
return (torch.stack(ret), torch.stack(blurred), )
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def resize_mask(mask, shape):
|
| 163 |
+
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
| 164 |
+
|
| 165 |
+
class JoinImageWithAlpha:
|
| 166 |
+
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
| 167 |
+
batch_size = min(len(image), len(alpha))
|
| 168 |
+
out_images = []
|
| 169 |
+
|
| 170 |
+
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
| 171 |
+
for i in range(batch_size):
|
| 172 |
+
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
| 173 |
+
|
| 174 |
+
result = (torch.stack(out_images),)
|
| 175 |
+
return result
|
| 176 |
+
|
| 177 |
+
class GrowMask:
|
| 178 |
+
def expand_mask(self, mask, expand, tapered_corners):
|
| 179 |
+
c = 0 if tapered_corners else 1
|
| 180 |
+
kernel = np.array([[c, 1, c],
|
| 181 |
+
[1, 1, 1],
|
| 182 |
+
[c, 1, c]])
|
| 183 |
+
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
| 184 |
+
out = []
|
| 185 |
+
for m in mask:
|
| 186 |
+
output = m.numpy()
|
| 187 |
+
for _ in range(abs(expand)):
|
| 188 |
+
if expand < 0:
|
| 189 |
+
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
|
| 190 |
+
else:
|
| 191 |
+
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
| 192 |
+
output = torch.from_numpy(output)
|
| 193 |
+
out.append(output)
|
| 194 |
+
return (torch.stack(out, dim=0),)
|
| 195 |
+
|
| 196 |
+
class InvertMask:
|
| 197 |
+
def invert(self, mask):
|
| 198 |
+
out = 1.0 - mask
|
| 199 |
+
return (out,)
|