LiuZichen commited on
Commit
f460ce6
·
1 Parent(s): c57bc42
Files changed (41) hide show
  1. README.md +2 -2
  2. app.py +439 -4
  3. edit_space.py +461 -0
  4. requirements.txt +28 -0
  5. src/__init__.py +0 -0
  6. src/layers_cache.py +406 -0
  7. src/lora_helper.py +194 -0
  8. src/pipeline_flux_kontext_control.py +1230 -0
  9. src/transformer_flux.py +608 -0
  10. train/default_config.yaml +16 -0
  11. train/src/__init__.py +0 -0
  12. train/src/condition/edge_extraction.py +356 -0
  13. train/src/condition/hed.py +56 -0
  14. train/src/condition/informative_drawing.py +279 -0
  15. train/src/condition/lineart.py +86 -0
  16. train/src/condition/pidi.py +681 -0
  17. train/src/condition/ted.py +296 -0
  18. train/src/condition/util.py +202 -0
  19. train/src/generate_diff_mask.py +301 -0
  20. train/src/jsonl_datasets_kontext_color.py +166 -0
  21. train/src/jsonl_datasets_kontext_complete_lora.py +363 -0
  22. train/src/jsonl_datasets_kontext_edge.py +225 -0
  23. train/src/jsonl_datasets_kontext_interactive_lora.py +1332 -0
  24. train/src/jsonl_datasets_kontext_local.py +312 -0
  25. train/src/layers.py +279 -0
  26. train/src/lora_helper.py +196 -0
  27. train/src/masks_integrated.py +322 -0
  28. train/src/pipeline_flux_kontext_control.py +1009 -0
  29. train/src/prompt_helper.py +205 -0
  30. train/src/transformer_flux.py +625 -0
  31. train/train_kontext_color.py +858 -0
  32. train/train_kontext_color.sh +25 -0
  33. train/train_kontext_complete_lora.sh +20 -0
  34. train/train_kontext_edge.py +814 -0
  35. train/train_kontext_edge.sh +25 -0
  36. train/train_kontext_interactive_lora.sh +18 -0
  37. train/train_kontext_local.py +876 -0
  38. train/train_kontext_local.sh +26 -0
  39. train/train_kontext_lora.py +871 -0
  40. util.py +188 -0
  41. utils_node.py +199 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: MagicQuillV2
3
- emoji: 🏆
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: MagicQuillV2
3
+ emoji: 🪶
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,7 +1,442 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
  import gradio as gr
4
+ import spaces
5
+ import tempfile
6
+ import numpy as np
7
+ import io
8
+ import base64
9
+ from gradio_client import Client, handle_file
10
+ from huggingface_hub import snapshot_download
11
+ from gradio_magicquillv2 import MagicQuillV2
12
+ from fastapi import FastAPI, Request
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ import uvicorn
15
+ import requests
16
+ from PIL import Image, ImageOps
17
+ import random
18
+ import time
19
+ import torch
20
+ import json
21
 
22
+ # Try importing as a package (recommended)
23
+ from edit_space import KontextEditModel
24
+ from util import (
25
+ load_and_preprocess_image,
26
+ read_base64_image as read_base64_image_utils,
27
+ create_alpha_mask,
28
+ tensor_to_base64,
29
+ get_mask_bbox
30
+ )
31
 
32
+ # Initialize models
33
+ print("Downloading models...")
34
+ hf_token = os.environ.get("hf_token")
35
+ snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models", token=hf_token)
36
+
37
+ print("Initializing models...")
38
+ kontext_model = KontextEditModel()
39
+
40
+ # Initialize SAM Client
41
+ # Replace with your actual SAM Space ID
42
+ sam_client = Client("LiuZichen/MagicQuillHelper")
43
+ print("Models initialized.")
44
+
45
+ css = """
46
+ .ms {
47
+ width: 60%;
48
+ margin: auto
49
+ }
50
+ """
51
+
52
+ url = "http://localhost:7860"
53
+
54
+ @spaces.GPU
55
+ def generate(merged_image, total_mask, original_image, add_color_image, add_edge_mask, remove_edge_mask, fill_mask, add_prop_image, positive_prompt, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
56
+ print("prompt is:", positive_prompt)
57
+ print("other parameters:", negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg)
58
+
59
+ if kontext_model is None:
60
+ raise RuntimeError("KontextEditModel not initialized")
61
+
62
+ # Preprocess inputs
63
+ # utils.read_base64_image returns BytesIO, which create_alpha_mask accepts (via Image.open)
64
+ # load_and_preprocess_image accepts path, so we might need to check if it accepts file-like object.
65
+ # utils.load_and_preprocess_image uses Image.open(image_path), so BytesIO works.
66
+
67
+ merged_image_tensor = load_and_preprocess_image(read_base64_image_utils(merged_image))
68
+ total_mask_tensor = create_alpha_mask(read_base64_image_utils(total_mask))
69
+ original_image_tensor = load_and_preprocess_image(read_base64_image_utils(original_image))
70
+
71
+ if add_color_image:
72
+ add_color_image_tensor = load_and_preprocess_image(read_base64_image_utils(add_color_image))
73
+ else:
74
+ add_color_image_tensor = original_image_tensor
75
+
76
+ add_mask = create_alpha_mask(read_base64_image_utils(add_edge_mask)) if add_edge_mask else torch.zeros_like(total_mask_tensor)
77
+ remove_mask = create_alpha_mask(read_base64_image_utils(remove_edge_mask)) if remove_edge_mask else torch.zeros_like(total_mask_tensor)
78
+ add_prop_mask = create_alpha_mask(read_base64_image_utils(add_prop_image)) if add_prop_image else torch.zeros_like(total_mask_tensor)
79
+ fill_mask_tensor = create_alpha_mask(read_base64_image_utils(fill_mask)) if fill_mask else torch.zeros_like(total_mask_tensor)
80
+
81
+ # Determine flag and modify prompt
82
+ flag = "kontext"
83
+ if torch.sum(add_prop_mask) > 0:
84
+ flag = "foreground"
85
+ positive_prompt = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary. " + positive_prompt
86
+ elif torch.sum(fill_mask_tensor).item() > 0:
87
+ flag = "local"
88
+ elif (torch.sum(remove_mask).item() > 0 and torch.sum(add_mask).item() == 0):
89
+ positive_prompt = "remove the instance"
90
+ flag = "removal"
91
+ elif (torch.sum(add_mask).item() > 0 or torch.sum(remove_mask).item() > 0 or (not torch.equal(original_image_tensor, add_color_image_tensor))):
92
+ flag = "precise_edit"
93
+
94
+ print("positive prompt: ", positive_prompt)
95
+ print("current flag: ", flag)
96
+
97
+ final_image, condition, mask = kontext_model.process(
98
+ original_image_tensor,
99
+ add_color_image_tensor,
100
+ merged_image_tensor,
101
+ positive_prompt,
102
+ total_mask_tensor,
103
+ add_mask,
104
+ remove_mask,
105
+ add_prop_mask,
106
+ fill_mask_tensor,
107
+ fine_edge,
108
+ fix_perspective,
109
+ edge_strength,
110
+ color_strength,
111
+ local_strength,
112
+ grow_size,
113
+ seed,
114
+ steps,
115
+ cfg,
116
+ flag,
117
+ )
118
+
119
+ # tensor_to_base64 returns pure base64 string
120
+ res_base64 = tensor_to_base64(final_image)
121
+ return res_base64
122
+
123
+ def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
124
+ merged_image = x['from_frontend']['img']
125
+ total_mask = x['from_frontend']['total_mask']
126
+ original_image = x['from_frontend']['original_image']
127
+ add_color_image = x['from_frontend']['add_color_image']
128
+ add_edge_mask = x['from_frontend']['add_edge_mask']
129
+ remove_edge_mask = x['from_frontend']['remove_edge_mask']
130
+ fill_mask = x['from_frontend']['fill_mask']
131
+ add_prop_image = x['from_frontend']['add_prop_image']
132
+ positive_prompt = x['from_backend']['prompt']
133
+
134
+ try:
135
+ res_base64 = generate(
136
+ merged_image,
137
+ total_mask,
138
+ original_image,
139
+ add_color_image,
140
+ add_edge_mask,
141
+ remove_edge_mask,
142
+ fill_mask,
143
+ add_prop_image,
144
+ positive_prompt,
145
+ negative_prompt,
146
+ fine_edge,
147
+ fix_perspective,
148
+ grow_size,
149
+ edge_strength,
150
+ color_strength,
151
+ local_strength,
152
+ seed,
153
+ steps,
154
+ cfg
155
+ )
156
+ x["from_backend"]["generated_image"] = res_base64
157
+ except Exception as e:
158
+ print(f"Error in generation: {e}")
159
+ x["from_backend"]["generated_image"] = None
160
+
161
+ return x
162
+
163
+
164
+ with gr.Blocks(title="MagicQuill V2") as demo:
165
+ with gr.Row():
166
+ ms = MagicQuillV2()
167
+
168
+ with gr.Row():
169
+ with gr.Column():
170
+ btn = gr.Button("Run", variant="primary")
171
+ with gr.Column():
172
+ with gr.Accordion("parameters", open=False):
173
+ negative_prompt = gr.Textbox(
174
+ label="Negative Prompt",
175
+ value="",
176
+ interactive=True
177
+ )
178
+ fine_edge = gr.Radio(
179
+ label="Fine Edge",
180
+ choices=['enable', 'disable'],
181
+ value='disable',
182
+ interactive=True
183
+ )
184
+ fix_perspective = gr.Radio(
185
+ label="Fix Perspective",
186
+ choices=['enable', 'disable'],
187
+ value='disable',
188
+ interactive=True
189
+ )
190
+ grow_size = gr.Slider(
191
+ label="Grow Size",
192
+ minimum=10,
193
+ maximum=100,
194
+ value=50,
195
+ step=1,
196
+ interactive=True
197
+ )
198
+ edge_strength = gr.Slider(
199
+ label="Edge Strength",
200
+ minimum=0.0,
201
+ maximum=5.0,
202
+ value=0.6,
203
+ step=0.01,
204
+ interactive=True
205
+ )
206
+ color_strength = gr.Slider(
207
+ label="Color Strength",
208
+ minimum=0.0,
209
+ maximum=5.0,
210
+ value=1.5,
211
+ step=0.01,
212
+ interactive=True
213
+ )
214
+ local_strength = gr.Slider(
215
+ label="Local Strength",
216
+ minimum=0.0,
217
+ maximum=5.0,
218
+ value=1.0,
219
+ step=0.01,
220
+ interactive=True
221
+ )
222
+ seed = gr.Number(
223
+ label="Seed",
224
+ value=-1,
225
+ precision=0,
226
+ interactive=True
227
+ )
228
+ steps = gr.Slider(
229
+ label="Steps",
230
+ minimum=0,
231
+ maximum=50,
232
+ value=20,
233
+ interactive=True
234
+ )
235
+ cfg = gr.Slider(
236
+ label="CFG",
237
+ minimum=0.0,
238
+ maximum=20.0,
239
+ value=3.5,
240
+ step=0.1,
241
+ interactive=True
242
+ )
243
+
244
+ btn.click(generate_image_handler, inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], outputs=ms)
245
+
246
+ app = FastAPI()
247
+ app.add_middleware(
248
+ CORSMiddleware,
249
+ allow_origins=['*'],
250
+ allow_credentials=True,
251
+ allow_methods=["*"],
252
+ allow_headers=["*"],
253
+ )
254
+
255
+ def get_root_url(
256
+ request: Request, route_path: str, root_path: str | None
257
+ ):
258
+ print(root_path)
259
+ return root_path
260
+ import gradio.route_utils
261
+ gr.route_utils.get_root_url = get_root_url
262
+
263
+ gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo")
264
+
265
+ @app.post("/magic_quill/generate_image")
266
+ async def generate_image(request: Request):
267
+ data = await request.json()
268
+ res = generate(
269
+ data["merged_image"],
270
+ data["total_mask"],
271
+ data["original_image"],
272
+ data["add_color_image"],
273
+ data["add_edge_mask"],
274
+ data["remove_edge_mask"],
275
+ data["fill_mask"],
276
+ data["add_prop_image"],
277
+ data["positive_prompt"],
278
+ data["negative_prompt"],
279
+ data["fine_edge"],
280
+ data["fix_perspective"],
281
+ data["grow_size"],
282
+ data["edge_strength"],
283
+ data["color_strength"],
284
+ data["local_strength"],
285
+ data["seed"],
286
+ data["steps"],
287
+ data["cfg"]
288
+ )
289
+ return {'res': res}
290
+
291
+ @app.post("/magic_quill/process_background_img")
292
+ async def process_background_img(request: Request):
293
+ img = await request.json()
294
+ from util import process_background
295
+ # process_background returns tensor [1, H, W, 3] in uint8 or float
296
+ resized_img_tensor = process_background(img)
297
+
298
+ # tensor_to_base64 from util expects tensor
299
+ resized_img_base64 = "data:image/webp;base64," + tensor_to_base64(
300
+ resized_img_tensor,
301
+ quality=80,
302
+ method=6
303
+ )
304
+ return resized_img_base64
305
+
306
+ @app.post("/magic_quill/segmentation")
307
+ async def segmentation(request: Request):
308
+ json_data = await request.json()
309
+ image_base64 = json_data.get("image", None)
310
+ coordinates_positive = json_data.get("coordinates_positive", None)
311
+ coordinates_negative = json_data.get("coordinates_negative", None)
312
+ bboxes = json_data.get("bboxes", None)
313
+
314
+ if sam_client is None:
315
+ return {"error": "sam client not initialized"}
316
+
317
+ # Process coordinates and bboxes
318
+ pos_coordinates = None
319
+ if coordinates_positive and len(coordinates_positive) > 0:
320
+ pos_coordinates = []
321
+ for coord in coordinates_positive:
322
+ coord['x'] = int(round(coord['x']))
323
+ coord['y'] = int(round(coord['y']))
324
+ pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
325
+ pos_coordinates = json.dumps(pos_coordinates)
326
+
327
+ neg_coordinates = None
328
+ if coordinates_negative and len(coordinates_negative) > 0:
329
+ neg_coordinates = []
330
+ for coord in coordinates_negative:
331
+ coord['x'] = int(round(coord['x']))
332
+ coord['y'] = int(round(coord['y']))
333
+ neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
334
+ neg_coordinates = json.dumps(neg_coordinates)
335
+
336
+ bboxes_xyxy = None
337
+ if bboxes and len(bboxes) > 0:
338
+ valid_bboxes = []
339
+ for bbox in bboxes:
340
+ if (bbox.get("startX") is None or
341
+ bbox.get("startY") is None or
342
+ bbox.get("endX") is None or
343
+ bbox.get("endY") is None):
344
+ continue
345
+ else:
346
+ x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0)
347
+ y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0)
348
+ # Note: image_tensor not available here easily without loading image,
349
+ # but usually we don't need to clip strictly if SAM handles it or we clip to large values
350
+ # For now, we skip strict clipping against image dims or assume 10000
351
+ x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"])
352
+ y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"])
353
+ valid_bboxes.append((x_min, y_min, x_max, y_max))
354
+
355
+ bboxes_xyxy = []
356
+ for bbox in valid_bboxes:
357
+ x_min, y_min, x_max, y_max = bbox
358
+ bboxes_xyxy.append((x_min, y_min, x_max, y_max))
359
+
360
+ # Convert to JSON string if that's what the client expects, or keep as list
361
+ # Assuming JSON string for consistency with coords
362
+ if bboxes_xyxy:
363
+ bboxes_xyxy = json.dumps(bboxes_xyxy)
364
+
365
+ print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}")
366
+
367
+ try:
368
+ # Save base64 image to temp file
369
+ image_bytes = read_base64_image_utils(image_base64)
370
+ # Image.open to verify and save as WebP (smaller size)
371
+ pil_image = Image.open(image_bytes)
372
+ with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in:
373
+ pil_image.save(temp_in.name, format="WEBP", quality=80)
374
+ temp_in_path = temp_in.name
375
+
376
+ # Execute segmentation via Client
377
+ # We assume the remote space returns a filepath to the segmented image (with alpha)
378
+ # NOW it returns mask_np image
379
+ result_path = sam_client.predict(
380
+ handle_file(temp_in_path),
381
+ pos_coordinates,
382
+ neg_coordinates,
383
+ bboxes_xyxy,
384
+ api_name="/segment"
385
+ )
386
+
387
+ # Clean up input temp
388
+ os.unlink(temp_in_path)
389
+
390
+ # Process result
391
+ # result_path should be a generic object, usually a tuple (image_path, mask_path) or just image_path
392
+ # Depending on how the remote space is implemented.
393
+ if isinstance(result_path, (list, tuple)):
394
+ result_path = result_path[0] # Take the first return value if multiple
395
+
396
+ if not result_path or not os.path.exists(result_path):
397
+ raise RuntimeError("Client returned invalid result path")
398
+
399
+ # result_path is the Mask Image (White=Selected, Black=Background)
400
+ mask_pil = Image.open(result_path)
401
+ if mask_pil.mode != 'L':
402
+ mask_pil = mask_pil.convert('L')
403
+
404
+ pil_image = pil_image.convert("RGB")
405
+ if pil_image.size != mask_pil.size:
406
+ mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST)
407
+
408
+ r, g, b = pil_image.split()
409
+ res_pil = Image.merge("RGBA", (r, g, b, mask_pil))
410
+
411
+ # Extract bbox from mask (alpha)
412
+ mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0)
413
+ mask_bbox = get_mask_bbox(mask_tensor)
414
+ if mask_bbox:
415
+ x_min, y_min, x_max, y_max = mask_bbox
416
+ seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
417
+ else:
418
+ seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
419
+
420
+ print(seg_bbox)
421
+
422
+ # Convert result to base64
423
+ # We need to convert the PIL image to base64 string
424
+ buffered = io.BytesIO()
425
+ res_pil.save(buffered, format="PNG")
426
+ image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
427
+
428
+ return {
429
+ "error": False,
430
+ "segmentation_image": "data:image/png;base64," + image_base64_res,
431
+ "segmentation_bbox": seg_bbox
432
+ }
433
+
434
+ except Exception as e:
435
+ print(f"Error in segmentation: {e}")
436
+ return {"error": str(e)}
437
+
438
+ app = gr.mount_gradio_app(app, demo, "/")
439
+
440
+ if __name__ == "__main__":
441
+ uvicorn.run(app, host="0.0.0.0", port=7860)
442
+ # demo.launch()
edit_space.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import sys
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import json
9
+
10
+
11
+ # New imports for the diffuser pipeline
12
+ from src.pipeline_flux_kontext_control import FluxKontextControlPipeline
13
+ from src.transformer_flux import FluxTransformer2DModel
14
+
15
+ import tempfile
16
+ from safetensors.torch import load_file, save_file
17
+
18
+ _original_load_lora_weights = FluxKontextControlPipeline.load_lora_weights
19
+
20
+ def _patched_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
21
+ """自动转换混合格式的 LoRA 并添加 transformer 前缀"""
22
+ weight_name = kwargs.get("weight_name", "pytorch_lora_weights.safetensors")
23
+
24
+ if isinstance(pretrained_model_name_or_path_or_dict, str):
25
+ if os.path.isdir(pretrained_model_name_or_path_or_dict):
26
+ lora_file = os.path.join(pretrained_model_name_or_path_or_dict, weight_name)
27
+ else:
28
+ lora_file = pretrained_model_name_or_path_or_dict
29
+
30
+ if os.path.exists(lora_file):
31
+ state_dict = load_file(lora_file)
32
+
33
+ # 检查是否需要转换格式或添加前缀
34
+ needs_format_conversion = any('lora_A.weight' in k or 'lora_B.weight' in k for k in state_dict.keys())
35
+ needs_prefix = not any(k.startswith('transformer.') for k in state_dict.keys())
36
+
37
+ if needs_format_conversion or needs_prefix:
38
+ print(f"🔄 Processing LoRA: {lora_file}")
39
+ if needs_format_conversion:
40
+ print(f" - Converting PEFT format to diffusers format")
41
+ if needs_prefix:
42
+ print(f" - Adding 'transformer.' prefix to keys")
43
+
44
+ converted_state = {}
45
+ converted_count = 0
46
+
47
+ for key, value in state_dict.items():
48
+ new_key = key
49
+
50
+ # 步骤 1: 转换 PEFT 格式到 diffusers 格式
51
+ if 'lora_A.weight' in new_key:
52
+ new_key = new_key.replace('lora_A.weight', 'lora.down.weight')
53
+ converted_count += 1
54
+ elif 'lora_B.weight' in new_key:
55
+ new_key = new_key.replace('lora_B.weight', 'lora.up.weight')
56
+ converted_count += 1
57
+
58
+ # 步骤 2: 添加 transformer 前缀(如果还没有的话)
59
+ if not new_key.startswith('transformer.'):
60
+ new_key = f'transformer.{new_key}'
61
+
62
+ converted_state[new_key] = value
63
+
64
+ if needs_format_conversion:
65
+ print(f" ✅ Converted {converted_count} PEFT keys")
66
+ print(f" ✅ Total keys: {len(converted_state)}")
67
+
68
+ with tempfile.TemporaryDirectory() as temp_dir:
69
+ temp_file = os.path.join(temp_dir, weight_name)
70
+ save_file(converted_state, temp_file)
71
+ return _original_load_lora_weights(self, temp_dir, **kwargs)
72
+ else:
73
+ print(f"✅ LoRA already in correct format: {lora_file}")
74
+
75
+ # 不需要转换,使用原始方法
76
+ return _original_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs)
77
+
78
+ # 应用 monkey patch
79
+ FluxKontextControlPipeline.load_lora_weights = _patched_load_lora_weights
80
+ print("✅ Monkey patch applied to FluxKontextPipeline.load_lora_weights")
81
+
82
+ current_dir = os.path.dirname(os.path.abspath(__file__))
83
+ sys.path.append(current_dir)
84
+ sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))
85
+ sys.path.append(os.path.abspath(os.path.join(current_dir, '..', '..', 'comfy_extras')))
86
+
87
+ from train.src.condition.edge_extraction import InformativeDetector, HEDDetector
88
+ from utils_node import BlendInpaint, JoinImageWithAlpha, GrowMask, InvertMask, ColorDetector
89
+
90
+ TEST_MODE = False
91
+
92
+ class KontextEditModel():
93
+ def __init__(self, base_model_path="/data0/lzc/FLUX.1-Kontext-dev", device="cuda",
94
+ aux_lora_dir="models/v2_ckpt", easycontrol_base_dir="models/v2_ckpt",
95
+ aux_lora_weight_name="puzzle_lora.safetensors",
96
+ aux_lora_weight=1.0):
97
+ # Keep necessary preprocessors
98
+ self.mask_processor = GrowMask()
99
+ self.scribble_processor = HEDDetector.from_pretrained()
100
+ self.lineart_processor = InformativeDetector.from_pretrained()
101
+ self.color_processor = ColorDetector()
102
+ self.blender = BlendInpaint()
103
+
104
+ # Initialize the new pipeline (Kontext version)
105
+ self.device = device
106
+ self.pipe = FluxKontextControlPipeline.from_pretrained(base_model_path, torch_dtype=torch.bfloat16)
107
+ transformer = FluxTransformer2DModel.from_pretrained(
108
+ base_model_path,
109
+ subfolder="transformer",
110
+ torch_dtype=torch.bfloat16,
111
+ device=self.device
112
+ )
113
+ self.pipe.transformer = transformer
114
+ self.pipe.to(self.device, dtype=torch.bfloat16)
115
+
116
+ control_lora_config = {
117
+ "local": {
118
+ "path": os.path.join(easycontrol_base_dir, "local_lora.safetensors"),
119
+ "lora_weights": [1.0],
120
+ "cond_size": 512,
121
+ },
122
+ "removal": {
123
+ "path": os.path.join(easycontrol_base_dir, "removal_lora.safetensors"),
124
+ "lora_weights": [1.0],
125
+ "cond_size": 512,
126
+ },
127
+ "edge": {
128
+ "path": os.path.join(easycontrol_base_dir, "edge_lora.safetensors"),
129
+ "lora_weights": [1.0],
130
+ "cond_size": 512,
131
+ },
132
+ "color": {
133
+ "path": os.path.join(easycontrol_base_dir, "color_lora.safetensors"),
134
+ "lora_weights": [1.0],
135
+ "cond_size": 512,
136
+ },
137
+ }
138
+ self.pipe.load_control_loras(control_lora_config)
139
+
140
+ # Aux LoRA for foreground mode
141
+ self.aux_lora_weight_name = aux_lora_weight_name
142
+ self.aux_lora_dir = aux_lora_dir
143
+ self.aux_lora_weight = aux_lora_weight
144
+ self.aux_adapter_name = "aux"
145
+
146
+ from safetensors.torch import load_file as _sft_load
147
+ aux_path = os.path.join(self.aux_lora_dir, self.aux_lora_weight_name)
148
+ if os.path.isfile(aux_path):
149
+ self.pipe.load_lora_weights(aux_path, adapter_name=self.aux_adapter_name)
150
+ print(f"Loaded aux LoRA: {aux_path}")
151
+ # Ensure aux LoRA is disabled by default; it will be enabled only in foreground_edit
152
+ self._disable_aux_lora()
153
+ else:
154
+ print(f"Aux LoRA not found at {aux_path}, foreground mode will run without it.")
155
+
156
+
157
+ # gamma is now applied inside the pipeline based on control_dict
158
+
159
+ def _tensor_to_pil(self, tensor_image):
160
+ # Converts a ComfyUI-style tensor [1, H, W, 3] to a PIL Image
161
+ return Image.fromarray(np.clip(255. * tensor_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
162
+
163
+ def _pil_to_tensor(self, pil_image):
164
+ # Converts a PIL image to a ComfyUI-style tensor [1, H, W, 3]
165
+ return torch.from_numpy(np.array(pil_image).astype(np.float32) / 255.0).unsqueeze(0)
166
+
167
+ def clear_cache(self):
168
+ for name, attn_processor in self.pipe.transformer.attn_processors.items():
169
+ if hasattr(attn_processor, 'bank_kv'):
170
+ attn_processor.bank_kv.clear()
171
+ if hasattr(attn_processor, 'bank_attn'):
172
+ attn_processor.bank_attn = None
173
+
174
+ def _enable_aux_lora(self):
175
+ self.pipe.enable_lora()
176
+ self.pipe.set_adapters([self.aux_adapter_name], adapter_weights=[self.aux_lora_weight])
177
+ print(f"Enabled aux LoRA '{self.aux_adapter_name}' with weight {self.aux_lora_weight}")
178
+
179
+ def _disable_aux_lora(self):
180
+ self.pipe.disable_lora()
181
+ print("Disabled aux LoRA")
182
+
183
+ def _expand_mask(self, mask_tensor: torch.Tensor, expand: int = 0) -> torch.Tensor:
184
+ if expand <= 0:
185
+ return mask_tensor
186
+ expanded = self.mask_processor.expand_mask(mask_tensor, expand=expand, tapered_corners=True)[0]
187
+ return expanded
188
+
189
+ def _tensor_mask_to_pil3(self, mask_tensor: torch.Tensor) -> Image.Image:
190
+ mask_01 = torch.clamp(mask_tensor, 0.0, 1.0)
191
+ if mask_01.ndim == 3 and mask_01.shape[-1] == 3:
192
+ mask_01 = mask_01[..., 0]
193
+ if mask_01.ndim == 3 and mask_01.shape[0] == 1:
194
+ mask_01 = mask_01[0]
195
+ pil = self._tensor_to_pil(mask_01.unsqueeze(-1).repeat(1, 1, 3))
196
+ return pil
197
+
198
+ def _apply_black_mask(self, image_tensor: torch.Tensor, binary_mask: torch.Tensor) -> Image.Image:
199
+ # image_tensor: [1, H, W, 3] in [0,1]
200
+ # binary_mask: [H, W] or [1, H, W], 1=mask area (white)
201
+ if binary_mask.ndim == 3:
202
+ binary_mask = binary_mask[0]
203
+ mask_bool = (binary_mask > 0.5)
204
+ img = image_tensor.clone()
205
+ img[0][mask_bool] = 0.0
206
+ return self._tensor_to_pil(img)
207
+
208
+ def edge_edit(self,
209
+ image, colored_image, positive_prompt,
210
+ base_mask, add_mask, remove_mask,
211
+ fine_edge,
212
+ edge_strength, color_strength,
213
+ seed, steps, cfg):
214
+
215
+ generator = torch.Generator(device=self.device).manual_seed(seed)
216
+
217
+ # Prepare mask and original image
218
+ original_image_tensor = image.clone()
219
+ original_mask = base_mask
220
+ original_mask = self._expand_mask(original_mask, expand=25)
221
+
222
+ image_pil = self._tensor_to_pil(image)
223
+ # image_pil.save("image_pil.png")
224
+ control_dict = {}
225
+ lineart_output = None
226
+
227
+ # Determine control type: color or edge
228
+ if not torch.equal(image, colored_image):
229
+ print("Apply color control")
230
+ colored_image_pil = self._tensor_to_pil(colored_image)
231
+ # Create color block condition
232
+ color_image_np = np.array(colored_image_pil)
233
+ downsampled = cv2.resize(color_image_np, (32, 32), interpolation=cv2.INTER_AREA)
234
+ upsampled = cv2.resize(downsampled, (256, 256), interpolation=cv2.INTER_NEAREST)
235
+ color_block = Image.fromarray(upsampled)
236
+ # Create grayscale condition
237
+
238
+ control_dict = {
239
+ "type": "color",
240
+ "spatial_images": [color_block],
241
+ "gammas": [color_strength]
242
+ }
243
+ else:
244
+ print("Apply edge control")
245
+ if fine_edge == "enable":
246
+ lineart_image = self.lineart_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), detect_resolution=1024, style="contour", output_type="pil")
247
+ lineart_output = self._pil_to_tensor(lineart_image)
248
+ else:
249
+ scribble_image = self.scribble_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), safe=True, resolution=512, output_type="pil")
250
+ lineart_output = self._pil_to_tensor(scribble_image)
251
+
252
+ if lineart_output is None:
253
+ raise ValueError("Preprocessor failed to generate lineart.")
254
+
255
+ # Apply user sketches to the lineart
256
+ add_mask_resized = F.interpolate(add_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0)
257
+ remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0)
258
+
259
+ bool_add_mask_resized = (add_mask_resized > 0.5)
260
+ bool_remove_mask_resized = (remove_mask_resized > 0.5)
261
+
262
+ lineart_output[bool_remove_mask_resized] = 0.0
263
+ lineart_output[bool_add_mask_resized] = 1.0
264
+
265
+ control_dict = {
266
+ "type": "edge",
267
+ "spatial_images": [self._tensor_to_pil(lineart_output)],
268
+ "gammas": [edge_strength]
269
+ }
270
+
271
+ # Prepare debug/output images
272
+ debug_image = lineart_output if lineart_output is not None else self.color_processor.execute(colored_image, resolution=1024)[0]
273
+
274
+ # Run inference
275
+ result_pil = self.pipe(
276
+ prompt=positive_prompt,
277
+ image=image_pil,
278
+ height=image_pil.height,
279
+ width=image_pil.width,
280
+ guidance_scale=cfg,
281
+ num_inference_steps=steps,
282
+ generator=generator,
283
+ max_sequence_length=128,
284
+ control_dict=control_dict,
285
+ ).images[0]
286
+
287
+ self.clear_cache()
288
+
289
+ # result_pil.save("result_pil.png")
290
+ result_tensor = self._pil_to_tensor(result_pil)
291
+ # final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
292
+ final_image = result_tensor
293
+ return (final_image, debug_image, original_mask)
294
+
295
+ def object_removal(self,
296
+ image, positive_prompt,
297
+ remove_mask,
298
+ local_strength,
299
+ seed, steps, cfg):
300
+
301
+ generator = torch.Generator(device=self.device).manual_seed(seed)
302
+
303
+ original_image_tensor = image.clone()
304
+ original_mask = remove_mask
305
+ original_mask = self._expand_mask(remove_mask, expand=25)
306
+
307
+ image_pil = self._tensor_to_pil(image)
308
+ # image_pil.save("image_pil.png")
309
+ # Prepare spatial image: original masked to black in the remove area
310
+ spatial_pil = self._apply_black_mask(image, original_mask)
311
+ # spatial_pil.save("spatial_pil.png")
312
+ # Note: mask is not passed to pipeline; we use it only for blending
313
+ control_dict = {
314
+ "type": "removal",
315
+ "spatial_images": [spatial_pil],
316
+ "gammas": [local_strength]
317
+ }
318
+
319
+ result_pil = self.pipe(
320
+ prompt=positive_prompt,
321
+ image=image_pil,
322
+ height=image_pil.height,
323
+ width=image_pil.width,
324
+ guidance_scale=cfg,
325
+ num_inference_steps=steps,
326
+ generator=generator,
327
+ control_dict=control_dict,
328
+ ).images[0]
329
+
330
+ self.clear_cache()
331
+
332
+ result_tensor = self._pil_to_tensor(result_pil)
333
+ final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
334
+ # final_image = result_tensor
335
+ return (final_image, self._pil_to_tensor(spatial_pil), original_mask)
336
+
337
+ def local_edit(self,
338
+ image, positive_prompt, fill_mask, local_strength,
339
+ seed, steps, cfg):
340
+ generator = torch.Generator(device=self.device).manual_seed(seed)
341
+ original_image_tensor = image.clone()
342
+ original_mask = self._expand_mask(fill_mask, expand=25)
343
+ image_pil = self._tensor_to_pil(image)
344
+ # image_pil.save("image_pil.png")
345
+
346
+ spatial_pil = self._apply_black_mask(image, original_mask)
347
+ # spatial_pil.save("spatial_pil.png")
348
+ control_dict = {
349
+ "type": "local",
350
+ "spatial_images": [spatial_pil],
351
+ "gammas": [local_strength]
352
+ }
353
+
354
+ result_pil = self.pipe(
355
+ prompt=positive_prompt,
356
+ image=image_pil,
357
+ height=image_pil.height,
358
+ width=image_pil.width,
359
+ guidance_scale=cfg,
360
+ num_inference_steps=steps,
361
+ generator=generator,
362
+ max_sequence_length=128,
363
+ control_dict=control_dict,
364
+ ).images[0]
365
+
366
+ self.clear_cache()
367
+ result_tensor = self._pil_to_tensor(result_pil)
368
+ final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0]
369
+ # final_image = result_tensor
370
+ return (final_image, self._pil_to_tensor(spatial_pil), original_mask)
371
+
372
+ def foreground_edit(self,
373
+ merged_image, positive_prompt,
374
+ add_prop_mask, fill_mask, fix_perspective, grow_size,
375
+ seed, steps, cfg):
376
+ generator = torch.Generator(device=self.device).manual_seed(seed)
377
+
378
+ edit_mask = torch.clamp(self._expand_mask(add_prop_mask, expand=grow_size) + fill_mask, 0.0, 1.0)
379
+ final_mask = self._expand_mask(edit_mask, expand=25)
380
+ if fix_perspective == "enable":
381
+ positive_prompt = positive_prompt + " Fix the perspective if necessary."
382
+ # Prepare edited input image: inside edit_mask but outside add_prop_mask set to white
383
+ img = merged_image.clone()
384
+ base_mask = (edit_mask > 0.5)
385
+ add_only = (add_prop_mask <= 0.5) & base_mask # [1, H, W] bool
386
+ add_only_3 = add_only.squeeze(0).unsqueeze(-1).expand(-1, -1, img.shape[-1]) # [H, W, 3]
387
+ img[0] = torch.where(add_only_3, torch.ones_like(img[0]), img[0])
388
+
389
+ image_pil = self._tensor_to_pil(img)
390
+ # image_pil.save("image_pil.png")
391
+
392
+ # Enable aux LoRA only for foreground
393
+ self._enable_aux_lora()
394
+
395
+ result_pil = self.pipe(
396
+ prompt=positive_prompt,
397
+ image=image_pil,
398
+ height=image_pil.height,
399
+ width=image_pil.width,
400
+ guidance_scale=cfg,
401
+ num_inference_steps=steps,
402
+ generator=generator,
403
+ max_sequence_length=128,
404
+ control_dict=None,
405
+ ).images[0]
406
+
407
+ # Disable aux LoRA afterwards
408
+ self._disable_aux_lora()
409
+
410
+ self.clear_cache()
411
+ final_image = self._pil_to_tensor(result_pil)
412
+ # final_image = self.blender.blend_inpaint(final_image, img, final_mask, kernel=10, sigma=10)[0]
413
+ return (final_image, self._pil_to_tensor(image_pil), edit_mask)
414
+
415
+ def kontext_edit(self,
416
+ image, positive_prompt,
417
+ seed, steps, cfg):
418
+ generator = torch.Generator(device=self.device).manual_seed(seed)
419
+ image_pil = self._tensor_to_pil(image)
420
+
421
+ result_pil = self.pipe(
422
+ prompt=positive_prompt,
423
+ image=image_pil,
424
+ height=image_pil.height,
425
+ width=image_pil.width,
426
+ guidance_scale=cfg,
427
+ num_inference_steps=steps,
428
+ generator=generator,
429
+ max_sequence_length=128,
430
+ control_dict=None,
431
+ ).images[0]
432
+
433
+ final_image = self._pil_to_tensor(result_pil)
434
+ mask = torch.zeros((1, final_image.shape[1], final_image.shape[2]), dtype=torch.float32, device=final_image.device)
435
+ return (final_image, image, mask)
436
+
437
+ def process(self, image, colored_image,
438
+ merged_image, positive_prompt,
439
+ total_mask, add_mask, remove_mask, add_prop_mask, fill_mask,
440
+ fine_edge, fix_perspective, edge_strength, color_strength, local_strength, grow_size,
441
+ seed, steps, cfg, flag="precise_edit"):
442
+ if flag == "foreground":
443
+ return self.foreground_edit(merged_image, positive_prompt, add_prop_mask, fill_mask, fix_perspective, grow_size, seed, steps, cfg)
444
+ elif flag == "local":
445
+ return self.local_edit(image, positive_prompt, fill_mask, local_strength, seed, steps, cfg)
446
+ elif flag == "removal":
447
+ return self.object_removal(image, positive_prompt, remove_mask, local_strength, seed, steps, cfg)
448
+ elif flag == "precise_edit":
449
+ return self.edge_edit(
450
+ image, colored_image, positive_prompt,
451
+ total_mask, add_mask, remove_mask,
452
+ fine_edge,
453
+ edge_strength, color_strength,
454
+ local_strength,
455
+ seed, steps, cfg
456
+ )
457
+ elif flag == "kontext":
458
+ return self.kontext_edit(image, positive_prompt, seed, steps, cfg)
459
+ else:
460
+ raise ValueError("Invalid Editing Type: {}".format(flag))
461
+
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ datasets
3
+ diffusers
4
+ easydict
5
+ einops
6
+ fastapi
7
+ gradio==5.4.0
8
+ gradio_client
9
+ huggingface_hub
10
+ numpy
11
+ opencv-python
12
+ peft
13
+ pillow
14
+ protobuf
15
+ requests
16
+ safetensors
17
+ scikit-image
18
+ scipy
19
+ git+https://github.com/facebookresearch/segment-anything.git
20
+ sentencepiece
21
+ spaces
22
+ torch
23
+ torchaudio
24
+ torchvision
25
+ tqdm
26
+ transformers
27
+ uvicorn
28
+ ./gradio_magicquillv2-0.0.1-py3-none-any.whl
src/__init__.py ADDED
File without changes
src/layers_cache.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union, Any, Dict
4
+ from einops import rearrange
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+
11
+ TXTLEN = 128
12
+ KONTEXT = False
13
+
14
+ class LoRALinearLayer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ out_features: int,
19
+ rank: int = 4,
20
+ network_alpha: Optional[float] = None,
21
+ device: Optional[Union[torch.device, str]] = None,
22
+ dtype: Optional[torch.dtype] = None,
23
+ cond_widths: Optional[List[int]] = None,
24
+ cond_heights: Optional[List[int]] = None,
25
+ lora_index: int = 0,
26
+ n_loras: int = 1,
27
+ ):
28
+ super().__init__()
29
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
30
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
31
+ self.network_alpha = network_alpha
32
+ self.rank = rank
33
+ self.out_features = out_features
34
+ self.in_features = in_features
35
+
36
+ nn.init.normal_(self.down.weight, std=1 / rank)
37
+ nn.init.zeros_(self.up.weight)
38
+
39
+ self.cond_heights = cond_heights if cond_heights is not None else [512]
40
+ self.cond_widths = cond_widths if cond_widths is not None else [512]
41
+ self.lora_index = lora_index
42
+ self.n_loras = n_loras
43
+
44
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
45
+ orig_dtype = hidden_states.dtype
46
+ dtype = self.down.weight.dtype
47
+
48
+ batch_size = hidden_states.shape[0]
49
+
50
+ cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
51
+ total_cond_size = sum(cond_sizes)
52
+ block_size = hidden_states.shape[1] - total_cond_size
53
+
54
+ offset = sum(cond_sizes[:self.lora_index])
55
+ current_cond_size = cond_sizes[self.lora_index]
56
+
57
+ shape = (batch_size, hidden_states.shape[1], 3072)
58
+ mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
59
+
60
+ mask[:, :block_size + offset, :] = 0
61
+ mask[:, block_size + offset + current_cond_size:, :] = 0
62
+
63
+ hidden_states = mask * hidden_states
64
+
65
+ down_hidden_states = self.down(hidden_states.to(dtype))
66
+ up_hidden_states = self.up(down_hidden_states)
67
+
68
+ if self.network_alpha is not None:
69
+ up_hidden_states *= self.network_alpha / self.rank
70
+
71
+ return up_hidden_states.to(orig_dtype)
72
+
73
+
74
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
75
+ def __init__(self, dim: int, ranks: List[int], lora_weights: List[float], network_alphas: List[float], device=None, dtype=None, cond_widths: Optional[List[int]] = None, cond_heights: Optional[List[int]] = None, n_loras=1):
76
+ super().__init__()
77
+ self.n_loras = n_loras
78
+ self.cond_widths = cond_widths if cond_widths is not None else [512]
79
+ self.cond_heights = cond_heights if cond_heights is not None else [512]
80
+
81
+ self.q_loras = nn.ModuleList([
82
+ LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
83
+ for i in range(n_loras)
84
+ ])
85
+ self.k_loras = nn.ModuleList([
86
+ LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
87
+ for i in range(n_loras)
88
+ ])
89
+ self.v_loras = nn.ModuleList([
90
+ LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras)
91
+ for i in range(n_loras)
92
+ ])
93
+ self.lora_weights = lora_weights
94
+ self.bank_attn = None
95
+ self.bank_kv: List[torch.Tensor] = []
96
+
97
+
98
+ def __call__(self,
99
+ attn: Attention,
100
+ hidden_states: torch.Tensor,
101
+ encoder_hidden_states: Optional[torch.Tensor] = None,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ image_rotary_emb: Optional[torch.Tensor] = None,
104
+ use_cond = False
105
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
106
+
107
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
108
+ scaled_seq_len = hidden_states.shape[1]
109
+
110
+ cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
111
+ total_cond_size = sum(cond_sizes)
112
+ block_size = scaled_seq_len - total_cond_size
113
+
114
+ scaled_cond_sizes = cond_sizes
115
+ scaled_block_size = block_size
116
+
117
+ global TXTLEN
118
+ global KONTEXT
119
+ if KONTEXT:
120
+ img_start, img_end = TXTLEN, (TXTLEN + block_size) // 2
121
+ else:
122
+ img_start, img_end = TXTLEN, block_size
123
+ cond_start, cond_end = block_size, scaled_seq_len
124
+
125
+ cache = len(self.bank_kv) == 0
126
+
127
+ if cache:
128
+ query = attn.to_q(hidden_states)
129
+ key = attn.to_k(hidden_states)
130
+ value = attn.to_v(hidden_states)
131
+ for i in range(self.n_loras):
132
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
133
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
134
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
135
+
136
+ inner_dim = key.shape[-1]
137
+ head_dim = inner_dim // attn.heads
138
+
139
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
140
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
141
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
142
+
143
+ self.bank_kv.extend([key[:, :, scaled_block_size:, :], value[:, :, scaled_block_size:, :]])
144
+
145
+ if attn.norm_q is not None: query = attn.norm_q(query)
146
+ if attn.norm_k is not None: key = attn.norm_k(key)
147
+
148
+ if image_rotary_emb is not None:
149
+ from diffusers.models.embeddings import apply_rotary_emb
150
+ query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
151
+
152
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
153
+ mask[ :scaled_block_size, :] = 0
154
+
155
+ current_offset = 0
156
+ for i in range(self.n_loras):
157
+ start, end = scaled_block_size + current_offset, scaled_block_size + current_offset + scaled_cond_sizes[i]
158
+ mask[start:end, start:end] = 0
159
+ current_offset += scaled_cond_sizes[i]
160
+
161
+ mask *= -1e20
162
+
163
+ c_factor = getattr(self, "c_factor", None)
164
+ if c_factor is not None:
165
+ # print(f"Using c_factor: {c_factor}")
166
+ current_offset = 0
167
+ for i in range(self.n_loras):
168
+ bias = torch.log(c_factor[i])
169
+ cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
170
+ mask[img_start:img_end, cond_i_start:cond_i_end] = bias
171
+ current_offset += scaled_cond_sizes[i]
172
+
173
+ # c_factor_kontext = getattr(self, "c_factor_kontext", None)
174
+ # if c_factor_kontext is not None:
175
+ # bias = torch.log(c_factor_kontext)
176
+ # kontext_start, kontext_end = img_end, block_size
177
+ # mask[img_start:img_end, kontext_start:kontext_end] = bias
178
+ # mask[kontext_start:kontext_end, img_start:img_end] = bias
179
+
180
+ # mask[kontext_start:kontext_end, kontext_end:] = -1e20
181
+
182
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask.to(query.dtype))
183
+ self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
184
+
185
+ else:
186
+ query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
187
+
188
+ inner_dim = query.shape[-1]
189
+ head_dim = inner_dim // attn.heads
190
+
191
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
192
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+
195
+ key = torch.cat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
196
+ value = torch.cat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
197
+
198
+ if attn.norm_q is not None: query = attn.norm_q(query)
199
+ if attn.norm_k is not None: key = attn.norm_k(key)
200
+
201
+ if image_rotary_emb is not None:
202
+ from diffusers.models.embeddings import apply_rotary_emb
203
+ query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
204
+
205
+ query = query[:, :, :scaled_block_size, :]
206
+
207
+ attn_mask = None
208
+ c_factor = getattr(self, "c_factor", None)
209
+ if c_factor is not None:
210
+ # print(f"Using c_factor: {c_factor}")
211
+ attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
212
+ current_offset = 0
213
+ for i in range(self.n_loras):
214
+ bias = torch.log(c_factor[i])
215
+ cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
216
+ attn_mask[img_start:img_end, cond_i_start:cond_i_end] = bias
217
+ current_offset += scaled_cond_sizes[i]
218
+
219
+ # c_factor_kontext = getattr(self, "c_factor_kontext", None)
220
+ # if c_factor_kontext is not None:
221
+ # if attn_mask is None:
222
+ # attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
223
+ # bias = torch.log(c_factor_kontext)
224
+ # kontext_start, kontext_end = img_end, block_size
225
+ # attn_mask[img_start:img_end, kontext_start:kontext_end] = bias
226
+ # attn_mask[kontext_start:kontext_end, img_start:img_end] = bias
227
+
228
+ # attn_mask[kontext_start:kontext_end, kontext_end:] = -1e20
229
+
230
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attn_mask)
231
+ if self.bank_attn is not None: hidden_states = torch.cat([hidden_states, self.bank_attn], dim=-2)
232
+
233
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
234
+ hidden_states = hidden_states.to(query.dtype)
235
+
236
+ cond_hidden_states = hidden_states[:, block_size:,:]
237
+ hidden_states = hidden_states[:, : block_size,:]
238
+
239
+ return (hidden_states, cond_hidden_states) if use_cond else hidden_states
240
+
241
+
242
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
243
+ def __init__(self, dim: int, ranks: List[int], lora_weights: List[float], network_alphas: List[float], device=None, dtype=None, cond_widths: Optional[List[int]] = None, cond_heights: Optional[List[int]] = None, n_loras=1):
244
+ super().__init__()
245
+
246
+ self.n_loras = n_loras
247
+ self.cond_widths = cond_widths if cond_widths is not None else [512]
248
+ self.cond_heights = cond_heights if cond_heights is not None else [512]
249
+ self.q_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
250
+ self.k_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
251
+ self.v_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
252
+ self.proj_loras = nn.ModuleList([LoRALinearLayer(dim, dim, ranks[i], network_alphas[i], device=device, dtype=dtype, cond_widths=self.cond_widths, cond_heights=self.cond_heights, lora_index=i, n_loras=n_loras) for i in range(n_loras)])
253
+ self.lora_weights = lora_weights
254
+ self.bank_attn = None
255
+ self.bank_kv: List[torch.Tensor] = []
256
+
257
+
258
+ def __call__(self,
259
+ attn: Attention,
260
+ hidden_states: torch.Tensor,
261
+ encoder_hidden_states: Optional[torch.Tensor] = None,
262
+ attention_mask: Optional[torch.Tensor] = None,
263
+ image_rotary_emb: Optional[torch.Tensor] = None,
264
+ use_cond=False,
265
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
266
+
267
+ global TXTLEN
268
+ global KONTEXT
269
+ TXTLEN = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 128
270
+
271
+ batch_size, _, _ = hidden_states.shape
272
+
273
+ cond_sizes = [(w // 8 * h // 8 * 16 // 64) for w, h in zip(self.cond_widths, self.cond_heights)]
274
+ block_size = hidden_states.shape[1] - sum(cond_sizes)
275
+
276
+ scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
277
+ scaled_cond_sizes = cond_sizes
278
+ scaled_block_size = scaled_seq_len - sum(scaled_cond_sizes)
279
+
280
+ if KONTEXT:
281
+ img_start, img_end = TXTLEN, (TXTLEN + block_size) // 2
282
+ else:
283
+ img_start, img_end = TXTLEN, block_size
284
+ cond_start, cond_end = scaled_block_size, scaled_seq_len
285
+
286
+ inner_dim, head_dim = 3072, 3072 // attn.heads
287
+
288
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
289
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
290
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
291
+
292
+ if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
293
+ if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
294
+
295
+ cache = len(self.bank_kv) == 0
296
+
297
+ if cache:
298
+ query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
299
+ for i in range(self.n_loras):
300
+ query, key, value = query + self.lora_weights[i] * self.q_loras[i](hidden_states), key + self.lora_weights[i] * self.k_loras[i](hidden_states), value + self.lora_weights[i] * self.v_loras[i](hidden_states)
301
+
302
+ query, key, value = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
303
+
304
+ self.bank_kv.extend([key[:, :, block_size:, :], value[:, :, block_size:, :]])
305
+
306
+ if attn.norm_q is not None: query = attn.norm_q(query)
307
+ if attn.norm_k is not None: key = attn.norm_k(key)
308
+
309
+ query, key, value = torch.cat([encoder_hidden_states_query_proj, query], dim=2), torch.cat([encoder_hidden_states_key_proj, key], dim=2), torch.cat([encoder_hidden_states_value_proj, value], dim=2)
310
+
311
+ if image_rotary_emb is not None:
312
+ from diffusers.models.embeddings import apply_rotary_emb
313
+ query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
314
+
315
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
316
+ mask[:scaled_block_size, :] = 0
317
+
318
+ current_offset = 0
319
+ for i in range(self.n_loras):
320
+ start, end = scaled_block_size + current_offset, scaled_block_size + current_offset + scaled_cond_sizes[i]
321
+ mask[start:end, start:end] = 0
322
+ current_offset += scaled_cond_sizes[i]
323
+
324
+ mask *= -1e20
325
+
326
+ c_factor = getattr(self, "c_factor", None)
327
+ if c_factor is not None:
328
+ # print(f"Using c_factor: {c_factor}")
329
+ current_offset = 0
330
+ for i in range(self.n_loras):
331
+ bias = torch.log(c_factor[i])
332
+ cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
333
+ mask[img_start:img_end, cond_i_start:cond_i_end] = bias
334
+ current_offset += scaled_cond_sizes[i]
335
+
336
+ # c_factor_kontext = getattr(self, "c_factor_kontext", None)
337
+ # if c_factor_kontext is not None:
338
+ # bias = torch.log(c_factor_kontext)
339
+ # kontext_start, kontext_end = img_end, block_size
340
+ # mask[img_start:img_end, kontext_start:kontext_end] = bias
341
+ # mask[kontext_start:kontext_end, img_start:img_end] = bias
342
+
343
+ # mask[kontext_start:kontext_end, kontext_end:] = -1e20
344
+
345
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask.to(query.dtype))
346
+ self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
347
+
348
+ else:
349
+ query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
350
+
351
+ query, key, value = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2), value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
352
+
353
+ key, value = torch.cat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2), torch.cat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
354
+
355
+ if attn.norm_q is not None: query = attn.norm_q(query)
356
+ if attn.norm_k is not None: key = attn.norm_k(key)
357
+
358
+ query, key, value = torch.cat([encoder_hidden_states_query_proj, query], dim=2), torch.cat([encoder_hidden_states_key_proj, key], dim=2), torch.cat([encoder_hidden_states_value_proj, value], dim=2)
359
+
360
+ if image_rotary_emb is not None:
361
+ from diffusers.models.embeddings import apply_rotary_emb
362
+ query, key = apply_rotary_emb(query, image_rotary_emb), apply_rotary_emb(key, image_rotary_emb)
363
+
364
+ query = query[:, :, :scaled_block_size, :]
365
+
366
+ attn_mask = None
367
+ c_factor = getattr(self, "c_factor", None)
368
+ if c_factor is not None:
369
+ # print(f"Using c_factor: {c_factor}")
370
+ attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
371
+ current_offset = 0
372
+ for i in range(self.n_loras):
373
+ bias = torch.log(c_factor[i])
374
+ cond_i_start, cond_i_end = cond_start + current_offset, cond_start + current_offset + scaled_cond_sizes[i]
375
+ attn_mask[img_start:img_end, cond_i_start:cond_i_end] = bias
376
+ current_offset += scaled_cond_sizes[i]
377
+
378
+ # c_factor_kontext = getattr(self, "c_factor_kontext", None)
379
+ # if c_factor_kontext is not None:
380
+ # if attn_mask is None:
381
+ # attn_mask = torch.zeros((query.shape[2], key.shape[2]), device=query.device, dtype=query.dtype)
382
+ # bias = torch.log(c_factor_kontext)
383
+ # kontext_start, kontext_end = img_end, block_size
384
+ # attn_mask[img_start:img_end, kontext_start:kontext_end] = bias
385
+ # attn_mask[kontext_start:kontext_end, img_start:img_end] = bias
386
+
387
+ # attn_mask[kontext_start:kontext_end, kontext_end:] = -1e20
388
+
389
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attn_mask)
390
+ if self.bank_attn is not None: hidden_states = torch.cat([hidden_states, self.bank_attn], dim=-2)
391
+
392
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
393
+ hidden_states = hidden_states.to(query.dtype)
394
+
395
+ encoder_hidden_states, hidden_states = hidden_states[:, :encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1]:]
396
+
397
+ hidden_states = attn.to_out[0](hidden_states)
398
+ for i in range(self.n_loras):
399
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
400
+ hidden_states = attn.to_out[1](hidden_states)
401
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
402
+
403
+ cond_hidden_states = hidden_states[:, block_size:,:]
404
+ hidden_states = hidden_states[:, :block_size,:]
405
+
406
+ return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
src/lora_helper.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors.torch import load_file
3
+ import re
4
+ import torch
5
+ from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
+
7
+ device = "cuda"
8
+
9
+ def load_safetensors(path):
10
+ """Safely loads tensors from a file and maps them to the CPU."""
11
+ return load_file(path, device="cpu")
12
+
13
+ def get_lora_count_from_checkpoint(checkpoint):
14
+ """
15
+ Infers the number of LoRA modules stored in a checkpoint by inspecting its keys.
16
+ Also prints a sample of keys for debugging.
17
+ """
18
+ lora_indices = set()
19
+ # Regex to find '..._loras.X.' where X is a number.
20
+ indexed_pattern = re.compile(r'._loras\.(\d+)\.')
21
+ found_keys = []
22
+
23
+ for key in checkpoint.keys():
24
+ match = indexed_pattern.search(key)
25
+ if match:
26
+ lora_indices.add(int(match.group(1)))
27
+ if len(found_keys) < 5 and key not in found_keys:
28
+ found_keys.append(key)
29
+
30
+ if lora_indices:
31
+ lora_count = max(lora_indices) + 1
32
+ print("INFO: Auto-detected indexed LoRA keys in checkpoint.")
33
+ print(f" Found {lora_count} LoRA module(s).")
34
+ print(" Sample keys:", found_keys)
35
+ return lora_count
36
+
37
+ # Fallback for legacy, non-indexed checkpoints.
38
+ legacy_found = False
39
+ legacy_key_sample = ""
40
+ for key in checkpoint.keys():
41
+ if '.q_lora.' in key:
42
+ legacy_found = True
43
+ legacy_key_sample = key
44
+ break
45
+
46
+ if legacy_found:
47
+ print("INFO: Auto-detected legacy (non-indexed) LoRA keys in checkpoint.")
48
+ print(" Assuming 1 LoRA module.")
49
+ print(" Sample key:", legacy_key_sample)
50
+ return 1
51
+
52
+ print("WARNING: No LoRA keys found in the checkpoint.")
53
+ return 0
54
+
55
+ def get_lora_ranks(checkpoint, num_loras):
56
+ """
57
+ Determines the rank for each LoRA module from the checkpoint.
58
+ It supports both indexed (e.g., 'loras.0') and legacy non-indexed formats.
59
+ """
60
+ ranks = {}
61
+
62
+ # First, try to find ranks for all indexed LoRA modules.
63
+ for i in range(num_loras):
64
+ # Find a key that uniquely identifies the i-th LoRA's down projection.
65
+ rank_pattern = re.compile(f'._loras\.({i})\.down\.weight')
66
+ for k, v in checkpoint.items():
67
+ if rank_pattern.search(k):
68
+ ranks[i] = v.shape[0]
69
+ break
70
+
71
+ # If not all ranks were found, there might be legacy keys or a mismatch.
72
+ if len(ranks) != num_loras:
73
+ # Fallback for single, non-indexed LoRA checkpoints.
74
+ if num_loras == 1:
75
+ for k, v in checkpoint.items():
76
+ if ".q_lora.down.weight" in k:
77
+ return [v.shape[0]]
78
+
79
+ # If still unresolved, use the rank of the very first LoRA found as a default for all.
80
+ first_found_rank = next((v.shape[0] for k, v in checkpoint.items() if k.endswith(".down.weight")), None)
81
+
82
+ if first_found_rank is None:
83
+ raise ValueError("Could not determine any LoRA rank from the provided checkpoint.")
84
+
85
+ # Return a list where missing ranks are filled with the first one found.
86
+ return [ranks.get(i, first_found_rank) for i in range(num_loras)]
87
+
88
+ # Return the list of ranks sorted by LoRA index.
89
+ return [ranks[i] for i in range(num_loras)]
90
+
91
+
92
+ def load_checkpoint(local_path):
93
+ if local_path is not None:
94
+ if '.safetensors' in local_path:
95
+ print(f"Loading .safetensors checkpoint from {local_path}")
96
+ checkpoint = load_safetensors(local_path)
97
+ else:
98
+ print(f"Loading checkpoint from {local_path}")
99
+ checkpoint = torch.load(local_path, map_location='cpu')
100
+ return checkpoint
101
+
102
+
103
+ def prepare_lora_processors(checkpoint, lora_weights, transformer, cond_size, number=None):
104
+ # Ensure processors match the transformer's device and dtype
105
+ try:
106
+ first_param = next(transformer.parameters())
107
+ target_device = first_param.device
108
+ target_dtype = first_param.dtype
109
+ except StopIteration:
110
+ target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+ target_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
112
+
113
+ if number is None:
114
+ number = get_lora_count_from_checkpoint(checkpoint)
115
+ if number == 0:
116
+ return {}
117
+
118
+ if lora_weights and len(lora_weights) != number:
119
+ print(f"WARNING: Provided `lora_weights` length ({len(lora_weights)}) differs from detected LoRA count ({number}).")
120
+ final_weights = (lora_weights + [1.0] * number)[:number]
121
+ print(f" Adjusting weights to: {final_weights}")
122
+ lora_weights = final_weights
123
+ elif not lora_weights:
124
+ print(f"INFO: No `lora_weights` provided. Defaulting to weights of 1.0 for all {number} LoRAs.")
125
+ lora_weights = [1.0] * number
126
+
127
+ ranks = get_lora_ranks(checkpoint, number)
128
+ print("INFO: Determined ranks for LoRA modules:", ranks)
129
+
130
+ cond_widths = cond_size if isinstance(cond_size, list) else [cond_size] * number
131
+ cond_heights = cond_size if isinstance(cond_size, list) else [cond_size] * number
132
+
133
+ lora_attn_procs = {}
134
+ double_blocks_idx = list(range(19))
135
+ single_blocks_idx = list(range(38))
136
+
137
+ # Get all attention processor names from the transformer to iterate over
138
+ for name in transformer.attn_processors.keys():
139
+ match = re.search(r'\.(\d+)\.', name)
140
+ if not match:
141
+ continue
142
+ layer_index = int(match.group(1))
143
+
144
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
145
+ lora_state_dicts = {
146
+ key: value for key, value in checkpoint.items()
147
+ if f"transformer_blocks.{layer_index}." in key
148
+ }
149
+
150
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
151
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
152
+ device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
153
+ )
154
+
155
+ for n in range(number):
156
+ lora_prefix_q = f"{name}.q_loras.{n}"
157
+ lora_prefix_k = f"{name}.k_loras.{n}"
158
+ lora_prefix_v = f"{name}.v_loras.{n}"
159
+ lora_prefix_proj = f"{name}.proj_loras.{n}"
160
+
161
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
162
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
163
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
164
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
165
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
166
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
167
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.down.weight')
168
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.up.weight')
169
+ lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
170
+
171
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
172
+ lora_state_dicts = {
173
+ key: value for key, value in checkpoint.items()
174
+ if f"single_transformer_blocks.{layer_index}." in key
175
+ }
176
+
177
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
178
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights,
179
+ device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number
180
+ )
181
+
182
+ for n in range(number):
183
+ lora_prefix_q = f"{name}.q_loras.{n}"
184
+ lora_prefix_k = f"{name}.k_loras.{n}"
185
+ lora_prefix_v = f"{name}.v_loras.{n}"
186
+
187
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight')
188
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight')
189
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight')
190
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight')
191
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight')
192
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight')
193
+ lora_attn_procs[name].to(device=target_device, dtype=target_dtype)
194
+ return lora_attn_procs
src/pipeline_flux_kontext_control.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import (
7
+ CLIPImageProcessor,
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ CLIPVisionModelWithProjection,
11
+ T5EncoderModel,
12
+ T5TokenizerFast,
13
+ )
14
+
15
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
16
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
17
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
18
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
+ from diffusers.utils import (
20
+ USE_PEFT_BACKEND,
21
+ is_torch_xla_available,
22
+ logging,
23
+ replace_example_docstring,
24
+ scale_lora_layers,
25
+ unscale_lora_layers,
26
+ )
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
30
+ from torchvision.transforms.functional import pad
31
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
32
+ from .lora_helper import prepare_lora_processors, load_checkpoint
33
+ from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
34
+ import re
35
+
36
+
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ PREFERRED_KONTEXT_RESOLUTIONS = [
48
+ (672, 1568),
49
+ (688, 1504),
50
+ (720, 1456),
51
+ (752, 1392),
52
+ (800, 1328),
53
+ (832, 1248),
54
+ (880, 1184),
55
+ (944, 1104),
56
+ (1024, 1024),
57
+ (1104, 944),
58
+ (1184, 880),
59
+ (1248, 832),
60
+ (1328, 800),
61
+ (1392, 752),
62
+ (1456, 720),
63
+ (1504, 688),
64
+ (1568, 672),
65
+ ]
66
+
67
+
68
+ def calculate_shift(
69
+ image_seq_len,
70
+ base_seq_len: int = 256,
71
+ max_seq_len: int = 4096,
72
+ base_shift: float = 0.5,
73
+ max_shift: float = 1.15,
74
+ ):
75
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
76
+ b = base_shift - m * base_seq_len
77
+ mu = image_seq_len * m + b
78
+ return mu
79
+
80
+
81
+ def prepare_latent_image_ids_(height, width, device, dtype):
82
+ latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
83
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y
84
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] # x
85
+ return latent_image_ids
86
+
87
+
88
+ def prepare_latent_subject_ids(height, width, device, dtype):
89
+ latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
90
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]
91
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]
92
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
93
+ latent_image_ids = latent_image_ids.reshape(
94
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
95
+ )
96
+ return latent_image_ids.to(device=device, dtype=dtype)
97
+
98
+
99
+ def resize_position_encoding(
100
+ batch_size, original_height, original_width, target_height, target_width, device, dtype
101
+ ):
102
+ latent_image_ids = prepare_latent_image_ids_(original_height // 2, original_width // 2, device, dtype)
103
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
104
+ latent_image_ids = latent_image_ids.reshape(
105
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
106
+ )
107
+
108
+ scale_h = original_height / target_height
109
+ scale_w = original_width / target_width
110
+ latent_image_ids_resized = torch.zeros(target_height // 2, target_width // 2, 3, device=device, dtype=dtype)
111
+ latent_image_ids_resized[..., 1] = (
112
+ latent_image_ids_resized[..., 1] + torch.arange(target_height // 2, device=device)[:, None] * scale_h
113
+ )
114
+ latent_image_ids_resized[..., 2] = (
115
+ latent_image_ids_resized[..., 2] + torch.arange(target_width // 2, device=device)[None, :] * scale_w
116
+ )
117
+
118
+ cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = (
119
+ latent_image_ids_resized.shape
120
+ )
121
+ cond_latent_image_ids = latent_image_ids_resized.reshape(
122
+ cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
123
+ )
124
+ return latent_image_ids, cond_latent_image_ids
125
+
126
+
127
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
128
+ def retrieve_timesteps(
129
+ scheduler,
130
+ num_inference_steps: Optional[int] = None,
131
+ device: Optional[Union[str, torch.device]] = None,
132
+ timesteps: Optional[List[int]] = None,
133
+ sigmas: Optional[List[float]] = None,
134
+ **kwargs,
135
+ ):
136
+ r"""
137
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
138
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
139
+
140
+ Args:
141
+ scheduler (`SchedulerMixin`):
142
+ The scheduler to get timesteps from.
143
+ num_inference_steps (`int`):
144
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
145
+ must be `None`.
146
+ device (`str` or `torch.device`, *optional*):
147
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
148
+ timesteps (`List[int]`, *optional*):
149
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
150
+ `num_inference_steps` and `sigmas` must be `None`.
151
+ sigmas (`List[float]`, *optional*):
152
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
153
+ `num_inference_steps` and `timesteps` must be `None`.
154
+
155
+ Returns:
156
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
157
+ second element is the number of inference steps.
158
+ """
159
+ if timesteps is not None and sigmas is not None:
160
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
161
+ if timesteps is not None:
162
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
163
+ if not accepts_timesteps:
164
+ raise ValueError(
165
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
166
+ f" timestep schedules. Please check whether you are using the correct scheduler."
167
+ )
168
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
169
+ timesteps = scheduler.timesteps
170
+ num_inference_steps = len(timesteps)
171
+ elif sigmas is not None:
172
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
173
+ if not accept_sigmas:
174
+ raise ValueError(
175
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
176
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
177
+ )
178
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
179
+ timesteps = scheduler.timesteps
180
+ num_inference_steps = len(timesteps)
181
+ else:
182
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
183
+ timesteps = scheduler.timesteps
184
+ return timesteps, num_inference_steps
185
+
186
+
187
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
188
+ def retrieve_latents(
189
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
190
+ ):
191
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
192
+ return encoder_output.latent_dist.sample(generator)
193
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
194
+ return encoder_output.latent_dist.mode()
195
+ elif hasattr(encoder_output, "latents"):
196
+ return encoder_output.latents
197
+ else:
198
+ raise AttributeError("Could not access latents of provided encoder_output")
199
+
200
+
201
+ class FluxKontextControlPipeline(
202
+ DiffusionPipeline,
203
+ FluxLoraLoaderMixin,
204
+ FromSingleFileMixin,
205
+ TextualInversionLoaderMixin,
206
+ ):
207
+ r"""
208
+ The Flux Kontext pipeline for image-to-image and text-to-image generation with control module.
209
+
210
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
211
+
212
+ Args:
213
+ transformer ([`FluxTransformer2DModel`]):
214
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
215
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
216
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
217
+ vae ([`AutoencoderKL`]):
218
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
219
+ text_encoder ([`CLIPTextModel`]):
220
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
221
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
222
+ text_encoder_2 ([`T5EncoderModel`]):
223
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
224
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
225
+ tokenizer (`CLIPTokenizer`):
226
+ Tokenizer of class
227
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
228
+ tokenizer_2 (`T5TokenizerFast`):
229
+ Second Tokenizer of class
230
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
231
+ """
232
+
233
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
234
+ _optional_components = []
235
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
236
+
237
+ def __init__(
238
+ self,
239
+ scheduler: FlowMatchEulerDiscreteScheduler,
240
+ vae: AutoencoderKL,
241
+ text_encoder: CLIPTextModel,
242
+ tokenizer: CLIPTokenizer,
243
+ text_encoder_2: T5EncoderModel,
244
+ tokenizer_2: T5TokenizerFast,
245
+ transformer: FluxTransformer2DModel,
246
+ image_encoder: CLIPVisionModelWithProjection = None,
247
+ feature_extractor: CLIPImageProcessor = None,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.register_modules(
252
+ vae=vae,
253
+ text_encoder=text_encoder,
254
+ text_encoder_2=text_encoder_2,
255
+ tokenizer=tokenizer,
256
+ tokenizer_2=tokenizer_2,
257
+ transformer=transformer,
258
+ scheduler=scheduler,
259
+ image_encoder=None,
260
+ feature_extractor=None,
261
+ )
262
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
263
+ # Flux latents are packed into 2x2 patches, so use VAE factor multiplied by patch size for image processing
264
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
265
+ self.tokenizer_max_length = (
266
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
267
+ )
268
+ self.default_sample_size = 128
269
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
270
+ self.control_lora_processors: Dict[str, Dict[str, Any]] = {}
271
+ self.control_lora_cond_sizes: Dict[str, Any] = {}
272
+ self.control_lora_weights: Dict[str, Any] = {}
273
+ self.current_control_type: Optional[Union[str, List[str]]] = None
274
+
275
+ def load_control_loras(self, lora_config: Dict[str, Dict[str, Any]]):
276
+ """
277
+ Loads and prepares LoRA attention processors for different control types.
278
+ Args:
279
+ lora_config: A dict where keys are control types (e.g., 'edge') and values are dicts
280
+ containing 'path', 'lora_weights', and 'cond_size'.
281
+ """
282
+ for control_type, config in lora_config.items():
283
+ print(f"Loading LoRA for control type: {control_type}")
284
+ checkpoint = load_checkpoint(config["path"])
285
+ processors = prepare_lora_processors(
286
+ checkpoint=checkpoint,
287
+ lora_weights=config["lora_weights"],
288
+ transformer=self.transformer,
289
+ cond_size=config["cond_size"],
290
+ number=len(config["lora_weights"]) if config.get("lora_weights") is not None else None,
291
+ )
292
+ self.control_lora_processors[control_type] = processors
293
+ self.control_lora_cond_sizes[control_type] = config["cond_size"]
294
+ self.control_lora_weights[control_type] = config["lora_weights"]
295
+ print("All control LoRAs loaded and prepared.")
296
+
297
+ def _combine_control_loras(self, control_types: List[str]):
298
+ """
299
+ Combines multiple control LoRAs into a single set of attention processors.
300
+ """
301
+ if not control_types:
302
+ return FluxAttnProcessor2_0()
303
+
304
+ try:
305
+ first_param = next(self.transformer.parameters())
306
+ target_device = first_param.device
307
+ target_dtype = first_param.dtype
308
+ except StopIteration:
309
+ target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
310
+ target_dtype = torch.float32
311
+
312
+ combined_procs = {}
313
+ # LoRA weights must come from configuration, not from gammas (which control strength)
314
+ all_lora_weights = []
315
+
316
+ # Determine total number of LoRAs and ranks across all control types
317
+ total_loras = 0
318
+ all_ranks = []
319
+ all_cond_sizes = []
320
+
321
+ for control_type in control_types:
322
+ procs = self.control_lora_processors.get(control_type)
323
+ if not procs:
324
+ raise ValueError(f"Control type '{control_type}' not loaded.")
325
+ # Collect configured LoRA weights for this control type
326
+ conf_weights = self.control_lora_weights.get(control_type)
327
+ if conf_weights is None:
328
+ raise ValueError(f"Control type '{control_type}' has no configured lora_weights.")
329
+ all_lora_weights.extend(conf_weights)
330
+
331
+ # Get n_loras from the first processor
332
+ first_proc = next(iter(procs.values()))
333
+ n_loras_in_control = first_proc.n_loras
334
+ total_loras += n_loras_in_control
335
+
336
+ # Correctly get ranks from the processor's LoRA layers
337
+ proc_ranks = [lora.down.weight.shape[0] for lora in first_proc.q_loras]
338
+ all_ranks.extend(proc_ranks)
339
+
340
+ cond_size = self.control_lora_cond_sizes[control_type]
341
+ cond_sizes = [cond_size] * n_loras_in_control if not isinstance(cond_size, list) else cond_size
342
+ all_cond_sizes.extend(cond_sizes)
343
+
344
+ for name in self.transformer.attn_processors.keys():
345
+ match = re.search(r'\.(\d+)\.', name)
346
+ if not match:
347
+ continue
348
+ layer_index = int(match.group(1))
349
+
350
+ if name.startswith("transformer_blocks"):
351
+ new_proc = MultiDoubleStreamBlockLoraProcessor(
352
+ dim=3072, ranks=all_ranks, network_alphas=all_ranks, lora_weights=all_lora_weights,
353
+ device=target_device, dtype=target_dtype,
354
+ cond_widths=all_cond_sizes, cond_heights=all_cond_sizes, n_loras=total_loras
355
+ )
356
+ elif name.startswith("single_transformer_blocks"):
357
+ new_proc = MultiSingleStreamBlockLoraProcessor(
358
+ dim=3072, ranks=all_ranks, network_alphas=all_ranks, lora_weights=all_lora_weights,
359
+ device=target_device, dtype=target_dtype,
360
+ cond_widths=all_cond_sizes, cond_heights=all_cond_sizes, n_loras=total_loras
361
+ )
362
+ else:
363
+ continue
364
+
365
+ lora_idx_offset = 0
366
+ for control_type in control_types:
367
+ source_proc = self.control_lora_processors[control_type][name]
368
+ for i in range(source_proc.n_loras):
369
+ current_lora_idx = lora_idx_offset + i
370
+ # Copy weights for q, k, v, proj
371
+ new_proc.q_loras[current_lora_idx].load_state_dict(source_proc.q_loras[i].state_dict())
372
+ new_proc.k_loras[current_lora_idx].load_state_dict(source_proc.k_loras[i].state_dict())
373
+ new_proc.v_loras[current_lora_idx].load_state_dict(source_proc.v_loras[i].state_dict())
374
+ if hasattr(new_proc, 'proj_loras'):
375
+ new_proc.proj_loras[current_lora_idx].load_state_dict(source_proc.proj_loras[i].state_dict())
376
+
377
+ lora_idx_offset += source_proc.n_loras
378
+
379
+ combined_procs[name] = new_proc.to(device=target_device, dtype=target_dtype)
380
+
381
+ return combined_procs
382
+
383
+ def set_gamma_values(self, gammas: List[float]):
384
+ """
385
+ Set gamma values for bias control modulation on current attention processors and attention modules.
386
+ """
387
+ print(f"Setting gamma values to: {gammas}")
388
+ # Resolve device/dtype robustly from model parameters
389
+ try:
390
+ first_param = next(self.transformer.parameters())
391
+ device = first_param.device
392
+ dtype = first_param.dtype
393
+ except StopIteration:
394
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
395
+ dtype = torch.float32
396
+ gamma_tensor = torch.tensor(gammas, device=device, dtype=dtype)
397
+ for name, attn_processor in self.transformer.attn_processors.items():
398
+ if hasattr(attn_processor, 'q_loras'):
399
+ setattr(attn_processor, 'c_factor', gamma_tensor)
400
+ # print(f" Set c_factor {gamma_tensor} on processor {name}")
401
+
402
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
403
+ def _get_t5_prompt_embeds(
404
+ self,
405
+ prompt: Union[str, List[str]] = None,
406
+ num_images_per_prompt: int = 1,
407
+ max_sequence_length: int = 512,
408
+ device: Optional[torch.device] = None,
409
+ dtype: Optional[torch.dtype] = None,
410
+ ):
411
+ device = device or self._execution_device
412
+ dtype = dtype or self.text_encoder.dtype
413
+
414
+ prompt = [prompt] if isinstance(prompt, str) else prompt
415
+ batch_size = len(prompt)
416
+
417
+ if isinstance(self, TextualInversionLoaderMixin):
418
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
419
+
420
+ text_inputs = self.tokenizer_2(
421
+ prompt,
422
+ padding="max_length",
423
+ max_length=max_sequence_length,
424
+ truncation=True,
425
+ return_length=False,
426
+ return_overflowing_tokens=False,
427
+ return_tensors="pt",
428
+ )
429
+ text_input_ids = text_inputs.input_ids
430
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
431
+
432
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
433
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
434
+ logger.warning(
435
+ "The following part of your input was truncated because `max_sequence_length` is set to "
436
+ f" {max_sequence_length} tokens: {removed_text}"
437
+ )
438
+
439
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
440
+
441
+ dtype = self.text_encoder_2.dtype
442
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
443
+
444
+ _, seq_len, _ = prompt_embeds.shape
445
+
446
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
447
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
448
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
449
+
450
+ return prompt_embeds
451
+
452
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
453
+ def _get_clip_prompt_embeds(
454
+ self,
455
+ prompt: Union[str, List[str]],
456
+ num_images_per_prompt: int = 1,
457
+ device: Optional[torch.device] = None,
458
+ ):
459
+ device = device or self._execution_device
460
+
461
+ prompt = [prompt] if isinstance(prompt, str) else prompt
462
+ batch_size = len(prompt)
463
+
464
+ if isinstance(self, TextualInversionLoaderMixin):
465
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
466
+
467
+ text_inputs = self.tokenizer(
468
+ prompt,
469
+ padding="max_length",
470
+ max_length=self.tokenizer_max_length,
471
+ truncation=True,
472
+ return_overflowing_tokens=False,
473
+ return_length=False,
474
+ return_tensors="pt",
475
+ )
476
+
477
+ text_input_ids = text_inputs.input_ids
478
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
479
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
480
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
481
+ logger.warning(
482
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
483
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
484
+ )
485
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
486
+
487
+ # Use pooled output of CLIPTextModel
488
+ prompt_embeds = prompt_embeds.pooler_output
489
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
490
+
491
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
492
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
493
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
494
+
495
+ return prompt_embeds
496
+
497
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
498
+ def encode_prompt(
499
+ self,
500
+ prompt: Union[str, List[str]],
501
+ prompt_2: Union[str, List[str]],
502
+ device: Optional[torch.device] = None,
503
+ num_images_per_prompt: int = 1,
504
+ prompt_embeds: Optional[torch.FloatTensor] = None,
505
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
506
+ max_sequence_length: int = 512,
507
+ lora_scale: Optional[float] = None,
508
+ ):
509
+ r"""
510
+
511
+ Args:
512
+ prompt (`str` or `List[str]`, *optional*):
513
+ prompt to be encoded
514
+ prompt_2 (`str` or `List[str]`, *optional*):
515
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
516
+ used in all text-encoders
517
+ device: (`torch.device`):
518
+ torch device
519
+ num_images_per_prompt (`int`):
520
+ number of images that should be generated per prompt
521
+ prompt_embeds (`torch.FloatTensor`, *optional*):
522
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
523
+ provided, text embeddings will be generated from `prompt` input argument.
524
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
525
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
526
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
527
+ lora_scale (`float`, *optional*):
528
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
529
+ """
530
+ device = device or self._execution_device
531
+
532
+ # set lora scale so that monkey patched LoRA
533
+ # function of text encoder can correctly access it
534
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
535
+ self._lora_scale = lora_scale
536
+
537
+ # dynamically adjust the LoRA scale
538
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
539
+ scale_lora_layers(self.text_encoder, lora_scale)
540
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
541
+ scale_lora_layers(self.text_encoder_2, lora_scale)
542
+
543
+ prompt = [prompt] if isinstance(prompt, str) else prompt
544
+
545
+ if prompt_embeds is None:
546
+ prompt_2 = prompt_2 or prompt
547
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
548
+
549
+ # We only use the pooled prompt output from the CLIPTextModel
550
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
551
+ prompt=prompt,
552
+ device=device,
553
+ num_images_per_prompt=num_images_per_prompt,
554
+ )
555
+ prompt_embeds = self._get_t5_prompt_embeds(
556
+ prompt=prompt_2,
557
+ num_images_per_prompt=num_images_per_prompt,
558
+ max_sequence_length=max_sequence_length,
559
+ device=device,
560
+ )
561
+
562
+ if self.text_encoder is not None:
563
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
564
+ # Retrieve the original scale by scaling back the LoRA layers
565
+ unscale_lora_layers(self.text_encoder, lora_scale)
566
+
567
+ if self.text_encoder_2 is not None:
568
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
569
+ # Retrieve the original scale by scaling back the LoRA layers
570
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
571
+
572
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
573
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
574
+
575
+ return prompt_embeds, pooled_prompt_embeds, text_ids
576
+
577
+ # Adapted from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
578
+ def check_inputs(
579
+ self,
580
+ prompt,
581
+ prompt_2,
582
+ height,
583
+ width,
584
+ prompt_embeds=None,
585
+ pooled_prompt_embeds=None,
586
+ callback_on_step_end_tensor_inputs=None,
587
+ max_sequence_length=None,
588
+ ):
589
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
590
+ raise ValueError(
591
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
592
+ )
593
+
594
+ if callback_on_step_end_tensor_inputs is not None and not all(
595
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
596
+ ):
597
+ raise ValueError(
598
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
599
+ )
600
+
601
+ if prompt is not None and prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
604
+ " only forward one of the two."
605
+ )
606
+ elif prompt_2 is not None and prompt_embeds is not None:
607
+ raise ValueError(
608
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
609
+ " only forward one of the two."
610
+ )
611
+ elif prompt is None and prompt_embeds is None:
612
+ raise ValueError(
613
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
614
+ )
615
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
616
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
617
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
618
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
619
+
620
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
621
+ raise ValueError(
622
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
623
+ )
624
+
625
+ if max_sequence_length is not None and max_sequence_length > 512:
626
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
627
+
628
+ @staticmethod
629
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
630
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
631
+ latent_image_ids = torch.zeros(height, width, 3)
632
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
633
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
634
+
635
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
636
+
637
+ latent_image_ids = latent_image_ids.reshape(
638
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
639
+ )
640
+
641
+ return latent_image_ids.to(device=device, dtype=dtype)
642
+
643
+ @staticmethod
644
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
645
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
646
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
647
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
648
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
649
+
650
+ return latents
651
+
652
+ @staticmethod
653
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
654
+ def _unpack_latents(latents, height, width, vae_scale_factor):
655
+ batch_size, num_patches, channels = latents.shape
656
+
657
+ # VAE applies 8x compression on images but we must also account for packing which requires
658
+ # latent height and width to be divisible by 2.
659
+ height = 2 * (int(height) // (vae_scale_factor * 2))
660
+ width = 2 * (int(width) // (vae_scale_factor * 2))
661
+
662
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
663
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
664
+
665
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
666
+
667
+ return latents
668
+
669
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
670
+ if isinstance(generator, list):
671
+ image_latents = [
672
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
673
+ for i in range(image.shape[0])
674
+ ]
675
+ image_latents = torch.cat(image_latents, dim=0)
676
+ else:
677
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
678
+
679
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
680
+
681
+ return image_latents
682
+
683
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
684
+ def enable_vae_slicing(self):
685
+ r"""
686
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
687
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
688
+ """
689
+ self.vae.enable_slicing()
690
+
691
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
692
+ def disable_vae_slicing(self):
693
+ r"""
694
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
695
+ computing decoding in one step.
696
+ """
697
+ self.vae.disable_slicing()
698
+
699
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
700
+ def enable_vae_tiling(self):
701
+ r"""
702
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
703
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
704
+ processing larger images.
705
+ """
706
+ self.vae.enable_tiling()
707
+
708
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
709
+ def disable_vae_tiling(self):
710
+ r"""
711
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
712
+ computing decoding in one step.
713
+ """
714
+ self.vae.disable_tiling()
715
+
716
+ def prepare_latents(
717
+ self,
718
+ batch_size,
719
+ num_channels_latents,
720
+ height,
721
+ width,
722
+ dtype,
723
+ device,
724
+ generator,
725
+ image,
726
+ subject_images,
727
+ spatial_images,
728
+ latents=None,
729
+ cond_size=512,
730
+ num_subject_images: int = 0,
731
+ num_spatial_images: int = 0,
732
+ ):
733
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
734
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
735
+ height_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
736
+ width_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
737
+
738
+ image_latents = image_ids = None
739
+ image_latent_h = 0 # Initialize to handle case where image is None
740
+
741
+ # Prepare noise latents
742
+ shape = (batch_size, num_channels_latents, height, width)
743
+ if latents is None:
744
+ noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
745
+ else:
746
+ noise_latents = latents.to(device=device, dtype=dtype)
747
+
748
+ noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
749
+ # print(noise_latents.shape)
750
+ noise_latent_image_ids, cond_latent_image_ids_resized = resize_position_encoding(
751
+ batch_size, height, width, height_cond, width_cond, device, dtype
752
+ )
753
+ # noise IDs are marked with 0 in the first channel
754
+ noise_latent_image_ids[..., 0] = 0
755
+
756
+ cond_latents_to_concat = []
757
+ latents_ids_to_concat = [noise_latent_image_ids]
758
+
759
+ # 1. Prepare `image` (Kontext) latents
760
+ if image is not None:
761
+ image = image.to(device=device, dtype=dtype)
762
+ if image.shape[1] != self.latent_channels:
763
+ image_latents = self._encode_vae_image(image=image, generator=generator)
764
+ else:
765
+ image_latents = image
766
+
767
+ image_latent_h, image_latent_w = image_latents.shape[2:]
768
+ image_latents = self._pack_latents(
769
+ image_latents, batch_size, num_channels_latents, image_latent_h, image_latent_w
770
+ )
771
+ image_ids = self._prepare_latent_image_ids(
772
+ batch_size, image_latent_h // 2, image_latent_w // 2, device, dtype
773
+ )
774
+ image_ids[..., 0] = 1 # Mark as condition
775
+ latents_ids_to_concat.append(image_ids)
776
+
777
+ # 2. Prepare `subject_images` latents
778
+ if subject_images is not None and num_subject_images > 0:
779
+ subject_images = subject_images.to(device=device, dtype=dtype)
780
+ subject_image_latents = self._encode_vae_image(image=subject_images, generator=generator)
781
+ subject_latent_h, subject_latent_w = subject_image_latents.shape[2:]
782
+ subject_latents = self._pack_latents(
783
+ subject_image_latents, batch_size, num_channels_latents, subject_latent_h, subject_latent_w
784
+ )
785
+
786
+ latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, device, dtype)
787
+ latent_subject_ids[..., 0] = 1
788
+ latent_subject_ids[:, 1] += image_latent_h // 2
789
+ subject_latent_image_ids = torch.cat([latent_subject_ids for _ in range(num_subject_images)], dim=0)
790
+
791
+ cond_latents_to_concat.append(subject_latents)
792
+ latents_ids_to_concat.append(subject_latent_image_ids)
793
+
794
+ # 3. Prepare `spatial_images` latents
795
+ if spatial_images is not None and num_spatial_images > 0:
796
+ spatial_images = spatial_images.to(device=device, dtype=dtype)
797
+ spatial_image_latents = self._encode_vae_image(image=spatial_images, generator=generator)
798
+ spatial_latent_h, spatial_latent_w = spatial_image_latents.shape[2:]
799
+ cond_latents = self._pack_latents(
800
+ spatial_image_latents, batch_size, num_channels_latents, spatial_latent_h, spatial_latent_w
801
+ )
802
+ cond_latent_image_ids_resized[..., 0] = 2 # Mark as condition
803
+ cond_latent_image_ids = torch.cat(
804
+ [cond_latent_image_ids_resized for _ in range(num_spatial_images)], dim=0
805
+ )
806
+
807
+ cond_latents_to_concat.append(cond_latents)
808
+ latents_ids_to_concat.append(cond_latent_image_ids)
809
+
810
+ cond_latents = torch.cat(cond_latents_to_concat, dim=1) if cond_latents_to_concat else None
811
+ latent_image_ids = torch.cat(latents_ids_to_concat, dim=0)
812
+
813
+ return noise_latents, image_latents, cond_latents, latent_image_ids
814
+
815
+ @property
816
+ def guidance_scale(self):
817
+ return self._guidance_scale
818
+
819
+ @property
820
+ def joint_attention_kwargs(self):
821
+ return self._joint_attention_kwargs
822
+
823
+ @property
824
+ def num_timesteps(self):
825
+ return self._num_timesteps
826
+
827
+ @property
828
+ def current_timestep(self):
829
+ return self._current_timestep
830
+
831
+ @property
832
+ def interrupt(self):
833
+ return self._interrupt
834
+
835
+ @torch.no_grad()
836
+ def __call__(
837
+ self,
838
+ image: Optional[PipelineImageInput] = None,
839
+ prompt: Union[str, List[str]] = None,
840
+ prompt_2: Optional[Union[str, List[str]]] = None,
841
+ height: Optional[int] = None,
842
+ width: Optional[int] = None,
843
+ num_inference_steps: int = 28,
844
+ sigmas: Optional[List[float]] = None,
845
+ guidance_scale: float = 3.5,
846
+ num_images_per_prompt: Optional[int] = 1,
847
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
848
+ latents: Optional[torch.FloatTensor] = None,
849
+ prompt_embeds: Optional[torch.FloatTensor] = None,
850
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
851
+ output_type: Optional[str] = "pil",
852
+ return_dict: bool = True,
853
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
854
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
855
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
856
+ max_sequence_length: int = 512,
857
+ cond_size: int = 512,
858
+ control_dict: Optional[Dict[str, Any]] = None,
859
+ ):
860
+ r"""
861
+ Function invoked when calling the pipeline for generation.
862
+
863
+ Args:
864
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
865
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
866
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
867
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
868
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
869
+ latents as `image`, but if passing latents directly it is not encoded again.
870
+ prompt (`str` or `List[str]`, *optional*):
871
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
872
+ instead.
873
+ prompt_2 (`str` or `List[str]`, *optional*):
874
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
875
+ will be used instead.
876
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
877
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
878
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
879
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
880
+ num_inference_steps (`int`, *optional*, defaults to 50):
881
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
882
+ expense of slower inference.
883
+ sigmas (`List[float]`, *optional*):
884
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
885
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
886
+ will be used.
887
+ guidance_scale (`float`, *optional*, defaults to 3.5):
888
+ Guidance scale as defined in [Classifier-Free Diffusion
889
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
890
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
891
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
892
+ the text `prompt`, usually at the expense of lower image quality.
893
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
894
+ The number of images to generate per prompt.
895
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
896
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
897
+ to make generation deterministic.
898
+ latents (`torch.FloatTensor`, *optional*):
899
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
900
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
901
+ tensor will ge generated by sampling using the supplied random `generator`.
902
+ prompt_embeds (`torch.FloatTensor`, *optional*):
903
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
904
+ provided, text embeddings will be generated from `prompt` input argument.
905
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
906
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
907
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
908
+ output_type (`str`, *optional*, defaults to `"pil"`):
909
+ The output format of the generate image. Choose between
910
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
911
+ return_dict (`bool`, *optional*, defaults to `True`):
912
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
913
+ joint_attention_kwargs (`dict`, *optional*):
914
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
915
+ `self.processor` in
916
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
917
+ callback_on_step_end (`Callable`, *optional*):
918
+ A function that calls at the end of each denoising steps during the inference. The function is called
919
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
920
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
921
+ `callback_on_step_end_tensor_inputs`.
922
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
923
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
924
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
925
+ `._callback_tensor_inputs` attribute of your pipeline class.
926
+ max_sequence_length (`int` defaults to 512):
927
+ Maximum sequence length to use with the `prompt`.
928
+ cond_size (`int`, *optional*, defaults to 512):
929
+ The size for conditioning images.
930
+
931
+ Examples:
932
+
933
+ Returns:
934
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
935
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
936
+ images.
937
+ """
938
+
939
+ height = height or self.default_sample_size * self.vae_scale_factor
940
+ width = width or self.default_sample_size * self.vae_scale_factor
941
+
942
+ # 1. Check inputs. Raise error if not correct
943
+ self.check_inputs(
944
+ prompt,
945
+ prompt_2,
946
+ height,
947
+ width,
948
+ prompt_embeds=prompt_embeds,
949
+ pooled_prompt_embeds=pooled_prompt_embeds,
950
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
951
+ max_sequence_length=max_sequence_length,
952
+ )
953
+
954
+ self._guidance_scale = guidance_scale
955
+ self._joint_attention_kwargs = joint_attention_kwargs
956
+ self._current_timestep = None
957
+ self._interrupt = False
958
+
959
+ # Normalize control_dict to an empty dict so kontext-only inference works without controls
960
+ control_dict = control_dict or {}
961
+
962
+ spatial_images = control_dict.get("spatial_images", [])
963
+ num_spatial_images = len(spatial_images)
964
+ subject_images = control_dict.get("subject_images", [])
965
+ num_subject_images = len(subject_images)
966
+
967
+ requested_control_type = control_dict.get("type") or None
968
+
969
+ # Normalize to list for unified handling
970
+ if requested_control_type and isinstance(requested_control_type, str):
971
+ requested_control_type = [requested_control_type]
972
+
973
+ # Revert to default if no control type is requested and a control is active
974
+ if not requested_control_type and self.current_control_type:
975
+ print("Reverting to default attention processors.")
976
+ self.transformer.set_attn_processor(FluxAttnProcessor2_0())
977
+ self.current_control_type = None
978
+ # Switch processors only if the control type(s) have changed
979
+ elif requested_control_type != self.current_control_type:
980
+ if requested_control_type:
981
+ print(f"Switching to LoRA control type(s): {requested_control_type}")
982
+ processors = self._combine_control_loras(requested_control_type)
983
+ self.transformer.set_attn_processor(processors)
984
+ # For cond_size, we assume they are compatible and just use the first one.
985
+ self.cond_size = self.control_lora_cond_sizes[requested_control_type[0]]
986
+ self.current_control_type = requested_control_type
987
+
988
+ # Align cond_size to selected control type (if any)
989
+ if hasattr(self, "cond_size"):
990
+ selected_cond_size = self.cond_size
991
+ if isinstance(selected_cond_size, list) and len(selected_cond_size) > 0:
992
+ cond_size = int(selected_cond_size[0])
993
+ elif isinstance(selected_cond_size, int):
994
+ cond_size = selected_cond_size
995
+
996
+ # Set gamma values simply based on provided control_dict['gammas'].
997
+ if requested_control_type:
998
+ raw_gammas = control_dict.get("gammas", [])
999
+ if not isinstance(raw_gammas, list):
1000
+ raw_gammas = [raw_gammas]
1001
+ # flatten one level
1002
+ flattened_gammas: List[float] = []
1003
+ for g in raw_gammas:
1004
+ if isinstance(g, (list, tuple)):
1005
+ flattened_gammas.extend([float(x) for x in g])
1006
+ else:
1007
+ flattened_gammas.append(float(g))
1008
+ if len(flattened_gammas) > 0:
1009
+ self.set_gamma_values(flattened_gammas)
1010
+
1011
+ # 2. Define call parameters
1012
+ if prompt is not None and isinstance(prompt, str):
1013
+ batch_size = 1
1014
+ elif prompt is not None and isinstance(prompt, list):
1015
+ batch_size = len(prompt)
1016
+ else:
1017
+ batch_size = prompt_embeds.shape[0]
1018
+
1019
+ device = self._execution_device
1020
+
1021
+ lora_scale = (
1022
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1023
+ )
1024
+ (
1025
+ prompt_embeds,
1026
+ pooled_prompt_embeds,
1027
+ text_ids,
1028
+ ) = self.encode_prompt(
1029
+ prompt=prompt,
1030
+ prompt_2=prompt_2,
1031
+ prompt_embeds=prompt_embeds,
1032
+ pooled_prompt_embeds=pooled_prompt_embeds,
1033
+ device=device,
1034
+ num_images_per_prompt=num_images_per_prompt,
1035
+ max_sequence_length=max_sequence_length,
1036
+ lora_scale=lora_scale,
1037
+ )
1038
+
1039
+ # 3. Preprocess images
1040
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
1041
+ img = image[0] if isinstance(image, list) else image
1042
+ image_height, image_width = self.image_processor.get_default_height_width(img)
1043
+ aspect_ratio = image_width / image_height
1044
+ # Kontext is trained on specific resolutions, using one of them is recommended
1045
+ _, image_width, image_height = min(
1046
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
1047
+ )
1048
+ multiple_of = self.vae_scale_factor * 2
1049
+ image_width = image_width // multiple_of * multiple_of
1050
+ image_height = image_height // multiple_of * multiple_of
1051
+ image = self.image_processor.resize(image, image_height, image_width)
1052
+ image = self.image_processor.preprocess(image, image_height, image_width)
1053
+
1054
+ if len(subject_images) > 0:
1055
+ subject_image_ls = []
1056
+ for subject_image in subject_images:
1057
+ w, h = subject_image.size[:2]
1058
+ scale = cond_size / max(h, w)
1059
+ new_h, new_w = int(h * scale), int(w * scale)
1060
+ subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
1061
+ subject_image = subject_image.to(dtype=self.vae.dtype)
1062
+ pad_h = cond_size - subject_image.shape[-2]
1063
+ pad_w = cond_size - subject_image.shape[-1]
1064
+ subject_image = pad(
1065
+ subject_image, padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)), fill=0
1066
+ )
1067
+ subject_image_ls.append(subject_image)
1068
+ subject_images = torch.cat(subject_image_ls, dim=-2)
1069
+ else:
1070
+ subject_images = None
1071
+
1072
+ if len(spatial_images) > 0:
1073
+ condition_image_ls = []
1074
+ for img in spatial_images:
1075
+ condition_image = self.image_processor.preprocess(img, height=cond_size, width=cond_size)
1076
+ condition_image = condition_image.to(dtype=self.vae.dtype)
1077
+ condition_image_ls.append(condition_image)
1078
+ spatial_images = torch.cat(condition_image_ls, dim=-2)
1079
+ else:
1080
+ spatial_images = None
1081
+
1082
+ # 4. Prepare latent variables
1083
+ num_channels_latents = self.transformer.config.in_channels // 4
1084
+ latents, image_latents, cond_latents, latent_image_ids = self.prepare_latents(
1085
+ batch_size * num_images_per_prompt,
1086
+ num_channels_latents,
1087
+ height,
1088
+ width,
1089
+ prompt_embeds.dtype,
1090
+ device,
1091
+ generator,
1092
+ image,
1093
+ subject_images,
1094
+ spatial_images,
1095
+ latents,
1096
+ cond_size,
1097
+ num_subject_images=num_subject_images,
1098
+ num_spatial_images=num_spatial_images,
1099
+ )
1100
+
1101
+ # 5. Prepare timesteps
1102
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1103
+ # sigmas = np.array([1.0000, 0.9836, 0.9660, 0.9471, 0.9266, 0.9045, 0.8805, 0.8543, 0.8257, 0.7942, 0.7595, 0.7210, 0.6780, 0.6297, 0.5751, 0.5128, 0.4412, 0.3579, 0.2598, 0.1425])
1104
+ image_seq_len = latents.shape[1]
1105
+ mu = calculate_shift(
1106
+ image_seq_len,
1107
+ self.scheduler.config.get("base_image_seq_len", 256),
1108
+ self.scheduler.config.get("max_image_seq_len", 4096),
1109
+ self.scheduler.config.get("base_shift", 0.5),
1110
+ self.scheduler.config.get("max_shift", 1.15),
1111
+ )
1112
+ timesteps, num_inference_steps = retrieve_timesteps(
1113
+ self.scheduler,
1114
+ num_inference_steps,
1115
+ device,
1116
+ sigmas=sigmas,
1117
+ mu=mu,
1118
+ )
1119
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1120
+ self._num_timesteps = len(timesteps)
1121
+
1122
+ # handle guidance
1123
+ if self.transformer.config.guidance_embeds:
1124
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1125
+ guidance = guidance.expand(latents.shape[0])
1126
+ else:
1127
+ guidance = None
1128
+
1129
+ if self.joint_attention_kwargs is None:
1130
+ self._joint_attention_kwargs = {}
1131
+
1132
+ # K/V Caching
1133
+ for name, attn_processor in self.transformer.attn_processors.items():
1134
+ if hasattr(attn_processor, "bank_kv"):
1135
+ attn_processor.bank_kv.clear()
1136
+ if hasattr(attn_processor, "bank_attn"):
1137
+ attn_processor.bank_attn = None
1138
+
1139
+ if cond_latents is not None:
1140
+ latent_model_input = latents
1141
+ if image_latents is not None:
1142
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
1143
+ print(latent_model_input.shape)
1144
+ warmup_latents = latent_model_input
1145
+ warmup_latent_ids = latent_image_ids
1146
+ t = torch.tensor([timesteps[0]], device=device)
1147
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1148
+ _ = self.transformer(
1149
+ hidden_states=warmup_latents,
1150
+ cond_hidden_states=cond_latents,
1151
+ timestep=timestep / 1000,
1152
+ guidance=guidance,
1153
+ pooled_projections=pooled_prompt_embeds,
1154
+ encoder_hidden_states=prompt_embeds,
1155
+ txt_ids=text_ids,
1156
+ img_ids=warmup_latent_ids,
1157
+ joint_attention_kwargs=self.joint_attention_kwargs,
1158
+ return_dict=False,
1159
+ )[0]
1160
+
1161
+ # 6. Denoising loop
1162
+ self.scheduler.set_begin_index(0)
1163
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1164
+ for i, t in enumerate(timesteps):
1165
+ if self.interrupt:
1166
+ continue
1167
+
1168
+ latent_model_input = latents
1169
+ if image_latents is not None:
1170
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
1171
+
1172
+ self._current_timestep = t
1173
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1174
+ noise_pred = self.transformer(
1175
+ hidden_states=latent_model_input,
1176
+ cond_hidden_states=cond_latents,
1177
+ timestep=timestep / 1000,
1178
+ guidance=guidance,
1179
+ pooled_projections=pooled_prompt_embeds,
1180
+ encoder_hidden_states=prompt_embeds,
1181
+ txt_ids=text_ids,
1182
+ img_ids=latent_image_ids,
1183
+ joint_attention_kwargs=self.joint_attention_kwargs,
1184
+ return_dict=False,
1185
+ )[0]
1186
+
1187
+ noise_pred = noise_pred[:, : latents.size(1)]
1188
+
1189
+ # compute the previous noisy sample x_t -> x_t-1
1190
+ latents_dtype = latents.dtype
1191
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1192
+
1193
+ if latents.dtype != latents_dtype:
1194
+ if torch.backends.mps.is_available():
1195
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1196
+ latents = latents.to(latents_dtype)
1197
+
1198
+ if callback_on_step_end is not None:
1199
+ callback_kwargs = {}
1200
+ for k in callback_on_step_end_tensor_inputs:
1201
+ callback_kwargs[k] = locals()[k]
1202
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1203
+
1204
+ latents = callback_outputs.pop("latents", latents)
1205
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1206
+
1207
+ # call the callback, if provided
1208
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1209
+ progress_bar.update()
1210
+
1211
+ if XLA_AVAILABLE:
1212
+ xm.mark_step()
1213
+
1214
+ self._current_timestep = None
1215
+
1216
+ if output_type == "latent":
1217
+ image = latents
1218
+ else:
1219
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1220
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1221
+ image = self.vae.decode(latents, return_dict=False)[0]
1222
+ image = self.image_processor.postprocess(image, output_type=output_type)
1223
+
1224
+ # Offload all models
1225
+ self.maybe_free_model_hooks()
1226
+
1227
+ if not return_dict:
1228
+ return (image,)
1229
+
1230
+ return FluxPipelineOutput(images=image)
src/transformer_flux.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ FluxAttnProcessor2_0,
15
+ FluxAttnProcessor2_0_NPU,
16
+ FusedFluxAttnProcessor2_0,
17
+ )
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
21
+ from diffusers.utils.import_utils import is_torch_npu_available
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ @maybe_allow_in_graph
29
+ class FluxSingleTransformerBlock(nn.Module):
30
+
31
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
32
+ super().__init__()
33
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
34
+
35
+ self.norm = AdaLayerNormZeroSingle(dim)
36
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
37
+ self.act_mlp = nn.GELU(approximate="tanh")
38
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
39
+
40
+ if is_torch_npu_available():
41
+ processor = FluxAttnProcessor2_0_NPU()
42
+ else:
43
+ processor = FluxAttnProcessor2_0()
44
+ self.attn = Attention(
45
+ query_dim=dim,
46
+ cross_attention_dim=None,
47
+ dim_head=attention_head_dim,
48
+ heads=num_attention_heads,
49
+ out_dim=dim,
50
+ bias=True,
51
+ processor=processor,
52
+ qk_norm="rms_norm",
53
+ eps=1e-6,
54
+ pre_only=True,
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ cond_hidden_states: torch.Tensor,
61
+ temb: torch.Tensor,
62
+ cond_temb: torch.Tensor,
63
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
64
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ ) -> torch.Tensor:
66
+ use_cond = cond_hidden_states is not None
67
+
68
+ residual = hidden_states
69
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
70
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
71
+
72
+ if use_cond:
73
+ residual_cond = cond_hidden_states
74
+ norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
75
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
76
+ norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
77
+ else:
78
+ norm_hidden_states_concat = norm_hidden_states
79
+
80
+ joint_attention_kwargs = joint_attention_kwargs or {}
81
+ if use_cond:
82
+ attn_output = self.attn(
83
+ hidden_states=norm_hidden_states_concat,
84
+ image_rotary_emb=image_rotary_emb,
85
+ use_cond=use_cond,
86
+ **joint_attention_kwargs,
87
+ )
88
+ else:
89
+ attn_output = self.attn(
90
+ hidden_states=norm_hidden_states_concat,
91
+ image_rotary_emb=image_rotary_emb,
92
+ **joint_attention_kwargs,
93
+ )
94
+ if use_cond:
95
+ attn_output, cond_attn_output = attn_output
96
+
97
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
98
+ gate = gate.unsqueeze(1)
99
+ hidden_states = gate * self.proj_out(hidden_states)
100
+ hidden_states = residual + hidden_states
101
+
102
+ if use_cond:
103
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
104
+ cond_gate = cond_gate.unsqueeze(1)
105
+ condition_latents = cond_gate * self.proj_out(condition_latents)
106
+ condition_latents = residual_cond + condition_latents
107
+
108
+ if hidden_states.dtype == torch.float16:
109
+ hidden_states = hidden_states.clip(-65504, 65504)
110
+
111
+ return hidden_states, condition_latents if use_cond else None
112
+
113
+
114
+ @maybe_allow_in_graph
115
+ class FluxTransformerBlock(nn.Module):
116
+ def __init__(
117
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
118
+ ):
119
+ super().__init__()
120
+
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ self.norm1_context = AdaLayerNormZero(dim)
124
+
125
+ if hasattr(F, "scaled_dot_product_attention"):
126
+ processor = FluxAttnProcessor2_0()
127
+ else:
128
+ raise ValueError(
129
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
130
+ )
131
+ self.attn = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=None,
134
+ added_kv_proj_dim=dim,
135
+ dim_head=attention_head_dim,
136
+ heads=num_attention_heads,
137
+ out_dim=dim,
138
+ context_pre_only=False,
139
+ bias=True,
140
+ processor=processor,
141
+ qk_norm=qk_norm,
142
+ eps=eps,
143
+ )
144
+
145
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
146
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
147
+
148
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
149
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
150
+
151
+ # let chunk size default to None
152
+ self._chunk_size = None
153
+ self._chunk_dim = 0
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states: torch.Tensor,
158
+ cond_hidden_states: torch.Tensor,
159
+ encoder_hidden_states: torch.Tensor,
160
+ temb: torch.Tensor,
161
+ cond_temb: torch.Tensor,
162
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
163
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
164
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
165
+ use_cond = cond_hidden_states is not None
166
+
167
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
168
+ if use_cond:
169
+ (
170
+ norm_cond_hidden_states,
171
+ cond_gate_msa,
172
+ cond_shift_mlp,
173
+ cond_scale_mlp,
174
+ cond_gate_mlp,
175
+ ) = self.norm1(cond_hidden_states, emb=cond_temb)
176
+ norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
177
+
178
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
179
+ encoder_hidden_states, emb=temb
180
+ )
181
+ joint_attention_kwargs = joint_attention_kwargs or {}
182
+ # Attention.
183
+ if use_cond:
184
+ attention_outputs = self.attn(
185
+ hidden_states=norm_hidden_states,
186
+ encoder_hidden_states=norm_encoder_hidden_states,
187
+ image_rotary_emb=image_rotary_emb,
188
+ use_cond=use_cond,
189
+ **joint_attention_kwargs,
190
+ )
191
+ else:
192
+ attention_outputs = self.attn(
193
+ hidden_states=norm_hidden_states,
194
+ encoder_hidden_states=norm_encoder_hidden_states,
195
+ image_rotary_emb=image_rotary_emb,
196
+ **joint_attention_kwargs,
197
+ )
198
+
199
+ attn_output, context_attn_output = attention_outputs[:2]
200
+ cond_attn_output = attention_outputs[2] if use_cond else None
201
+
202
+ # Process attention outputs for the `hidden_states`.
203
+ attn_output = gate_msa.unsqueeze(1) * attn_output
204
+ hidden_states = hidden_states + attn_output
205
+
206
+ if use_cond:
207
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
208
+ cond_hidden_states = cond_hidden_states + cond_attn_output
209
+
210
+ norm_hidden_states = self.norm2(hidden_states)
211
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
212
+
213
+ if use_cond:
214
+ norm_cond_hidden_states = self.norm2(cond_hidden_states)
215
+ norm_cond_hidden_states = (
216
+ norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
217
+ + cond_shift_mlp[:, None]
218
+ )
219
+
220
+ ff_output = self.ff(norm_hidden_states)
221
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
222
+ hidden_states = hidden_states + ff_output
223
+
224
+ if use_cond:
225
+ cond_ff_output = self.ff(norm_cond_hidden_states)
226
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
227
+ cond_hidden_states = cond_hidden_states + cond_ff_output
228
+
229
+ # Process attention outputs for the `encoder_hidden_states`.
230
+
231
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
232
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
233
+
234
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
235
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
236
+
237
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
238
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
239
+ if encoder_hidden_states.dtype == torch.float16:
240
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
241
+
242
+ return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
243
+
244
+
245
+ class FluxTransformer2DModel(
246
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
247
+ ):
248
+ _supports_gradient_checkpointing = True
249
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
250
+
251
+ @register_to_config
252
+ def __init__(
253
+ self,
254
+ patch_size: int = 1,
255
+ in_channels: int = 64,
256
+ out_channels: Optional[int] = None,
257
+ num_layers: int = 19,
258
+ num_single_layers: int = 38,
259
+ attention_head_dim: int = 128,
260
+ num_attention_heads: int = 24,
261
+ joint_attention_dim: int = 4096,
262
+ pooled_projection_dim: int = 768,
263
+ guidance_embeds: bool = False,
264
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
265
+ ):
266
+ super().__init__()
267
+ self.out_channels = out_channels or in_channels
268
+ self.inner_dim = num_attention_heads * attention_head_dim
269
+
270
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
271
+
272
+ text_time_guidance_cls = (
273
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
274
+ )
275
+ self.time_text_embed = text_time_guidance_cls(
276
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
277
+ )
278
+
279
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
280
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
281
+
282
+ self.transformer_blocks = nn.ModuleList(
283
+ [
284
+ FluxTransformerBlock(
285
+ dim=self.inner_dim,
286
+ num_attention_heads=num_attention_heads,
287
+ attention_head_dim=attention_head_dim,
288
+ )
289
+ for _ in range(num_layers)
290
+ ]
291
+ )
292
+
293
+ self.single_transformer_blocks = nn.ModuleList(
294
+ [
295
+ FluxSingleTransformerBlock(
296
+ dim=self.inner_dim,
297
+ num_attention_heads=num_attention_heads,
298
+ attention_head_dim=attention_head_dim,
299
+ )
300
+ for _ in range(num_single_layers)
301
+ ]
302
+ )
303
+
304
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
305
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
306
+
307
+ self.gradient_checkpointing = False
308
+
309
+ @property
310
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
311
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
312
+ r"""
313
+ Returns:
314
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
315
+ indexed by its weight name.
316
+ """
317
+ # set recursively
318
+ processors = {}
319
+
320
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
321
+ if hasattr(module, "get_processor"):
322
+ processors[f"{name}.processor"] = module.get_processor()
323
+
324
+ for sub_name, child in module.named_children():
325
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
326
+
327
+ return processors
328
+
329
+ for name, module in self.named_children():
330
+ fn_recursive_add_processors(name, module, processors)
331
+
332
+ return processors
333
+
334
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
335
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
336
+ r"""
337
+ Sets the attention processor to use to compute attention.
338
+
339
+ Parameters:
340
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
341
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
342
+ for **all** `Attention` layers.
343
+
344
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
345
+ processor. This is strongly recommended when setting trainable attention processors.
346
+
347
+ """
348
+ count = len(self.attn_processors.keys())
349
+
350
+ if isinstance(processor, dict) and len(processor) != count:
351
+ raise ValueError(
352
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
353
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
354
+ )
355
+
356
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
357
+ if hasattr(module, "set_processor"):
358
+ if not isinstance(processor, dict):
359
+ module.set_processor(processor)
360
+ else:
361
+ module.set_processor(processor.pop(f"{name}.processor"))
362
+
363
+ for sub_name, child in module.named_children():
364
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
365
+
366
+ # Make a copy of the processor dictionary to avoid destructive changes to the original.
367
+ if isinstance(processor, dict):
368
+ processor = processor.copy()
369
+
370
+ for name, module in self.named_children():
371
+ fn_recursive_attn_processor(name, module, processor)
372
+
373
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
374
+ def fuse_qkv_projections(self):
375
+ """
376
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
377
+ are fused. For cross-attention modules, key and value projection matrices are fused.
378
+
379
+ <Tip warning={true}>
380
+
381
+ This API is 🧪 experimental.
382
+
383
+ </Tip>
384
+ """
385
+ self.original_attn_processors = None
386
+
387
+ for _, attn_processor in self.attn_processors.items():
388
+ if "Added" in str(attn_processor.__class__.__name__):
389
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
390
+
391
+ self.original_attn_processors = self.attn_processors
392
+
393
+ for module in self.modules():
394
+ if isinstance(module, Attention):
395
+ module.fuse_projections(fuse=True)
396
+
397
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
398
+
399
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
400
+ def unfuse_qkv_projections(self):
401
+ """Disables the fused QKV projection if enabled.
402
+
403
+ <Tip warning={true}>
404
+
405
+ This API is 🧪 experimental.
406
+
407
+ </Tip>
408
+
409
+ """
410
+ if self.original_attn_processors is not None:
411
+ self.set_attn_processor(self.original_attn_processors)
412
+
413
+ def _set_gradient_checkpointing(self, module, value=False):
414
+ if hasattr(module, "gradient_checkpointing"):
415
+ module.gradient_checkpointing = value
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states: torch.Tensor,
420
+ cond_hidden_states: torch.Tensor = None,
421
+ encoder_hidden_states: torch.Tensor = None,
422
+ pooled_projections: torch.Tensor = None,
423
+ timestep: torch.LongTensor = None,
424
+ img_ids: torch.Tensor = None,
425
+ txt_ids: torch.Tensor = None,
426
+ guidance: torch.Tensor = None,
427
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
428
+ controlnet_block_samples=None,
429
+ controlnet_single_block_samples=None,
430
+ return_dict: bool = True,
431
+ controlnet_blocks_repeat: bool = False,
432
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
433
+ if cond_hidden_states is not None:
434
+ use_condition = True
435
+ else:
436
+ use_condition = False
437
+
438
+ if joint_attention_kwargs is not None:
439
+ joint_attention_kwargs = joint_attention_kwargs.copy()
440
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
441
+ else:
442
+ lora_scale = 1.0
443
+
444
+ if USE_PEFT_BACKEND:
445
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
446
+ scale_lora_layers(self, lora_scale)
447
+ else:
448
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
449
+ logger.warning(
450
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
451
+ )
452
+ hidden_states = self.x_embedder(hidden_states)
453
+ if cond_hidden_states is not None:
454
+ if cond_hidden_states.shape[-1] == self.x_embedder.in_features:
455
+ cond_hidden_states = self.x_embedder(cond_hidden_states)
456
+ elif cond_hidden_states.shape[-1] == 64:
457
+ # 只用前64列权重和bias
458
+ weight = self.x_embedder.weight[:, :64] # [inner_dim, 64]
459
+ bias = self.x_embedder.bias
460
+ cond_hidden_states = torch.nn.functional.linear(cond_hidden_states, weight, bias)
461
+
462
+ timestep = timestep.to(hidden_states.dtype) * 1000
463
+ if guidance is not None:
464
+ guidance = guidance.to(hidden_states.dtype) * 1000
465
+ else:
466
+ guidance = None
467
+
468
+ temb = (
469
+ self.time_text_embed(timestep, pooled_projections)
470
+ if guidance is None
471
+ else self.time_text_embed(timestep, guidance, pooled_projections)
472
+ )
473
+
474
+ cond_temb = (
475
+ self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
476
+ if guidance is None
477
+ else self.time_text_embed(
478
+ torch.ones_like(timestep) * 0, guidance, pooled_projections
479
+ )
480
+ )
481
+
482
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
483
+
484
+
485
+ if txt_ids.ndim == 3:
486
+ logger.warning(
487
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
488
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
489
+ )
490
+ txt_ids = txt_ids[0]
491
+ if img_ids.ndim == 3:
492
+ logger.warning(
493
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
494
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
495
+ )
496
+ img_ids = img_ids[0]
497
+
498
+ ids = torch.cat((txt_ids, img_ids), dim=0)
499
+ image_rotary_emb = self.pos_embed(ids)
500
+
501
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
502
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
503
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
504
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
505
+
506
+ for index_block, block in enumerate(self.transformer_blocks):
507
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
508
+
509
+ def create_custom_forward(module, return_dict=None):
510
+ def custom_forward(*inputs):
511
+ if return_dict is not None:
512
+ return module(*inputs, return_dict=return_dict)
513
+ else:
514
+ return module(*inputs)
515
+
516
+ return custom_forward
517
+
518
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
519
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
520
+ create_custom_forward(block),
521
+ hidden_states,
522
+ encoder_hidden_states,
523
+ temb,
524
+ image_rotary_emb,
525
+ cond_temb=cond_temb if use_condition else None,
526
+ cond_hidden_states=cond_hidden_states if use_condition else None,
527
+ **ckpt_kwargs,
528
+ )
529
+
530
+ else:
531
+ encoder_hidden_states, hidden_states, cond_hidden_states = block(
532
+ hidden_states=hidden_states,
533
+ encoder_hidden_states=encoder_hidden_states,
534
+ cond_hidden_states=cond_hidden_states if use_condition else None,
535
+ temb=temb,
536
+ cond_temb=cond_temb if use_condition else None,
537
+ image_rotary_emb=image_rotary_emb,
538
+ joint_attention_kwargs=joint_attention_kwargs,
539
+ )
540
+
541
+ # controlnet residual
542
+ if controlnet_block_samples is not None:
543
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
544
+ interval_control = int(np.ceil(interval_control))
545
+ # For Xlabs ControlNet.
546
+ if controlnet_blocks_repeat:
547
+ hidden_states = (
548
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
549
+ )
550
+ else:
551
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
552
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
553
+
554
+ for index_block, block in enumerate(self.single_transformer_blocks):
555
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
556
+
557
+ def create_custom_forward(module, return_dict=None):
558
+ def custom_forward(*inputs):
559
+ if return_dict is not None:
560
+ return module(*inputs, return_dict=return_dict)
561
+ else:
562
+ return module(*inputs)
563
+
564
+ return custom_forward
565
+
566
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
567
+ hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
568
+ create_custom_forward(block),
569
+ hidden_states,
570
+ temb,
571
+ image_rotary_emb,
572
+ cond_temb=cond_temb if use_condition else None,
573
+ cond_hidden_states=cond_hidden_states if use_condition else None,
574
+ **ckpt_kwargs,
575
+ )
576
+
577
+ else:
578
+ hidden_states, cond_hidden_states = block(
579
+ hidden_states=hidden_states,
580
+ cond_hidden_states=cond_hidden_states if use_condition else None,
581
+ temb=temb,
582
+ cond_temb=cond_temb if use_condition else None,
583
+ image_rotary_emb=image_rotary_emb,
584
+ joint_attention_kwargs=joint_attention_kwargs,
585
+ )
586
+
587
+ # controlnet residual
588
+ if controlnet_single_block_samples is not None:
589
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
590
+ interval_control = int(np.ceil(interval_control))
591
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
592
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
593
+ + controlnet_single_block_samples[index_block // interval_control]
594
+ )
595
+
596
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
597
+
598
+ hidden_states = self.norm_out(hidden_states, temb)
599
+ output = self.proj_out(hidden_states)
600
+
601
+ if USE_PEFT_BACKEND:
602
+ # remove `lora_scale` from each PEFT layer
603
+ unscale_lora_layers(self, lora_scale)
604
+
605
+ if not return_dict:
606
+ return (output,)
607
+
608
+ return Transformer2DModelOutput(sample=output)
train/default_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ main_process_port: 14121
5
+ downcast_bf16: 'no'
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
train/src/__init__.py ADDED
File without changes
train/src/condition/edge_extraction.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ import os
9
+
10
+ from einops import rearrange
11
+
12
+ from .util import HWC3, nms, safe_step, resize_image_with_pad, common_input_validate, get_intensity_mask, combine_layers
13
+
14
+ from .pidi import pidinet
15
+ from .ted import TED
16
+ from .lineart import Generator as LineartGenerator
17
+ from .informative_drawing import Generator
18
+ from .hed import ControlNetHED_Apache2
19
+
20
+ from pathlib import Path
21
+
22
+ from skimage import morphology
23
+ import argparse
24
+ from tqdm import tqdm
25
+
26
+
27
+ PREPROCESSORS_ROOT = os.getenv("PREPROCESSORS_ROOT", os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), "models/preprocessors"))
28
+
29
+
30
+ class HEDDetector:
31
+ def __init__(self, netNetwork):
32
+ self.netNetwork = netNetwork
33
+ self.device = "cpu"
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, filename="ControlNetHED.pth"):
37
+ model_path = os.path.join(PREPROCESSORS_ROOT, filename)
38
+
39
+ netNetwork = ControlNetHED_Apache2()
40
+ netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
41
+ netNetwork.float().eval()
42
+
43
+ return cls(netNetwork)
44
+
45
+ def to(self, device):
46
+ self.netNetwork.to(device)
47
+ self.device = device
48
+ return self
49
+
50
+
51
+ def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, upscale_method="INTER_CUBIC", **kwargs):
52
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
53
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
54
+
55
+ assert input_image.ndim == 3
56
+ H, W, C = input_image.shape
57
+ with torch.no_grad():
58
+ image_hed = torch.from_numpy(input_image).float().to(self.device)
59
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
60
+ edges = self.netNetwork(image_hed)
61
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
62
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
63
+ edges = np.stack(edges, axis=2)
64
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
65
+ if safe:
66
+ edge = safe_step(edge)
67
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
68
+
69
+ detected_map = edge
70
+
71
+ if scribble:
72
+ detected_map = nms(detected_map, 127, 3.0)
73
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
74
+ detected_map[detected_map > 4] = 255
75
+ detected_map[detected_map < 255] = 0
76
+
77
+ detected_map = HWC3(remove_pad(detected_map))
78
+
79
+ if output_type == "pil":
80
+ detected_map = Image.fromarray(detected_map)
81
+
82
+ return detected_map
83
+
84
+
85
+ class CannyDetector:
86
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
87
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
88
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
89
+ detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
90
+ detected_map = HWC3(remove_pad(detected_map))
91
+
92
+ if output_type == "pil":
93
+ detected_map = Image.fromarray(detected_map)
94
+
95
+ return detected_map
96
+
97
+ class PidiNetDetector:
98
+ def __init__(self, netNetwork):
99
+ self.netNetwork = netNetwork
100
+ self.device = "cpu"
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, filename="table5_pidinet.pth"):
104
+ model_path = os.path.join(PREPROCESSORS_ROOT, filename)
105
+
106
+ netNetwork = pidinet()
107
+ netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})
108
+ netNetwork.eval()
109
+
110
+ return cls(netNetwork)
111
+
112
+ def to(self, device):
113
+ self.netNetwork.to(device)
114
+ self.device = device
115
+ return self
116
+
117
+ def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, apply_filter=False, upscale_method="INTER_CUBIC", **kwargs):
118
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
119
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
120
+
121
+ detected_map = detected_map[:, :, ::-1].copy()
122
+ with torch.no_grad():
123
+ image_pidi = torch.from_numpy(detected_map).float().to(self.device)
124
+ image_pidi = image_pidi / 255.0
125
+ image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
126
+ edge = self.netNetwork(image_pidi)[-1]
127
+ edge = edge.cpu().numpy()
128
+ if apply_filter:
129
+ edge = edge > 0.5
130
+ if safe:
131
+ edge = safe_step(edge)
132
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
133
+
134
+ detected_map = edge[0, 0]
135
+
136
+ if scribble:
137
+ detected_map = nms(detected_map, 127, 3.0)
138
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
139
+ detected_map[detected_map > 4] = 255
140
+ detected_map[detected_map < 255] = 0
141
+
142
+ detected_map = HWC3(remove_pad(detected_map))
143
+
144
+ if output_type == "pil":
145
+ detected_map = Image.fromarray(detected_map)
146
+
147
+ return detected_map
148
+
149
+ class TEDDetector:
150
+ def __init__(self, model):
151
+ self.model = model
152
+ self.device = "cpu"
153
+
154
+ @classmethod
155
+ def from_pretrained(cls, filename="7_model.pth"):
156
+ model_path = os.path.join(PREPROCESSORS_ROOT, filename)
157
+ model = TED()
158
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
159
+ model.eval()
160
+ return cls(model)
161
+
162
+ def to(self, device):
163
+ self.model.to(device)
164
+ self.device = device
165
+ return self
166
+
167
+ def __call__(self, input_image, detect_resolution=512, safe_steps=2, upscale_method="INTER_CUBIC", output_type=None, **kwargs):
168
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
169
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
170
+
171
+ H, W, _ = input_image.shape
172
+ with torch.no_grad():
173
+ image_teed = torch.from_numpy(input_image.copy()).float().to(self.device)
174
+ image_teed = rearrange(image_teed, 'h w c -> 1 c h w')
175
+ edges = self.model(image_teed)
176
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
177
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
178
+ edges = np.stack(edges, axis=2)
179
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
180
+ if safe_steps != 0:
181
+ edge = safe_step(edge, safe_steps)
182
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
183
+
184
+ detected_map = remove_pad(HWC3(edge))
185
+ if output_type == "pil":
186
+ detected_map = Image.fromarray(detected_map[..., :3])
187
+
188
+ return detected_map
189
+
190
+ class LineartStandardDetector:
191
+ def __call__(self, input_image=None, guassian_sigma=6.0, intensity_threshold=8, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
192
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
193
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
194
+
195
+ x = input_image.astype(np.float32)
196
+ g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
197
+ intensity = np.min(g - x, axis=2).clip(0, 255)
198
+ intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
199
+ intensity *= 127
200
+ detected_map = intensity.clip(0, 255).astype(np.uint8)
201
+
202
+ detected_map = HWC3(remove_pad(detected_map))
203
+ if output_type == "pil":
204
+ detected_map = Image.fromarray(detected_map)
205
+ return detected_map
206
+
207
+ class AnyLinePreprocessor:
208
+ def __init__(self, mteed_model, lineart_standard_detector):
209
+ self.device = "cpu"
210
+ self.mteed_model = mteed_model
211
+ self.lineart_standard_detector = lineart_standard_detector
212
+
213
+ @classmethod
214
+ def from_pretrained(cls, mteed_filename="MTEED.pth"):
215
+ mteed_model = TEDDetector.from_pretrained(filename=mteed_filename)
216
+ lineart_standard_detector = LineartStandardDetector()
217
+ return cls(mteed_model, lineart_standard_detector)
218
+
219
+ def to(self, device):
220
+ self.mteed_model.to(device)
221
+ self.device = device
222
+ return self
223
+
224
+ def __call__(self, image, resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1):
225
+ # Process the image with MTEED model
226
+ mteed_result = self.mteed_model(image, detect_resolution=resolution)
227
+
228
+ # Process the image with the lineart standard preprocessor
229
+ lineart_result = self.lineart_standard_detector(image, guassian_sigma=2, intensity_threshold=3, resolution=resolution)
230
+
231
+ _lineart_result = get_intensity_mask(lineart_result, lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound)
232
+ _cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity)
233
+ _lineart_result = _lineart_result * _cleaned
234
+ _mteed_result = mteed_result
235
+
236
+ result = combine_layers(_mteed_result, _lineart_result)
237
+ # print(result.shape)
238
+ return result
239
+
240
+ class LineartDetector:
241
+ def __init__(self, model, coarse_model):
242
+ self.model = model
243
+ self.model_coarse = coarse_model
244
+ self.device = "cpu"
245
+
246
+ @classmethod
247
+ def from_pretrained(cls, filename="sk_model.pth", coarse_filename="sk_model2.pth"):
248
+ model_path = os.path.join(PREPROCESSORS_ROOT, filename)
249
+ coarse_model_path = os.path.join(PREPROCESSORS_ROOT, coarse_filename)
250
+
251
+ model = LineartGenerator(3, 1, 3)
252
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
253
+ model.eval()
254
+
255
+ coarse_model = LineartGenerator(3, 1, 3)
256
+ coarse_model.load_state_dict(torch.load(coarse_model_path, map_location="cpu"))
257
+ coarse_model.eval()
258
+
259
+ return cls(model, coarse_model)
260
+
261
+ def to(self, device):
262
+ self.model.to(device)
263
+ self.model_coarse.to(device)
264
+ self.device = device
265
+ return self
266
+
267
+ def __call__(self, input_image, coarse=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
268
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
269
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
270
+
271
+ model = self.model_coarse if coarse else self.model
272
+ assert detected_map.ndim == 3
273
+ with torch.no_grad():
274
+ image = torch.from_numpy(detected_map).float().to(self.device)
275
+ image = image / 255.0
276
+ image = rearrange(image, 'h w c -> 1 c h w')
277
+ line = model(image)[0][0]
278
+
279
+ line = line.cpu().numpy()
280
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
281
+
282
+ detected_map = HWC3(line)
283
+ detected_map = remove_pad(255 - detected_map)
284
+
285
+ if output_type == "pil":
286
+ detected_map = Image.fromarray(detected_map)
287
+
288
+ return detected_map
289
+
290
+
291
+ class InformativeDetector:
292
+ def __init__(self, anime_model, contour_model):
293
+ self.anime_model = anime_model
294
+ self.contour_model = contour_model
295
+ self.device = "cpu"
296
+
297
+ @classmethod
298
+ def from_pretrained(cls, anime_filename="anime_style.pth", contour_filename="contour_style.pth"):
299
+ anime_model_path = os.path.join(PREPROCESSORS_ROOT, anime_filename)
300
+ contour_model_path = os.path.join(PREPROCESSORS_ROOT, contour_filename)
301
+
302
+ # 创建两个Generator模型
303
+ anime_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3
304
+ anime_model.load_state_dict(torch.load(anime_model_path, map_location="cpu"))
305
+ anime_model.eval()
306
+
307
+ contour_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3
308
+ contour_model.load_state_dict(torch.load(contour_model_path, map_location="cpu"))
309
+ contour_model.eval()
310
+
311
+ return cls(anime_model, contour_model)
312
+
313
+ def to(self, device):
314
+ self.anime_model.to(device)
315
+ self.contour_model.to(device)
316
+ self.device = device
317
+ return self
318
+
319
+ def __call__(self, input_image, style="anime", detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
320
+ """
321
+ 提取sketch
322
+
323
+ Args:
324
+ input_image: 输入图像
325
+ style: "anime" 或 "contour"
326
+ detect_resolution: 检测分辨率
327
+ output_type: 输出类型
328
+ upscale_method: 上采样方法
329
+ """
330
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
331
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
332
+
333
+ # 选择模型
334
+ model = self.anime_model if style == "anime" else self.contour_model
335
+
336
+ assert detected_map.ndim == 3
337
+ with torch.no_grad():
338
+ image = torch.from_numpy(detected_map).float().to(self.device)
339
+ image = image / 255.0
340
+ # 转换维度 (h, w, c) -> (1, c, h, w)
341
+ image = image.permute(2, 0, 1).unsqueeze(0)
342
+
343
+ # 生成sketch
344
+ sketch = model(image)
345
+ sketch = sketch[0][0] # 取出第一个batch的第一个通道
346
+
347
+ sketch = sketch.cpu().numpy()
348
+ sketch = (sketch * 255.0).clip(0, 255).astype(np.uint8)
349
+
350
+ detected_map = HWC3(sketch)
351
+ detected_map = remove_pad(255 - detected_map) # 反转颜色
352
+
353
+ if output_type == "pil":
354
+ detected_map = Image.fromarray(detected_map)
355
+
356
+ return detected_map
train/src/condition/hed.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import warnings
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ from einops import rearrange
15
+ from PIL import Image
16
+
17
+ from .util import HWC3, nms, resize_image_with_pad, safe_step, common_input_validate
18
+
19
+
20
+ class DoubleConvBlock(torch.nn.Module):
21
+ def __init__(self, input_channel, output_channel, layer_number):
22
+ super().__init__()
23
+ self.convs = torch.nn.Sequential()
24
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
25
+ for i in range(1, layer_number):
26
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
27
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
28
+
29
+ def __call__(self, x, down_sampling=False):
30
+ h = x
31
+ if down_sampling:
32
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
33
+ for conv in self.convs:
34
+ h = conv(h)
35
+ h = torch.nn.functional.relu(h)
36
+ return h, self.projection(h)
37
+
38
+
39
+ class ControlNetHED_Apache2(torch.nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
43
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
44
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
45
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
46
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
47
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
48
+
49
+ def __call__(self, x):
50
+ h = x - self.norm
51
+ h, projection1 = self.block1(h)
52
+ h, projection2 = self.block2(h, down_sampling=True)
53
+ h, projection3 = self.block3(h, down_sampling=True)
54
+ h, projection4 = self.block4(h, down_sampling=True)
55
+ h, projection5 = self.block5(h, down_sampling=True)
56
+ return projection1, projection2, projection3, projection4, projection5
train/src/condition/informative_drawing.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import functools
5
+ from torchvision import models
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ import math
9
+
10
+ norm_layer = nn.InstanceNorm2d
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, in_features):
14
+ super(ResidualBlock, self).__init__()
15
+
16
+ conv_block = [ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(in_features, in_features, 3),
18
+ norm_layer(in_features),
19
+ nn.ReLU(inplace=True),
20
+ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(in_features, in_features, 3),
22
+ norm_layer(in_features)
23
+ ]
24
+
25
+ self.conv_block = nn.Sequential(*conv_block)
26
+
27
+ def forward(self, x):
28
+ return x + self.conv_block(x)
29
+
30
+
31
+ class Generator(nn.Module):
32
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
33
+ super(Generator, self).__init__()
34
+
35
+ # Initial convolution block
36
+ model0 = [ nn.ReflectionPad2d(3),
37
+ nn.Conv2d(input_nc, 64, 7),
38
+ norm_layer(64),
39
+ nn.ReLU(inplace=True) ]
40
+ self.model0 = nn.Sequential(*model0)
41
+
42
+ # Downsampling
43
+ model1 = []
44
+ in_features = 64
45
+ out_features = in_features*2
46
+ for _ in range(2):
47
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
48
+ norm_layer(out_features),
49
+ nn.ReLU(inplace=True) ]
50
+ in_features = out_features
51
+ out_features = in_features*2
52
+ self.model1 = nn.Sequential(*model1)
53
+
54
+ model2 = []
55
+ # Residual blocks
56
+ for _ in range(n_residual_blocks):
57
+ model2 += [ResidualBlock(in_features)]
58
+ self.model2 = nn.Sequential(*model2)
59
+
60
+ # Upsampling
61
+ model3 = []
62
+ out_features = in_features//2
63
+ for _ in range(2):
64
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
65
+ norm_layer(out_features),
66
+ nn.ReLU(inplace=True) ]
67
+ in_features = out_features
68
+ out_features = in_features//2
69
+ self.model3 = nn.Sequential(*model3)
70
+
71
+ # Output layer
72
+ model4 = [ nn.ReflectionPad2d(3),
73
+ nn.Conv2d(64, output_nc, 7)]
74
+ if sigmoid:
75
+ model4 += [nn.Sigmoid()]
76
+
77
+ self.model4 = nn.Sequential(*model4)
78
+
79
+ def forward(self, x, cond=None):
80
+ out = self.model0(x)
81
+ out = self.model1(out)
82
+ out = self.model2(out)
83
+ out = self.model3(out)
84
+ out = self.model4(out)
85
+
86
+ return out
87
+
88
+ # Define a resnet block
89
+ class ResnetBlock(nn.Module):
90
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
91
+ super(ResnetBlock, self).__init__()
92
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
93
+
94
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
95
+ conv_block = []
96
+ p = 0
97
+ if padding_type == 'reflect':
98
+ conv_block += [nn.ReflectionPad2d(1)]
99
+ elif padding_type == 'replicate':
100
+ conv_block += [nn.ReplicationPad2d(1)]
101
+ elif padding_type == 'zero':
102
+ p = 1
103
+ else:
104
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
105
+
106
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
107
+ norm_layer(dim),
108
+ activation]
109
+ if use_dropout:
110
+ conv_block += [nn.Dropout(0.5)]
111
+
112
+ p = 0
113
+ if padding_type == 'reflect':
114
+ conv_block += [nn.ReflectionPad2d(1)]
115
+ elif padding_type == 'replicate':
116
+ conv_block += [nn.ReplicationPad2d(1)]
117
+ elif padding_type == 'zero':
118
+ p = 1
119
+ else:
120
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
121
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
122
+ norm_layer(dim)]
123
+
124
+ return nn.Sequential(*conv_block)
125
+
126
+ def forward(self, x):
127
+ out = x + self.conv_block(x)
128
+ return out
129
+
130
+ class GlobalGenerator2(nn.Module):
131
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
132
+ padding_type='reflect', use_sig=False, n_UPsampling=0):
133
+ assert(n_blocks >= 0)
134
+ super(GlobalGenerator2, self).__init__()
135
+ activation = nn.ReLU(True)
136
+
137
+ mult = 8
138
+ model = [nn.ReflectionPad2d(4), nn.Conv2d(input_nc, ngf*mult, kernel_size=7, padding=0), norm_layer(ngf*mult), activation]
139
+
140
+ ### downsample
141
+ for i in range(n_downsampling):
142
+ model += [nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=4, stride=2, padding=1),
143
+ norm_layer(ngf * mult // 2), activation]
144
+ mult = mult // 2
145
+
146
+ if n_UPsampling <= 0:
147
+ n_UPsampling = n_downsampling
148
+
149
+ ### resnet blocks
150
+ for i in range(n_blocks):
151
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
152
+
153
+ ### upsample
154
+ for i in range(n_UPsampling):
155
+ next_mult = mult // 2
156
+ if next_mult == 0:
157
+ next_mult = 1
158
+ mult = 1
159
+
160
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * next_mult), kernel_size=3, stride=2, padding=1, output_padding=1),
161
+ norm_layer(int(ngf * next_mult)), activation]
162
+ mult = next_mult
163
+
164
+ if use_sig:
165
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
166
+ else:
167
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
168
+ self.model = nn.Sequential(*model)
169
+
170
+ def forward(self, input, cond=None):
171
+ return self.model(input)
172
+
173
+
174
+ class InceptionV3(nn.Module): #avg pool
175
+ def __init__(self, num_classes, isTrain, use_aux=True, pretrain=False, freeze=True, every_feat=False):
176
+ super(InceptionV3, self).__init__()
177
+ """ Inception v3 expects (299,299) sized images for training and has auxiliary output
178
+ """
179
+
180
+ self.every_feat = every_feat
181
+
182
+ self.model_ft = models.inception_v3(pretrained=pretrain)
183
+ stop = 0
184
+ if freeze and pretrain:
185
+ for child in self.model_ft.children():
186
+ if stop < 17:
187
+ for param in child.parameters():
188
+ param.requires_grad = False
189
+ stop += 1
190
+
191
+ num_ftrs = self.model_ft.AuxLogits.fc.in_features #768
192
+ self.model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
193
+
194
+ # Handle the primary net
195
+ num_ftrs = self.model_ft.fc.in_features #2048
196
+ self.model_ft.fc = nn.Linear(num_ftrs,num_classes)
197
+
198
+ self.model_ft.input_size = 299
199
+
200
+ self.isTrain = isTrain
201
+ self.use_aux = use_aux
202
+
203
+ if self.isTrain:
204
+ self.model_ft.train()
205
+ else:
206
+ self.model_ft.eval()
207
+
208
+
209
+ def forward(self, x, cond=None, catch_gates=False):
210
+ # N x 3 x 299 x 299
211
+ x = self.model_ft.Conv2d_1a_3x3(x)
212
+
213
+ # N x 32 x 149 x 149
214
+ x = self.model_ft.Conv2d_2a_3x3(x)
215
+ # N x 32 x 147 x 147
216
+ x = self.model_ft.Conv2d_2b_3x3(x)
217
+ # N x 64 x 147 x 147
218
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
219
+ # N x 64 x 73 x 73
220
+ x = self.model_ft.Conv2d_3b_1x1(x)
221
+ # N x 80 x 73 x 73
222
+ x = self.model_ft.Conv2d_4a_3x3(x)
223
+
224
+ # N x 192 x 71 x 71
225
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
226
+ # N x 192 x 35 x 35
227
+ x = self.model_ft.Mixed_5b(x)
228
+ feat1 = x
229
+ # N x 256 x 35 x 35
230
+ x = self.model_ft.Mixed_5c(x)
231
+ feat11 = x
232
+ # N x 288 x 35 x 35
233
+ x = self.model_ft.Mixed_5d(x)
234
+ feat12 = x
235
+ # N x 288 x 35 x 35
236
+ x = self.model_ft.Mixed_6a(x)
237
+ feat2 = x
238
+ # N x 768 x 17 x 17
239
+ x = self.model_ft.Mixed_6b(x)
240
+ feat21 = x
241
+ # N x 768 x 17 x 17
242
+ x = self.model_ft.Mixed_6c(x)
243
+ feat22 = x
244
+ # N x 768 x 17 x 17
245
+ x = self.model_ft.Mixed_6d(x)
246
+ feat23 = x
247
+ # N x 768 x 17 x 17
248
+ x = self.model_ft.Mixed_6e(x)
249
+
250
+ feat3 = x
251
+
252
+ # N x 768 x 17 x 17
253
+ aux_defined = self.isTrain and self.use_aux
254
+ if aux_defined:
255
+ aux = self.model_ft.AuxLogits(x)
256
+ else:
257
+ aux = None
258
+ # N x 768 x 17 x 17
259
+ x = self.model_ft.Mixed_7a(x)
260
+ # N x 1280 x 8 x 8
261
+ x = self.model_ft.Mixed_7b(x)
262
+ # N x 2048 x 8 x 8
263
+ x = self.model_ft.Mixed_7c(x)
264
+ # N x 2048 x 8 x 8
265
+ # Adaptive average pooling
266
+ x = F.adaptive_avg_pool2d(x, (1, 1))
267
+ # N x 2048 x 1 x 1
268
+ feats = F.dropout(x, training=self.isTrain)
269
+ # N x 2048 x 1 x 1
270
+ x = torch.flatten(feats, 1)
271
+ # N x 2048
272
+ x = self.model_ft.fc(x)
273
+ # N x 1000 (num_classes)
274
+
275
+ if self.every_feat:
276
+ # return feat21, feats, x
277
+ return x, feat21
278
+
279
+ return x, aux
train/src/condition/lineart.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import functools
5
+ from torchvision import models
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ import math
9
+
10
+ norm_layer = nn.InstanceNorm2d
11
+
12
+ class ResidualBlock(nn.Module):
13
+ def __init__(self, in_features):
14
+ super(ResidualBlock, self).__init__()
15
+
16
+ conv_block = [ nn.ReflectionPad2d(1),
17
+ nn.Conv2d(in_features, in_features, 3),
18
+ norm_layer(in_features),
19
+ nn.ReLU(inplace=True),
20
+ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(in_features, in_features, 3),
22
+ norm_layer(in_features)
23
+ ]
24
+
25
+ self.conv_block = nn.Sequential(*conv_block)
26
+
27
+ def forward(self, x):
28
+ return x + self.conv_block(x)
29
+
30
+
31
+ class Generator(nn.Module):
32
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
33
+ super(Generator, self).__init__()
34
+
35
+ # Initial convolution block
36
+ model0 = [ nn.ReflectionPad2d(3),
37
+ nn.Conv2d(input_nc, 64, 7),
38
+ norm_layer(64),
39
+ nn.ReLU(inplace=True) ]
40
+ self.model0 = nn.Sequential(*model0)
41
+
42
+ # Downsampling
43
+ model1 = []
44
+ in_features = 64
45
+ out_features = in_features*2
46
+ for _ in range(2):
47
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
48
+ norm_layer(out_features),
49
+ nn.ReLU(inplace=True) ]
50
+ in_features = out_features
51
+ out_features = in_features*2
52
+ self.model1 = nn.Sequential(*model1)
53
+
54
+ model2 = []
55
+ # Residual blocks
56
+ for _ in range(n_residual_blocks):
57
+ model2 += [ResidualBlock(in_features)]
58
+ self.model2 = nn.Sequential(*model2)
59
+
60
+ # Upsampling
61
+ model3 = []
62
+ out_features = in_features//2
63
+ for _ in range(2):
64
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
65
+ norm_layer(out_features),
66
+ nn.ReLU(inplace=True) ]
67
+ in_features = out_features
68
+ out_features = in_features//2
69
+ self.model3 = nn.Sequential(*model3)
70
+
71
+ # Output layer
72
+ model4 = [ nn.ReflectionPad2d(3),
73
+ nn.Conv2d(64, output_nc, 7)]
74
+ if sigmoid:
75
+ model4 += [nn.Sigmoid()]
76
+
77
+ self.model4 = nn.Sequential(*model4)
78
+
79
+ def forward(self, x, cond=None):
80
+ out = self.model0(x)
81
+ out = self.model1(out)
82
+ out = self.model2(out)
83
+ out = self.model3(out)
84
+ out = self.model4(out)
85
+
86
+ return out
train/src/condition/pidi.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Zhuo Su, Wenzhe Liu
3
+ Date: Feb 18, 2021
4
+ """
5
+
6
+ import math
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
16
+ """Numpy array to tensor.
17
+
18
+ Args:
19
+ imgs (list[ndarray] | ndarray): Input images.
20
+ bgr2rgb (bool): Whether to change bgr to rgb.
21
+ float32 (bool): Whether to change to float32.
22
+
23
+ Returns:
24
+ list[tensor] | tensor: Tensor images. If returned results only have
25
+ one element, just return tensor.
26
+ """
27
+
28
+ def _totensor(img, bgr2rgb, float32):
29
+ if img.shape[2] == 3 and bgr2rgb:
30
+ if img.dtype == 'float64':
31
+ img = img.astype('float32')
32
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
+ img = torch.from_numpy(img.transpose(2, 0, 1))
34
+ if float32:
35
+ img = img.float()
36
+ return img
37
+
38
+ if isinstance(imgs, list):
39
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
40
+ else:
41
+ return _totensor(imgs, bgr2rgb, float32)
42
+
43
+ nets = {
44
+ 'baseline': {
45
+ 'layer0': 'cv',
46
+ 'layer1': 'cv',
47
+ 'layer2': 'cv',
48
+ 'layer3': 'cv',
49
+ 'layer4': 'cv',
50
+ 'layer5': 'cv',
51
+ 'layer6': 'cv',
52
+ 'layer7': 'cv',
53
+ 'layer8': 'cv',
54
+ 'layer9': 'cv',
55
+ 'layer10': 'cv',
56
+ 'layer11': 'cv',
57
+ 'layer12': 'cv',
58
+ 'layer13': 'cv',
59
+ 'layer14': 'cv',
60
+ 'layer15': 'cv',
61
+ },
62
+ 'c-v15': {
63
+ 'layer0': 'cd',
64
+ 'layer1': 'cv',
65
+ 'layer2': 'cv',
66
+ 'layer3': 'cv',
67
+ 'layer4': 'cv',
68
+ 'layer5': 'cv',
69
+ 'layer6': 'cv',
70
+ 'layer7': 'cv',
71
+ 'layer8': 'cv',
72
+ 'layer9': 'cv',
73
+ 'layer10': 'cv',
74
+ 'layer11': 'cv',
75
+ 'layer12': 'cv',
76
+ 'layer13': 'cv',
77
+ 'layer14': 'cv',
78
+ 'layer15': 'cv',
79
+ },
80
+ 'a-v15': {
81
+ 'layer0': 'ad',
82
+ 'layer1': 'cv',
83
+ 'layer2': 'cv',
84
+ 'layer3': 'cv',
85
+ 'layer4': 'cv',
86
+ 'layer5': 'cv',
87
+ 'layer6': 'cv',
88
+ 'layer7': 'cv',
89
+ 'layer8': 'cv',
90
+ 'layer9': 'cv',
91
+ 'layer10': 'cv',
92
+ 'layer11': 'cv',
93
+ 'layer12': 'cv',
94
+ 'layer13': 'cv',
95
+ 'layer14': 'cv',
96
+ 'layer15': 'cv',
97
+ },
98
+ 'r-v15': {
99
+ 'layer0': 'rd',
100
+ 'layer1': 'cv',
101
+ 'layer2': 'cv',
102
+ 'layer3': 'cv',
103
+ 'layer4': 'cv',
104
+ 'layer5': 'cv',
105
+ 'layer6': 'cv',
106
+ 'layer7': 'cv',
107
+ 'layer8': 'cv',
108
+ 'layer9': 'cv',
109
+ 'layer10': 'cv',
110
+ 'layer11': 'cv',
111
+ 'layer12': 'cv',
112
+ 'layer13': 'cv',
113
+ 'layer14': 'cv',
114
+ 'layer15': 'cv',
115
+ },
116
+ 'cvvv4': {
117
+ 'layer0': 'cd',
118
+ 'layer1': 'cv',
119
+ 'layer2': 'cv',
120
+ 'layer3': 'cv',
121
+ 'layer4': 'cd',
122
+ 'layer5': 'cv',
123
+ 'layer6': 'cv',
124
+ 'layer7': 'cv',
125
+ 'layer8': 'cd',
126
+ 'layer9': 'cv',
127
+ 'layer10': 'cv',
128
+ 'layer11': 'cv',
129
+ 'layer12': 'cd',
130
+ 'layer13': 'cv',
131
+ 'layer14': 'cv',
132
+ 'layer15': 'cv',
133
+ },
134
+ 'avvv4': {
135
+ 'layer0': 'ad',
136
+ 'layer1': 'cv',
137
+ 'layer2': 'cv',
138
+ 'layer3': 'cv',
139
+ 'layer4': 'ad',
140
+ 'layer5': 'cv',
141
+ 'layer6': 'cv',
142
+ 'layer7': 'cv',
143
+ 'layer8': 'ad',
144
+ 'layer9': 'cv',
145
+ 'layer10': 'cv',
146
+ 'layer11': 'cv',
147
+ 'layer12': 'ad',
148
+ 'layer13': 'cv',
149
+ 'layer14': 'cv',
150
+ 'layer15': 'cv',
151
+ },
152
+ 'rvvv4': {
153
+ 'layer0': 'rd',
154
+ 'layer1': 'cv',
155
+ 'layer2': 'cv',
156
+ 'layer3': 'cv',
157
+ 'layer4': 'rd',
158
+ 'layer5': 'cv',
159
+ 'layer6': 'cv',
160
+ 'layer7': 'cv',
161
+ 'layer8': 'rd',
162
+ 'layer9': 'cv',
163
+ 'layer10': 'cv',
164
+ 'layer11': 'cv',
165
+ 'layer12': 'rd',
166
+ 'layer13': 'cv',
167
+ 'layer14': 'cv',
168
+ 'layer15': 'cv',
169
+ },
170
+ 'cccv4': {
171
+ 'layer0': 'cd',
172
+ 'layer1': 'cd',
173
+ 'layer2': 'cd',
174
+ 'layer3': 'cv',
175
+ 'layer4': 'cd',
176
+ 'layer5': 'cd',
177
+ 'layer6': 'cd',
178
+ 'layer7': 'cv',
179
+ 'layer8': 'cd',
180
+ 'layer9': 'cd',
181
+ 'layer10': 'cd',
182
+ 'layer11': 'cv',
183
+ 'layer12': 'cd',
184
+ 'layer13': 'cd',
185
+ 'layer14': 'cd',
186
+ 'layer15': 'cv',
187
+ },
188
+ 'aaav4': {
189
+ 'layer0': 'ad',
190
+ 'layer1': 'ad',
191
+ 'layer2': 'ad',
192
+ 'layer3': 'cv',
193
+ 'layer4': 'ad',
194
+ 'layer5': 'ad',
195
+ 'layer6': 'ad',
196
+ 'layer7': 'cv',
197
+ 'layer8': 'ad',
198
+ 'layer9': 'ad',
199
+ 'layer10': 'ad',
200
+ 'layer11': 'cv',
201
+ 'layer12': 'ad',
202
+ 'layer13': 'ad',
203
+ 'layer14': 'ad',
204
+ 'layer15': 'cv',
205
+ },
206
+ 'rrrv4': {
207
+ 'layer0': 'rd',
208
+ 'layer1': 'rd',
209
+ 'layer2': 'rd',
210
+ 'layer3': 'cv',
211
+ 'layer4': 'rd',
212
+ 'layer5': 'rd',
213
+ 'layer6': 'rd',
214
+ 'layer7': 'cv',
215
+ 'layer8': 'rd',
216
+ 'layer9': 'rd',
217
+ 'layer10': 'rd',
218
+ 'layer11': 'cv',
219
+ 'layer12': 'rd',
220
+ 'layer13': 'rd',
221
+ 'layer14': 'rd',
222
+ 'layer15': 'cv',
223
+ },
224
+ 'c16': {
225
+ 'layer0': 'cd',
226
+ 'layer1': 'cd',
227
+ 'layer2': 'cd',
228
+ 'layer3': 'cd',
229
+ 'layer4': 'cd',
230
+ 'layer5': 'cd',
231
+ 'layer6': 'cd',
232
+ 'layer7': 'cd',
233
+ 'layer8': 'cd',
234
+ 'layer9': 'cd',
235
+ 'layer10': 'cd',
236
+ 'layer11': 'cd',
237
+ 'layer12': 'cd',
238
+ 'layer13': 'cd',
239
+ 'layer14': 'cd',
240
+ 'layer15': 'cd',
241
+ },
242
+ 'a16': {
243
+ 'layer0': 'ad',
244
+ 'layer1': 'ad',
245
+ 'layer2': 'ad',
246
+ 'layer3': 'ad',
247
+ 'layer4': 'ad',
248
+ 'layer5': 'ad',
249
+ 'layer6': 'ad',
250
+ 'layer7': 'ad',
251
+ 'layer8': 'ad',
252
+ 'layer9': 'ad',
253
+ 'layer10': 'ad',
254
+ 'layer11': 'ad',
255
+ 'layer12': 'ad',
256
+ 'layer13': 'ad',
257
+ 'layer14': 'ad',
258
+ 'layer15': 'ad',
259
+ },
260
+ 'r16': {
261
+ 'layer0': 'rd',
262
+ 'layer1': 'rd',
263
+ 'layer2': 'rd',
264
+ 'layer3': 'rd',
265
+ 'layer4': 'rd',
266
+ 'layer5': 'rd',
267
+ 'layer6': 'rd',
268
+ 'layer7': 'rd',
269
+ 'layer8': 'rd',
270
+ 'layer9': 'rd',
271
+ 'layer10': 'rd',
272
+ 'layer11': 'rd',
273
+ 'layer12': 'rd',
274
+ 'layer13': 'rd',
275
+ 'layer14': 'rd',
276
+ 'layer15': 'rd',
277
+ },
278
+ 'carv4': {
279
+ 'layer0': 'cd',
280
+ 'layer1': 'ad',
281
+ 'layer2': 'rd',
282
+ 'layer3': 'cv',
283
+ 'layer4': 'cd',
284
+ 'layer5': 'ad',
285
+ 'layer6': 'rd',
286
+ 'layer7': 'cv',
287
+ 'layer8': 'cd',
288
+ 'layer9': 'ad',
289
+ 'layer10': 'rd',
290
+ 'layer11': 'cv',
291
+ 'layer12': 'cd',
292
+ 'layer13': 'ad',
293
+ 'layer14': 'rd',
294
+ 'layer15': 'cv',
295
+ },
296
+ }
297
+
298
+ def createConvFunc(op_type):
299
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
300
+ if op_type == 'cv':
301
+ return F.conv2d
302
+
303
+ if op_type == 'cd':
304
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
305
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
306
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
307
+ assert padding == dilation, 'padding for cd_conv set wrong'
308
+
309
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
310
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
311
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
312
+ return y - yc
313
+ return func
314
+ elif op_type == 'ad':
315
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
316
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
317
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
318
+ assert padding == dilation, 'padding for ad_conv set wrong'
319
+
320
+ shape = weights.shape
321
+ weights = weights.view(shape[0], shape[1], -1)
322
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
323
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
324
+ return y
325
+ return func
326
+ elif op_type == 'rd':
327
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
328
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
329
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
330
+ padding = 2 * dilation
331
+
332
+ shape = weights.shape
333
+ if weights.is_cuda:
334
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
335
+ else:
336
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device)
337
+ weights = weights.view(shape[0], shape[1], -1)
338
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
339
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
340
+ buffer[:, :, 12] = 0
341
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
342
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
343
+ return y
344
+ return func
345
+ else:
346
+ print('impossible to be here unless you force that')
347
+ return None
348
+
349
+ class Conv2d(nn.Module):
350
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
351
+ super(Conv2d, self).__init__()
352
+ if in_channels % groups != 0:
353
+ raise ValueError('in_channels must be divisible by groups')
354
+ if out_channels % groups != 0:
355
+ raise ValueError('out_channels must be divisible by groups')
356
+ self.in_channels = in_channels
357
+ self.out_channels = out_channels
358
+ self.kernel_size = kernel_size
359
+ self.stride = stride
360
+ self.padding = padding
361
+ self.dilation = dilation
362
+ self.groups = groups
363
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
364
+ if bias:
365
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
366
+ else:
367
+ self.register_parameter('bias', None)
368
+ self.reset_parameters()
369
+ self.pdc = pdc
370
+
371
+ def reset_parameters(self):
372
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
373
+ if self.bias is not None:
374
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
375
+ bound = 1 / math.sqrt(fan_in)
376
+ nn.init.uniform_(self.bias, -bound, bound)
377
+
378
+ def forward(self, input):
379
+
380
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
381
+
382
+ class CSAM(nn.Module):
383
+ """
384
+ Compact Spatial Attention Module
385
+ """
386
+ def __init__(self, channels):
387
+ super(CSAM, self).__init__()
388
+
389
+ mid_channels = 4
390
+ self.relu1 = nn.ReLU()
391
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
392
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
393
+ self.sigmoid = nn.Sigmoid()
394
+ nn.init.constant_(self.conv1.bias, 0)
395
+
396
+ def forward(self, x):
397
+ y = self.relu1(x)
398
+ y = self.conv1(y)
399
+ y = self.conv2(y)
400
+ y = self.sigmoid(y)
401
+
402
+ return x * y
403
+
404
+ class CDCM(nn.Module):
405
+ """
406
+ Compact Dilation Convolution based Module
407
+ """
408
+ def __init__(self, in_channels, out_channels):
409
+ super(CDCM, self).__init__()
410
+
411
+ self.relu1 = nn.ReLU()
412
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
413
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
414
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
415
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
416
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
417
+ nn.init.constant_(self.conv1.bias, 0)
418
+
419
+ def forward(self, x):
420
+ x = self.relu1(x)
421
+ x = self.conv1(x)
422
+ x1 = self.conv2_1(x)
423
+ x2 = self.conv2_2(x)
424
+ x3 = self.conv2_3(x)
425
+ x4 = self.conv2_4(x)
426
+ return x1 + x2 + x3 + x4
427
+
428
+
429
+ class MapReduce(nn.Module):
430
+ """
431
+ Reduce feature maps into a single edge map
432
+ """
433
+ def __init__(self, channels):
434
+ super(MapReduce, self).__init__()
435
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
436
+ nn.init.constant_(self.conv.bias, 0)
437
+
438
+ def forward(self, x):
439
+ return self.conv(x)
440
+
441
+
442
+ class PDCBlock(nn.Module):
443
+ def __init__(self, pdc, inplane, ouplane, stride=1):
444
+ super(PDCBlock, self).__init__()
445
+ self.stride=stride
446
+
447
+ self.stride=stride
448
+ if self.stride > 1:
449
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
450
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
451
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
452
+ self.relu2 = nn.ReLU()
453
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
454
+
455
+ def forward(self, x):
456
+ if self.stride > 1:
457
+ x = self.pool(x)
458
+ y = self.conv1(x)
459
+ y = self.relu2(y)
460
+ y = self.conv2(y)
461
+ if self.stride > 1:
462
+ x = self.shortcut(x)
463
+ y = y + x
464
+ return y
465
+
466
+ class PDCBlock_converted(nn.Module):
467
+ """
468
+ CPDC, APDC can be converted to vanilla 3x3 convolution
469
+ RPDC can be converted to vanilla 5x5 convolution
470
+ """
471
+ def __init__(self, pdc, inplane, ouplane, stride=1):
472
+ super(PDCBlock_converted, self).__init__()
473
+ self.stride=stride
474
+
475
+ if self.stride > 1:
476
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
477
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
478
+ if pdc == 'rd':
479
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
480
+ else:
481
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
482
+ self.relu2 = nn.ReLU()
483
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
484
+
485
+ def forward(self, x):
486
+ if self.stride > 1:
487
+ x = self.pool(x)
488
+ y = self.conv1(x)
489
+ y = self.relu2(y)
490
+ y = self.conv2(y)
491
+ if self.stride > 1:
492
+ x = self.shortcut(x)
493
+ y = y + x
494
+ return y
495
+
496
+ class PiDiNet(nn.Module):
497
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
498
+ super(PiDiNet, self).__init__()
499
+ self.sa = sa
500
+ if dil is not None:
501
+ assert isinstance(dil, int), 'dil should be an int'
502
+ self.dil = dil
503
+
504
+ self.fuseplanes = []
505
+
506
+ self.inplane = inplane
507
+ if convert:
508
+ if pdcs[0] == 'rd':
509
+ init_kernel_size = 5
510
+ init_padding = 2
511
+ else:
512
+ init_kernel_size = 3
513
+ init_padding = 1
514
+ self.init_block = nn.Conv2d(3, self.inplane,
515
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
516
+ block_class = PDCBlock_converted
517
+ else:
518
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
519
+ block_class = PDCBlock
520
+
521
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
522
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
523
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
524
+ self.fuseplanes.append(self.inplane) # C
525
+
526
+ inplane = self.inplane
527
+ self.inplane = self.inplane * 2
528
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
529
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
530
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
531
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
532
+ self.fuseplanes.append(self.inplane) # 2C
533
+
534
+ inplane = self.inplane
535
+ self.inplane = self.inplane * 2
536
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
537
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
538
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
539
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
540
+ self.fuseplanes.append(self.inplane) # 4C
541
+
542
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
543
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
544
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
545
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
546
+ self.fuseplanes.append(self.inplane) # 4C
547
+
548
+ self.conv_reduces = nn.ModuleList()
549
+ if self.sa and self.dil is not None:
550
+ self.attentions = nn.ModuleList()
551
+ self.dilations = nn.ModuleList()
552
+ for i in range(4):
553
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
554
+ self.attentions.append(CSAM(self.dil))
555
+ self.conv_reduces.append(MapReduce(self.dil))
556
+ elif self.sa:
557
+ self.attentions = nn.ModuleList()
558
+ for i in range(4):
559
+ self.attentions.append(CSAM(self.fuseplanes[i]))
560
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
561
+ elif self.dil is not None:
562
+ self.dilations = nn.ModuleList()
563
+ for i in range(4):
564
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
565
+ self.conv_reduces.append(MapReduce(self.dil))
566
+ else:
567
+ for i in range(4):
568
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
569
+
570
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
571
+ nn.init.constant_(self.classifier.weight, 0.25)
572
+ nn.init.constant_(self.classifier.bias, 0)
573
+
574
+ # print('initialization done')
575
+
576
+ def get_weights(self):
577
+ conv_weights = []
578
+ bn_weights = []
579
+ relu_weights = []
580
+ for pname, p in self.named_parameters():
581
+ if 'bn' in pname:
582
+ bn_weights.append(p)
583
+ elif 'relu' in pname:
584
+ relu_weights.append(p)
585
+ else:
586
+ conv_weights.append(p)
587
+
588
+ return conv_weights, bn_weights, relu_weights
589
+
590
+ def forward(self, x):
591
+ H, W = x.size()[2:]
592
+
593
+ x = self.init_block(x)
594
+
595
+ x1 = self.block1_1(x)
596
+ x1 = self.block1_2(x1)
597
+ x1 = self.block1_3(x1)
598
+
599
+ x2 = self.block2_1(x1)
600
+ x2 = self.block2_2(x2)
601
+ x2 = self.block2_3(x2)
602
+ x2 = self.block2_4(x2)
603
+
604
+ x3 = self.block3_1(x2)
605
+ x3 = self.block3_2(x3)
606
+ x3 = self.block3_3(x3)
607
+ x3 = self.block3_4(x3)
608
+
609
+ x4 = self.block4_1(x3)
610
+ x4 = self.block4_2(x4)
611
+ x4 = self.block4_3(x4)
612
+ x4 = self.block4_4(x4)
613
+
614
+ x_fuses = []
615
+ if self.sa and self.dil is not None:
616
+ for i, xi in enumerate([x1, x2, x3, x4]):
617
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
618
+ elif self.sa:
619
+ for i, xi in enumerate([x1, x2, x3, x4]):
620
+ x_fuses.append(self.attentions[i](xi))
621
+ elif self.dil is not None:
622
+ for i, xi in enumerate([x1, x2, x3, x4]):
623
+ x_fuses.append(self.dilations[i](xi))
624
+ else:
625
+ x_fuses = [x1, x2, x3, x4]
626
+
627
+ e1 = self.conv_reduces[0](x_fuses[0])
628
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
629
+
630
+ e2 = self.conv_reduces[1](x_fuses[1])
631
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
632
+
633
+ e3 = self.conv_reduces[2](x_fuses[2])
634
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
635
+
636
+ e4 = self.conv_reduces[3](x_fuses[3])
637
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
638
+
639
+ outputs = [e1, e2, e3, e4]
640
+
641
+ output = self.classifier(torch.cat(outputs, dim=1))
642
+ #if not self.training:
643
+ # return torch.sigmoid(output)
644
+
645
+ outputs.append(output)
646
+ outputs = [torch.sigmoid(r) for r in outputs]
647
+ return outputs
648
+
649
+ def config_model(model):
650
+ model_options = list(nets.keys())
651
+ assert model in model_options, \
652
+ 'unrecognized model, please choose from %s' % str(model_options)
653
+
654
+ # print(str(nets[model]))
655
+
656
+ pdcs = []
657
+ for i in range(16):
658
+ layer_name = 'layer%d' % i
659
+ op = nets[model][layer_name]
660
+ pdcs.append(createConvFunc(op))
661
+
662
+ return pdcs
663
+
664
+ def pidinet():
665
+ pdcs = config_model('carv4')
666
+ dil = 24 #if args.dil else None
667
+ return PiDiNet(60, pdcs, dil=dil, sa=True)
668
+
669
+
670
+ if __name__ == '__main__':
671
+ model = pidinet()
672
+ ckp = torch.load('table5_pidinet.pth')['state_dict']
673
+ model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
674
+ im = cv2.imread('examples/test_my/cat_v4.png')
675
+ im = img2tensor(im).unsqueeze(0)/255.
676
+ res = model(im)[-1]
677
+ res = res>0.5
678
+ res = res.float()
679
+ res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
680
+ print(res.shape)
681
+ cv2.imwrite('edge.png', res)
train/src/condition/ted.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
2
+ # with a Slightly modification
3
+ # LDC parameters:
4
+ # 155665
5
+ # TED > 58K
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .util import smish as Fsmish
12
+ from .util import Smish
13
+
14
+
15
+ def weight_init(m):
16
+ if isinstance(m, (nn.Conv2d,)):
17
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
18
+
19
+ if m.bias is not None:
20
+ torch.nn.init.zeros_(m.bias)
21
+
22
+ # for fusion layer
23
+ if isinstance(m, (nn.ConvTranspose2d,)):
24
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
25
+ if m.bias is not None:
26
+ torch.nn.init.zeros_(m.bias)
27
+
28
+ class CoFusion(nn.Module):
29
+ # from LDC
30
+
31
+ def __init__(self, in_ch, out_ch):
32
+ super(CoFusion, self).__init__()
33
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
34
+ stride=1, padding=1) # before 64
35
+ self.conv3= nn.Conv2d(32, out_ch, kernel_size=3,
36
+ stride=1, padding=1)# before 64 instead of 32
37
+ self.relu = nn.ReLU()
38
+ self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
39
+
40
+ def forward(self, x):
41
+ # fusecat = torch.cat(x, dim=1)
42
+ attn = self.relu(self.norm_layer1(self.conv1(x)))
43
+ attn = F.softmax(self.conv3(attn), dim=1)
44
+ return ((x * attn).sum(1)).unsqueeze(1)
45
+
46
+
47
+ class CoFusion2(nn.Module):
48
+ # TEDv14-3
49
+ def __init__(self, in_ch, out_ch):
50
+ super(CoFusion2, self).__init__()
51
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
52
+ stride=1, padding=1) # before 64
53
+ # self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
54
+ # stride=1, padding=1)# before 64
55
+ self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3,
56
+ stride=1, padding=1)# before 64 instead of 32
57
+ self.smish= Smish()#nn.ReLU(inplace=True)
58
+
59
+
60
+ def forward(self, x):
61
+ # fusecat = torch.cat(x, dim=1)
62
+ attn = self.conv1(self.smish(x))
63
+ attn = self.conv3(self.smish(attn)) # before , )dim=1)
64
+
65
+ # return ((fusecat * attn).sum(1)).unsqueeze(1)
66
+ return ((x * attn).sum(1)).unsqueeze(1)
67
+
68
+ class DoubleFusion(nn.Module):
69
+ # TED fusion before the final edge map prediction
70
+ def __init__(self, in_ch, out_ch):
71
+ super(DoubleFusion, self).__init__()
72
+ self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3,
73
+ stride=1, padding=1, groups=in_ch) # before 64
74
+ self.PSconv1 = nn.PixelShuffle(1)
75
+
76
+ self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3,
77
+ stride=1, padding=1,groups=24)# before 64 instead of 32
78
+
79
+ self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()#
80
+
81
+
82
+ def forward(self, x):
83
+ # fusecat = torch.cat(x, dim=1)
84
+ attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
85
+
86
+ attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
87
+
88
+ return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res
89
+
90
+ class _DenseLayer(nn.Sequential):
91
+ def __init__(self, input_features, out_features):
92
+ super(_DenseLayer, self).__init__()
93
+
94
+ self.add_module('conv1', nn.Conv2d(input_features, out_features,
95
+ kernel_size=3, stride=1, padding=2, bias=True)),
96
+ self.add_module('smish1', Smish()),
97
+ self.add_module('conv2', nn.Conv2d(out_features, out_features,
98
+ kernel_size=3, stride=1, bias=True))
99
+ def forward(self, x):
100
+ x1, x2 = x
101
+
102
+ new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()
103
+
104
+ return 0.5 * (new_features + x2), x2
105
+
106
+
107
+ class _DenseBlock(nn.Sequential):
108
+ def __init__(self, num_layers, input_features, out_features):
109
+ super(_DenseBlock, self).__init__()
110
+ for i in range(num_layers):
111
+ layer = _DenseLayer(input_features, out_features)
112
+ self.add_module('denselayer%d' % (i + 1), layer)
113
+ input_features = out_features
114
+
115
+
116
+ class UpConvBlock(nn.Module):
117
+ def __init__(self, in_features, up_scale):
118
+ super(UpConvBlock, self).__init__()
119
+ self.up_factor = 2
120
+ self.constant_features = 16
121
+
122
+ layers = self.make_deconv_layers(in_features, up_scale)
123
+ assert layers is not None, layers
124
+ self.features = nn.Sequential(*layers)
125
+
126
+ def make_deconv_layers(self, in_features, up_scale):
127
+ layers = []
128
+ all_pads=[0,0,1,3,7]
129
+ for i in range(up_scale):
130
+ kernel_size = 2 ** up_scale
131
+ pad = all_pads[up_scale] # kernel_size-1
132
+ out_features = self.compute_out_features(i, up_scale)
133
+ layers.append(nn.Conv2d(in_features, out_features, 1))
134
+ layers.append(Smish())
135
+ layers.append(nn.ConvTranspose2d(
136
+ out_features, out_features, kernel_size, stride=2, padding=pad))
137
+ in_features = out_features
138
+ return layers
139
+
140
+ def compute_out_features(self, idx, up_scale):
141
+ return 1 if idx == up_scale - 1 else self.constant_features
142
+
143
+ def forward(self, x):
144
+ return self.features(x)
145
+
146
+
147
+ class SingleConvBlock(nn.Module):
148
+ def __init__(self, in_features, out_features, stride, use_ac=False):
149
+ super(SingleConvBlock, self).__init__()
150
+ # self.use_bn = use_bs
151
+ self.use_ac=use_ac
152
+ self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
153
+ bias=True)
154
+ if self.use_ac:
155
+ self.smish = Smish()
156
+
157
+ def forward(self, x):
158
+ x = self.conv(x)
159
+ if self.use_ac:
160
+ return self.smish(x)
161
+ else:
162
+ return x
163
+
164
+ class DoubleConvBlock(nn.Module):
165
+ def __init__(self, in_features, mid_features,
166
+ out_features=None,
167
+ stride=1,
168
+ use_act=True):
169
+ super(DoubleConvBlock, self).__init__()
170
+
171
+ self.use_act = use_act
172
+ if out_features is None:
173
+ out_features = mid_features
174
+ self.conv1 = nn.Conv2d(in_features, mid_features,
175
+ 3, padding=1, stride=stride)
176
+ self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
177
+ self.smish= Smish()#nn.ReLU(inplace=True)
178
+
179
+ def forward(self, x):
180
+ x = self.conv1(x)
181
+ x = self.smish(x)
182
+ x = self.conv2(x)
183
+ if self.use_act:
184
+ x = self.smish(x)
185
+ return x
186
+
187
+
188
+ class TED(nn.Module):
189
+ """ Definition of Tiny and Efficient Edge Detector
190
+ model
191
+ """
192
+
193
+ def __init__(self):
194
+ super(TED, self).__init__()
195
+ self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,)
196
+ self.block_2 = DoubleConvBlock(16, 32, use_act=False)
197
+ self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
198
+
199
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
200
+
201
+ # skip1 connection, see fig. 2
202
+ self.side_1 = SingleConvBlock(16, 32, 2)
203
+
204
+ # skip2 connection, see fig. 2
205
+ self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
206
+
207
+ # USNet
208
+ self.up_block_1 = UpConvBlock(16, 1)
209
+ self.up_block_2 = UpConvBlock(32, 1)
210
+ self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
211
+
212
+ self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion
213
+
214
+ self.apply(weight_init)
215
+
216
+ def slice(self, tensor, slice_shape):
217
+ t_shape = tensor.shape
218
+ img_h, img_w = slice_shape
219
+ if img_w!=t_shape[-1] or img_h!=t_shape[2]:
220
+ new_tensor = F.interpolate(
221
+ tensor, size=(img_h, img_w), mode='bicubic',align_corners=False)
222
+
223
+ else:
224
+ new_tensor=tensor
225
+ # tensor[..., :height, :width]
226
+ return new_tensor
227
+ def resize_input(self,tensor):
228
+ t_shape = tensor.shape
229
+ if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
230
+ img_w= ((t_shape[3]// 8) + 1) * 8
231
+ img_h = ((t_shape[2] // 8) + 1) * 8
232
+ new_tensor = F.interpolate(
233
+ tensor, size=(img_h, img_w), mode='bicubic', align_corners=False)
234
+ else:
235
+ new_tensor = tensor
236
+ return new_tensor
237
+
238
+ def crop_bdcn(data1, h, w, crop_h, crop_w):
239
+ # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
240
+ _, _, h1, w1 = data1.size()
241
+ assert (h <= h1 and w <= w1)
242
+ data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w]
243
+ return data
244
+
245
+
246
+ def forward(self, x, single_test=False):
247
+ assert x.ndim == 4, x.shape
248
+ # supose the image size is 352x352
249
+
250
+ # Block 1
251
+ block_1 = self.block_1(x) # [8,16,176,176]
252
+ block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
253
+
254
+ # Block 2
255
+ block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
256
+ block_2_down = self.maxpool(block_2) # [8,32,88,88]
257
+ block_2_add = block_2_down + block_1_side # [8,32,88,88]
258
+
259
+ # Block 3
260
+ block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
261
+ block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
262
+
263
+ # upsampling blocks
264
+ out_1 = self.up_block_1(block_1)
265
+ out_2 = self.up_block_2(block_2)
266
+ out_3 = self.up_block_3(block_3)
267
+
268
+ results = [out_1, out_2, out_3]
269
+
270
+ # concatenate multiscale outputs
271
+ block_cat = torch.cat(results, dim=1) # Bx6xHxW
272
+ block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
273
+
274
+ results.append(block_cat)
275
+ return results
276
+
277
+
278
+ if __name__ == '__main__':
279
+ batch_size = 8
280
+ img_height = 352
281
+ img_width = 352
282
+
283
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
284
+ device = "cpu"
285
+ input = torch.rand(batch_size, 3, img_height, img_width).to(device)
286
+ # target = torch.rand(batch_size, 1, img_height, img_width).to(device)
287
+ print(f"input shape: {input.shape}")
288
+ model = TED().to(device)
289
+ output = model(input)
290
+ print(f"output shapes: {[t.shape for t in output]}")
291
+
292
+ # for i in range(20000):
293
+ # print(i)
294
+ # output = model(input)
295
+ # loss = nn.MSELoss()(output[-1], target)
296
+ # loss.backward()
train/src/condition/util.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import tempfile
4
+ import warnings
5
+ from contextlib import suppress
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from huggingface_hub import constants, hf_hub_download
12
+ from torch.hub import get_dir, download_url_to_file
13
+ from ast import literal_eval
14
+
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+
18
+ def safe_step(x, step=2):
19
+ y = x.astype(np.float32) * float(step + 1)
20
+ y = y.astype(np.int32).astype(np.float32) / float(step)
21
+ return y
22
+
23
+ def nms(x, t, s):
24
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
25
+
26
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
27
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
28
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
29
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
30
+
31
+ y = np.zeros_like(x)
32
+
33
+ for f in [f1, f2, f3, f4]:
34
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
35
+
36
+ z = np.zeros_like(y, dtype=np.uint8)
37
+ z[y > t] = 255
38
+ return z
39
+
40
+
41
+ def safer_memory(x):
42
+ # Fix many MAC/AMD problems
43
+ return np.ascontiguousarray(x.copy()).copy()
44
+
45
+ UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
46
+ def get_upscale_method(method_str):
47
+ assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
48
+ return getattr(cv2, method_str)
49
+
50
+ def pad64(x):
51
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
52
+
53
+ def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
54
+ if skip_hwc3:
55
+ img = input_image
56
+ else:
57
+ img = HWC3(input_image)
58
+ H_raw, W_raw, _ = img.shape
59
+ if resolution == 0:
60
+ return img, lambda x: x
61
+ k = float(resolution) / float(min(H_raw, W_raw))
62
+ H_target = int(np.round(float(H_raw) * k))
63
+ W_target = int(np.round(float(W_raw) * k))
64
+ img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
65
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
66
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
67
+
68
+ def remove_pad(x):
69
+ return safer_memory(x[:H_target, :W_target, ...])
70
+
71
+ return safer_memory(img_padded), remove_pad
72
+
73
+ def common_input_validate(input_image, output_type, **kwargs):
74
+ if "img" in kwargs:
75
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
76
+ input_image = kwargs.pop("img")
77
+
78
+ if "return_pil" in kwargs:
79
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
80
+ output_type = "pil" if kwargs["return_pil"] else "np"
81
+
82
+ if type(output_type) is bool:
83
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
84
+ if output_type:
85
+ output_type = "pil"
86
+
87
+ if input_image is None:
88
+ raise ValueError("input_image must be defined.")
89
+
90
+ if not isinstance(input_image, np.ndarray):
91
+ input_image = np.array(input_image, dtype=np.uint8)
92
+ output_type = output_type or "pil"
93
+ else:
94
+ output_type = output_type or "np"
95
+
96
+ return (input_image, output_type)
97
+
98
+ def HWC3(x):
99
+ assert x.dtype == np.uint8
100
+ if x.ndim == 2:
101
+ x = x[:, :, None]
102
+ assert x.ndim == 3
103
+ H, W, C = x.shape
104
+ assert C == 1 or C == 3 or C == 4
105
+ if C == 3:
106
+ return x
107
+ if C == 1:
108
+ return np.concatenate([x, x, x], axis=2)
109
+ if C == 4:
110
+ color = x[:, :, 0:3].astype(np.float32)
111
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
112
+ y = color * alpha + 255.0 * (1.0 - alpha)
113
+ y = y.clip(0, 255).astype(np.uint8)
114
+ return y
115
+
116
+ def get_intensity_mask(image_array, lower_bound, upper_bound):
117
+ mask = image_array[:, :, 0]
118
+ mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
119
+ mask = np.expand_dims(mask, 2).repeat(3, axis=2)
120
+ return mask
121
+
122
+ def combine_layers(base_layer, top_layer):
123
+ mask = top_layer.astype(bool)
124
+ temp = 1 - (1 - top_layer) * (1 - base_layer)
125
+ result = base_layer * (~mask) + temp * mask
126
+ return result
127
+
128
+ @torch.jit.script
129
+ def mish(input):
130
+ """
131
+ Applies the mish function element-wise:
132
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
133
+ See additional documentation for mish class.
134
+ """
135
+ return input * torch.tanh(F.softplus(input))
136
+
137
+ @torch.jit.script
138
+ def smish(input):
139
+ """
140
+ Applies the mish function element-wise:
141
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
142
+ See additional documentation for mish class.
143
+ """
144
+ return input * torch.tanh(torch.log(1+torch.sigmoid(input)))
145
+
146
+
147
+ class Mish(nn.Module):
148
+ """
149
+ Applies the mish function element-wise:
150
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
151
+ Shape:
152
+ - Input: (N, *) where * means, any number of additional
153
+ dimensions
154
+ - Output: (N, *), same shape as the input
155
+ Examples:
156
+ >>> m = Mish()
157
+ >>> input = torch.randn(2)
158
+ >>> output = m(input)
159
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
160
+ """
161
+
162
+ def __init__(self):
163
+ """
164
+ Init method.
165
+ """
166
+ super().__init__()
167
+
168
+ def forward(self, input):
169
+ """
170
+ Forward pass of the function.
171
+ """
172
+ if torch.__version__ >= "1.9":
173
+ return F.mish(input)
174
+ else:
175
+ return mish(input)
176
+
177
+ class Smish(nn.Module):
178
+ """
179
+ Applies the mish function element-wise:
180
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
181
+ Shape:
182
+ - Input: (N, *) where * means, any number of additional
183
+ dimensions
184
+ - Output: (N, *), same shape as the input
185
+ Examples:
186
+ >>> m = Mish()
187
+ >>> input = torch.randn(2)
188
+ >>> output = m(input)
189
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
190
+ """
191
+
192
+ def __init__(self):
193
+ """
194
+ Init method.
195
+ """
196
+ super().__init__()
197
+
198
+ def forward(self, input):
199
+ """
200
+ Forward pass of the function.
201
+ """
202
+ return smish(input)
train/src/generate_diff_mask.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone script: Given two images, generate a final difference mask using the
4
+ same pipeline as visualize_mask_diff (without any visualization output).
5
+
6
+ Pipeline:
7
+ 1) Align images to a preferred resolution/crop so they share the same size.
8
+ 2) Pixel-diff screening across parameter combinations; skip if any hull ratio is
9
+ outside [hull_min_allowed, hull_max_allowed].
10
+ 3) Color-diff to produce the final mask; remove small areas and re-check hull
11
+ ratio. Save final mask to output path.
12
+ """
13
+
14
+ import os
15
+ import json
16
+ import argparse
17
+ from typing import Tuple, Optional
18
+
19
+ import numpy as np
20
+ from PIL import Image
21
+ import cv2
22
+
23
+
24
+ PREFERRED_KONTEXT_RESOLUTIONS = [
25
+ (672, 1568), (688, 1504), (720, 1456), (752, 1392), (800, 1328),
26
+ (832, 1248), (880, 1184), (944, 1104), (1024, 1024), (1104, 944),
27
+ (1184, 880), (1248, 832), (1328, 800), (1392, 752), (1456, 720),
28
+ (1504, 688), (1568, 672),
29
+ ]
30
+
31
+
32
+ def choose_preferred_resolution(image_width: int, image_height: int) -> Tuple[int, int]:
33
+ aspect_ratio = image_width / max(1, image_height)
34
+ best = min(((abs(aspect_ratio - (w / h)), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS), key=lambda x: x[0])
35
+ _, w_best, h_best = best
36
+ return int(w_best), int(h_best)
37
+
38
+
39
+ def align_images(source_path: str, target_path: str) -> Tuple[Image.Image, Image.Image]:
40
+ source_img = Image.open(source_path).convert("RGB")
41
+ target_img = Image.open(target_path).convert("RGB")
42
+
43
+ pref_w, pref_h = choose_preferred_resolution(source_img.width, source_img.height)
44
+ source_resized = source_img.resize((pref_w, pref_h), Image.Resampling.LANCZOS)
45
+
46
+ tgt_w, tgt_h = target_img.width, target_img.height
47
+ crop_w = min(source_resized.width, tgt_w)
48
+ crop_h = min(source_resized.height, tgt_h)
49
+
50
+ source_aligned = source_resized.crop((0, 0, crop_w, crop_h))
51
+ target_aligned = target_img.crop((0, 0, crop_w, crop_h))
52
+ return source_aligned, target_aligned
53
+
54
+
55
+ def pil_to_cv_gray(img: Image.Image) -> np.ndarray:
56
+ bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
57
+ gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
58
+ return gray
59
+
60
+
61
+ def generate_pixel_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 11) -> np.ndarray:
62
+ img1_gray = pil_to_cv_gray(img1)
63
+ img2_gray = pil_to_cv_gray(img2)
64
+ diff = cv2.absdiff(img1_gray, img2_gray)
65
+ if threshold is None:
66
+ mask = cv2.threshold(diff, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
67
+ else:
68
+ mask = cv2.threshold(diff, int(threshold), 255, cv2.THRESH_BINARY)[1]
69
+ if clean_kernel_size and clean_kernel_size > 0:
70
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size))
71
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
72
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
73
+ return mask
74
+
75
+
76
+ def generate_color_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 21) -> np.ndarray:
77
+ bgr1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
78
+ bgr2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
79
+ lab1 = cv2.cvtColor(bgr1, cv2.COLOR_BGR2LAB).astype("float32")
80
+ lab2 = cv2.cvtColor(bgr2, cv2.COLOR_BGR2LAB).astype("float32")
81
+ diff = lab1 - lab2
82
+ dist = np.sqrt(np.sum(diff * diff, axis=2))
83
+ dist_u8 = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype("uint8")
84
+ if threshold is None:
85
+ mask = cv2.threshold(dist_u8, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
86
+ else:
87
+ mask = cv2.threshold(dist_u8, int(threshold), 255, cv2.THRESH_BINARY)[1]
88
+ if clean_kernel_size and clean_kernel_size > 0:
89
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size))
90
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
91
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
92
+ return mask
93
+
94
+
95
+ def compute_unified_contour(mask_bin: np.ndarray, contours: list, min_area: int = 40, method: str = "morph", morph_kernel: int = 15, morph_iters: int = 1, approx_epsilon_ratio: float = 0.01):
96
+ valid_cnts = []
97
+ for c in contours:
98
+ if cv2.contourArea(c) >= max(1, min_area):
99
+ valid_cnts.append(c)
100
+ if not valid_cnts:
101
+ return None
102
+ if method == "convex_hull":
103
+ all_points = np.vstack(valid_cnts)
104
+ hull = cv2.convexHull(all_points)
105
+ epsilon = approx_epsilon_ratio * cv2.arcLength(hull, True)
106
+ unified = cv2.approxPolyDP(hull, epsilon, True)
107
+ return unified
108
+ union = np.zeros_like(mask_bin)
109
+ cv2.drawContours(union, valid_cnts, -1, 255, thickness=-1)
110
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_kernel, morph_kernel))
111
+ union_closed = union.copy()
112
+ for _ in range(max(1, morph_iters)):
113
+ union_closed = cv2.morphologyEx(union_closed, cv2.MORPH_CLOSE, kernel)
114
+ ext = cv2.findContours(union_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
115
+ ext = ext[0] if len(ext) == 2 else ext[1]
116
+ if not ext:
117
+ return None
118
+ largest = max(ext, key=cv2.contourArea)
119
+ epsilon = approx_epsilon_ratio * cv2.arcLength(largest, True)
120
+ unified = cv2.approxPolyDP(largest, epsilon, True)
121
+ return unified
122
+
123
+
124
+ def compute_hull_area_ratio(mask: np.ndarray, min_area: int = 40) -> float:
125
+ mask_bin = (mask > 0).astype("uint8") * 255
126
+ cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
128
+ if not cnts:
129
+ return 0.0
130
+ hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=min_area, method="convex_hull", morph_kernel=15, morph_iters=1)
131
+ if hull_cnt is None or len(hull_cnt) < 3:
132
+ return 0.0
133
+ hull_area = float(cv2.contourArea(hull_cnt))
134
+ img_area = float(mask_bin.shape[0] * mask_bin.shape[1])
135
+ return hull_area / max(1.0, img_area)
136
+
137
+
138
+ def clean_and_fill_mask(mask: np.ndarray, min_area: int = 40) -> np.ndarray:
139
+ mask_bin = (mask > 0).astype("uint8") * 255
140
+ cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
141
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
142
+ cleaned = np.zeros_like(mask_bin)
143
+ for c in cnts:
144
+ if cv2.contourArea(c) >= max(1, min_area):
145
+ cv2.drawContours(cleaned, [c], 0, 255, -1)
146
+ return cleaned
147
+
148
+
149
+ def generate_final_difference_mask(source_path: str,
150
+ target_path: str,
151
+ hull_min_allowed: float = 0.001,
152
+ hull_max_allowed: float = 0.75,
153
+ pixel_parameters: Optional[list] = None,
154
+ pixel_clean_kernel_default: int = 11,
155
+ color_clean_kernel: int = 3,
156
+ roll_radius: int = 0,
157
+ roll_iters: int = 1) -> Optional[np.ndarray]:
158
+ if pixel_parameters is None:
159
+ # Mirrors the tuned combinations used in visualization script
160
+ pixel_parameters = [(None, 5), (None, 11), (50, 5)]
161
+
162
+ src_img, tgt_img = align_images(source_path, target_path)
163
+
164
+ # Pixel screening across parameter combinations
165
+ violation = False
166
+ for thr, ksize in pixel_parameters:
167
+ pm = generate_pixel_diff_mask(src_img, tgt_img, threshold=thr, clean_kernel_size=ksize)
168
+ r = compute_hull_area_ratio(pm, min_area=40)
169
+ if r < hull_min_allowed or r > hull_max_allowed:
170
+ violation = True
171
+ break
172
+ if violation:
173
+ # Failure: do not produce any mask
174
+ return None
175
+
176
+ # Color-based final mask → cleaned small areas
177
+ color_mask = generate_color_diff_mask(src_img, tgt_img, threshold=None, clean_kernel_size=color_clean_kernel)
178
+ cleaned = clean_and_fill_mask(color_mask, min_area=40)
179
+
180
+ # Produce binary mask from the convex hull contour of the cleaned mask
181
+ mask_bin = (cleaned > 0).astype("uint8") * 255
182
+ cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
183
+ cnts = cnts[0] if len(cnts) == 2 else cnts[1]
184
+ hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=40, method="convex_hull", morph_kernel=15, morph_iters=1)
185
+ if hull_cnt is None or len(hull_cnt) < 3:
186
+ return None
187
+
188
+ h_mask = np.zeros_like(mask_bin)
189
+ cv2.drawContours(h_mask, [hull_cnt], -1, 255, thickness=-1)
190
+
191
+ # Rolling-circle smoothing: closing then opening with a disk of radius R
192
+ if roll_radius and roll_radius > 0 and roll_iters and roll_iters > 0:
193
+ ksize = max(1, 2 * int(roll_radius) + 1)
194
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
195
+ for _ in range(max(1, roll_iters)):
196
+ h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_CLOSE, kernel)
197
+ h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_OPEN, kernel)
198
+
199
+ # Final hull ratio check on the hull-filled binary mask
200
+ r_final = compute_hull_area_ratio(h_mask, min_area=40)
201
+ if r_final > hull_max_allowed or r_final < hull_min_allowed:
202
+ return None
203
+
204
+ return h_mask
205
+
206
+
207
+ def main():
208
+ parser = argparse.ArgumentParser(description="Generate final difference mask (single pair or whole dataset)")
209
+ # Single-pair mode (optional): if provided, runs single pair; otherwise runs dataset mode
210
+ parser.add_argument("--source", help="Path to source image")
211
+ parser.add_argument("--target", help="Path to target image")
212
+ parser.add_argument("--output", help="Path to write the final mask (PNG)")
213
+ # Dataset mode (defaults to user's dataset paths)
214
+ parser.add_argument("--dataset_dir", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset", help="Base dataset dir with source_images/ and target_images/")
215
+ parser.add_argument("--dataset_output_dir", default="/home/lzc/KontextFill/visualizations_masks/inference_masks_smoothing", help="Output directory for batch masks")
216
+ parser.add_argument("--json_path", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset/extracted_data.json", help="Dataset JSON mapping with fields 'source_image' and 'target_image'")
217
+ # Common params
218
+ parser.add_argument("--hull_min_allowed", type=float, default=0.001)
219
+ parser.add_argument("--hull_max_allowed", type=float, default=0.75)
220
+ parser.add_argument("--color_clean_kernel", type=int, default=3)
221
+ parser.add_argument("--roll_radius", type=int, default=15, help="Rolling-circle smoothing radius (pixels); 0 disables")
222
+ parser.add_argument("--roll_iters", type=int, default=5, help="Rolling smoothing iterations")
223
+
224
+ args = parser.parse_args()
225
+
226
+ pixel_parameters = [(None, 5), (None, 11), (50, 5)]
227
+
228
+ # Decide mode: single or dataset
229
+ if args.source and args.target and args.output:
230
+ mask = generate_final_difference_mask(
231
+ source_path=args.source,
232
+ target_path=args.target,
233
+ hull_min_allowed=args.hull_min_allowed,
234
+ hull_max_allowed=args.hull_max_allowed,
235
+ pixel_parameters=pixel_parameters,
236
+ color_clean_kernel=args.color_clean_kernel,
237
+ roll_radius=args.roll_radius,
238
+ roll_iters=args.roll_iters,
239
+ )
240
+ if mask is None:
241
+ print("Single-pair inference failed; no output saved.")
242
+ return
243
+ os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
244
+ cv2.imwrite(args.output, mask)
245
+ return
246
+
247
+ # Dataset mode using JSON mapping
248
+ out_dir = args.dataset_output_dir
249
+ os.makedirs(out_dir, exist_ok=True)
250
+
251
+ processed = 0
252
+ skipped = 0
253
+ failed = 0
254
+ missing_files = 0
255
+ try:
256
+ with open(args.json_path, "r", encoding="utf-8") as f:
257
+ entries = json.load(f)
258
+ except Exception as e:
259
+ print(f"Failed to read JSON mapping at {args.json_path}: {e}")
260
+ entries = []
261
+
262
+ for item in entries:
263
+ try:
264
+ src_rel = item.get("source_image")
265
+ tgt_rel = item.get("target_image")
266
+ edit_id = item.get("id")
267
+ if not src_rel or not tgt_rel:
268
+ skipped += 1
269
+ continue
270
+ s = os.path.join(args.dataset_dir, src_rel)
271
+ t = os.path.join(args.dataset_dir, tgt_rel)
272
+ if not (os.path.exists(s) and os.path.exists(t)):
273
+ missing_files += 1
274
+ continue
275
+ mask = generate_final_difference_mask(
276
+ source_path=s,
277
+ target_path=t,
278
+ hull_min_allowed=args.hull_min_allowed,
279
+ hull_max_allowed=args.hull_max_allowed,
280
+ pixel_parameters=pixel_parameters,
281
+ color_clean_kernel=args.color_clean_kernel,
282
+ roll_radius=args.roll_radius,
283
+ roll_iters=args.roll_iters,
284
+ )
285
+ if mask is None:
286
+ failed += 1
287
+ continue
288
+ name = f"edit_{int(edit_id):04d}" if isinstance(edit_id, int) or (isinstance(edit_id, str) and edit_id.isdigit()) else os.path.splitext(os.path.basename(src_rel))[0]
289
+ out_path = os.path.join(out_dir, f"{name}.png")
290
+ cv2.imwrite(out_path, mask)
291
+ processed += 1
292
+ except Exception as e:
293
+ skipped += 1
294
+ continue
295
+ print(f"Batch done. Processed={processed}, Failed={failed}, Skipped={skipped}, MissingFiles={missing_files}, OutputDir={out_dir}")
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
300
+
301
+
train/src/jsonl_datasets_kontext_color.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from datasets import load_dataset
3
+ from torchvision import transforms
4
+ import random
5
+ import torch
6
+ import os
7
+ from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
8
+ import numpy as np
9
+ from .jsonl_datasets_colorization import FlexibleColorDetector
10
+
11
+ Image.MAX_IMAGE_PIXELS = None
12
+
13
+ def multiple_16(num: float):
14
+ return int(round(num / 16) * 16)
15
+
16
+ def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"):
17
+ image_path = os.path.join(root, image_path)
18
+ try:
19
+ image = Image.open(image_path).convert("RGB")
20
+ return image
21
+ except Exception as e:
22
+ print("file error: "+image_path)
23
+ with open("failed_images.txt", "a") as f:
24
+ f.write(f"{image_path}\n")
25
+ return Image.new("RGB", (size, size), (255, 255, 255))
26
+
27
+ def choose_kontext_resolution_from_wh(width: int, height: int):
28
+ aspect_ratio = width / max(1, height)
29
+ _, best_w, best_h = min(
30
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
31
+ )
32
+ return best_w, best_h
33
+
34
+ color_detector = FlexibleColorDetector()
35
+
36
+ def collate_fn(examples):
37
+ if examples[0].get("cond_pixel_values") is not None:
38
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
39
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
40
+ else:
41
+ cond_pixel_values = None
42
+ # source_pixel_values 被移除,保持兼容返回 None
43
+ source_pixel_values = None
44
+
45
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
46
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
47
+ token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
48
+ token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
49
+
50
+ return {
51
+ "cond_pixel_values": cond_pixel_values,
52
+ "source_pixel_values": source_pixel_values,
53
+ "pixel_values": target_pixel_values,
54
+ "text_ids_1": token_ids_clip,
55
+ "text_ids_2": token_ids_t5,
56
+ }
57
+
58
+
59
+ def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None):
60
+ # 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption
61
+ if args.train_data_dir is not None:
62
+ dataset = load_dataset('csv', data_files=args.train_data_dir)
63
+
64
+ # 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption
65
+ column_names = dataset["train"].column_names
66
+ image_col = column_names[0]
67
+ caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1]
68
+
69
+ size = args.cond_size
70
+
71
+ # 设备设置(保留接口,以后需要时可用)
72
+ if accelerator is not None:
73
+ device = accelerator.device
74
+ else:
75
+ device = "cpu"
76
+
77
+ # Transforms
78
+ to_tensor_and_norm = transforms.Compose([
79
+ transforms.ToTensor(),
80
+ transforms.Normalize([0.5], [0.5]),
81
+ ])
82
+
83
+ # cond 与 colorization 保持一致:CenterCrop -> ToTensor -> Normalize
84
+ cond_train_transforms = transforms.Compose([
85
+ transforms.CenterCrop((size, size)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize([0.5], [0.5]),
88
+ ])
89
+
90
+ tokenizer_clip = tokenizers[0]
91
+ tokenizer_t5 = tokenizers[1]
92
+
93
+ def tokenize_prompt_clip_t5(examples):
94
+ captions_raw = examples[caption_col]
95
+ captions = []
96
+ for c in captions_raw:
97
+ if isinstance(c, str):
98
+ if random.random() < 0.25:
99
+ captions.append("")
100
+ else:
101
+ captions.append(c)
102
+ else:
103
+ captions.append("")
104
+
105
+ text_inputs_clip = tokenizer_clip(
106
+ captions,
107
+ padding="max_length",
108
+ max_length=77,
109
+ truncation=True,
110
+ return_length=False,
111
+ return_overflowing_tokens=False,
112
+ return_tensors="pt",
113
+ )
114
+ text_input_ids_1 = text_inputs_clip.input_ids
115
+
116
+ text_inputs_t5 = tokenizer_t5(
117
+ captions,
118
+ padding="max_length",
119
+ max_length=128,
120
+ truncation=True,
121
+ return_length=False,
122
+ return_overflowing_tokens=False,
123
+ return_tensors="pt",
124
+ )
125
+ text_input_ids_2 = text_inputs_t5.input_ids
126
+ return text_input_ids_1, text_input_ids_2
127
+
128
+ def preprocess_train(examples):
129
+ batch = {}
130
+
131
+ img_paths = examples[image_col]
132
+
133
+ target_tensors = []
134
+ cond_tensors = []
135
+
136
+ for p in img_paths:
137
+ # Load image by joining with root in load_image_safely
138
+ img = load_image_safely(p, size)
139
+ img = img.convert("RGB")
140
+
141
+ # Resize to Kontext preferred resolution for target
142
+ w, h = img.size
143
+ best_w, best_h = choose_kontext_resolution_from_wh(w, h)
144
+ img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR)
145
+ target_tensor = to_tensor_and_norm(img_rs)
146
+
147
+ # Build color block condition
148
+ color_blocks = color_detector(input_image=img, block_size=32, output_size=size)
149
+ edge_tensor = cond_train_transforms(color_blocks)
150
+
151
+ target_tensors.append(target_tensor)
152
+ cond_tensors.append(edge_tensor)
153
+
154
+ batch["pixel_values"] = target_tensors
155
+ batch["cond_pixel_values"] = cond_tensors
156
+
157
+ batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
158
+ return batch
159
+
160
+ if accelerator is not None:
161
+ with accelerator.main_process_first():
162
+ train_dataset = dataset["train"].with_transform(preprocess_train)
163
+ else:
164
+ train_dataset = dataset["train"].with_transform(preprocess_train)
165
+
166
+ return train_dataset
train/src/jsonl_datasets_kontext_complete_lora.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ import torchvision.transforms.functional as TF
4
+ import random
5
+ import torch
6
+ import os
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import json
10
+
11
+ Image.MAX_IMAGE_PIXELS = None
12
+
13
+ def _prepend_caption(description: str, obj_name: str) -> str:
14
+ """Build instruction with stochastic OBJECT choice and keep only instruction with 20% prob.
15
+
16
+ OBJECT choice (equal probability):
17
+ - literal string "object"
18
+ - JSON field `object` with '_' replaced by space
19
+ - JSON field `description`
20
+ """
21
+ # Prepare options for OBJECT slot
22
+ cleaned_obj = (obj_name or "object").replace("_", " ").strip() or "object"
23
+ desc_opt = (description or "object").strip() or "object"
24
+ object_slot = random.choice(["object", cleaned_obj, desc_opt])
25
+
26
+ instruction = f"Complete the {object_slot}'s missing parts if necessary. White Background;"
27
+
28
+ return instruction
29
+
30
+ def collate_fn(examples):
31
+ if examples[0].get("cond_pixel_values") is not None:
32
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
33
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
34
+ else:
35
+ cond_pixel_values = None
36
+
37
+ if examples[0].get("source_pixel_values") is not None:
38
+ source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
39
+ source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
40
+ else:
41
+ source_pixel_values = None
42
+
43
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
44
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
45
+ token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
46
+ token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
47
+
48
+ mask_values = None
49
+ if examples[0].get("mask_values") is not None:
50
+ mask_values = torch.stack([example["mask_values"] for example in examples])
51
+ mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
52
+
53
+ return {
54
+ "cond_pixel_values": cond_pixel_values,
55
+ "source_pixel_values": source_pixel_values,
56
+ "pixel_values": target_pixel_values,
57
+ "text_ids_1": token_ids_clip,
58
+ "text_ids_2": token_ids_t5,
59
+ "mask_values": mask_values,
60
+ }
61
+
62
+
63
+ def _resolve_jsonl(path_str: str):
64
+ if path_str is None or str(path_str).strip() == "":
65
+ raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.")
66
+ if os.path.isdir(path_str):
67
+ files = [
68
+ os.path.join(path_str, f)
69
+ for f in os.listdir(path_str)
70
+ if f.lower().endswith((".jsonl", ".json"))
71
+ ]
72
+ if not files:
73
+ raise ValueError(f"No .json or .jsonl files found under directory: {path_str}")
74
+ return {"train": sorted(files)}
75
+ if not os.path.exists(path_str):
76
+ raise FileNotFoundError(f"train_data_jsonl not found: {path_str}")
77
+ return {"train": [path_str]}
78
+
79
+
80
+ def _tokenize(tokenizers, caption: str):
81
+ tokenizer_clip = tokenizers[0]
82
+ tokenizer_t5 = tokenizers[1]
83
+ text_inputs_clip = tokenizer_clip(
84
+ [caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt"
85
+ )
86
+ text_inputs_t5 = tokenizer_t5(
87
+ [caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt"
88
+ )
89
+ return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0]
90
+
91
+ def _apply_white_brushstrokes(image_np: np.ndarray, mask_bin: np.ndarray = None) -> np.ndarray:
92
+ """Draw random white brushstrokes on the RGB image array and return modified array.
93
+ Strokes preferentially start within mask_bin if provided.
94
+ """
95
+ import cv2
96
+ h, w = image_np.shape[:2]
97
+ rng = random.Random()
98
+
99
+ # Determine stroke counts and sizes based on image size
100
+ ref = max(1, min(h, w))
101
+ num_strokes = rng.randint(1, 5)
102
+ max_offset = max(5, ref // 40)
103
+ min_th = max(2, ref // 40)
104
+ max_th = max(min_th + 1, ref // 5)
105
+
106
+ out = image_np.copy()
107
+ prefer_mask_p = 0.33 if mask_bin is not None and mask_bin.any() else 0.0
108
+
109
+ def rand_point_inside_mask():
110
+ ys, xs = np.where(mask_bin > 0)
111
+ if len(xs) == 0:
112
+ return rng.randrange(w), rng.randrange(h)
113
+ i = rng.randrange(len(xs))
114
+ return int(xs[i]), int(ys[i])
115
+
116
+ def rand_point_any():
117
+ return rng.randrange(w), rng.randrange(h)
118
+
119
+ for _ in range(num_strokes):
120
+ if rng.random() < prefer_mask_p:
121
+ px, py = rand_point_inside_mask()
122
+ else:
123
+ px, py = rand_point_any()
124
+ px, py = rand_point_any()
125
+
126
+ # Polyline with several jittered segments
127
+ segments = rng.randint(40, 80)
128
+ thickness = rng.randint(min_th, max_th)
129
+ for _ in range(segments):
130
+ dx = rng.randint(-max_offset, max_offset)
131
+ dy = rng.randint(-max_offset, max_offset)
132
+ nx = int(np.clip(px + dx, 0, w - 1))
133
+ ny = int(np.clip(py + dy, 0, h - 1))
134
+ cv2.line(out, (px, py), (nx, ny), (255, 255, 255), thickness)
135
+ px, py = nx, ny
136
+
137
+ return out
138
+
139
+
140
+ def make_train_dataset_subjects(args, tokenizers, accelerator=None):
141
+ """
142
+ Dataset for JSONL with fields (one JSON object per line):
143
+ - white_image_path: absolute path to base image used for both pixel_values and source_pixel_values
144
+ - mask_path: absolute path to mask image (grayscale)
145
+ - img_width: target width
146
+ - img_height: target height
147
+ - description: caption text
148
+
149
+ Behavior:
150
+ - pixel_values = white_image_path resized to (img_width, img_height)
151
+ - source_pixel_values = same image but with random white brushstrokes overlaid
152
+ - mask_values = binarized mask from mask_path resized with nearest neighbor
153
+ - captions tokenized from description
154
+ """
155
+ data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None))
156
+ file_paths = data_files.get("train", [])
157
+ records = []
158
+ for p in file_paths:
159
+ with open(p, "r", encoding="utf-8") as f:
160
+ for line in f:
161
+ line = line.strip()
162
+ if not line:
163
+ continue
164
+ try:
165
+ obj = json.loads(line)
166
+ except Exception:
167
+ # Best-effort: strip any trailing commas and retry
168
+ try:
169
+ obj = json.loads(line.rstrip(","))
170
+ except Exception:
171
+ continue
172
+ # Keep only fields we need for this dataset schema
173
+ pruned = {
174
+ "white_image_path": obj.get("white_image_path"),
175
+ "mask_path": obj.get("mask_path"),
176
+ "img_width": obj.get("img_width"),
177
+ "img_height": obj.get("img_height"),
178
+ "description": obj.get("description"),
179
+ "object": obj.get("object"),
180
+ }
181
+ records.append(pruned)
182
+
183
+ size = int(getattr(args, "cond_size", 512))
184
+
185
+ to_tensor_and_norm = transforms.Compose([
186
+ transforms.ToTensor(),
187
+ transforms.Normalize([0.5], [0.5]),
188
+ ])
189
+
190
+ # Repeat each record with independent random brushstrokes
191
+ REPEATS_PER_IMAGE = 5
192
+
193
+ class SubjectsDataset(torch.utils.data.Dataset):
194
+ def __init__(self, hf_ds):
195
+ self.ds = hf_ds
196
+ self.repeats = REPEATS_PER_IMAGE
197
+ def __len__(self):
198
+ if self.repeats and self.repeats > 1:
199
+ return len(self.ds) * self.repeats
200
+ return len(self.ds)
201
+ def __getitem__(self, idx):
202
+ if self.repeats and self.repeats > 1:
203
+ base_idx = idx % len(self.ds)
204
+ else:
205
+ base_idx = idx
206
+ rec = self.ds[base_idx]
207
+
208
+ white_p = rec.get("white_image_path", "") or ""
209
+ mask_p = rec.get("mask_path", "") or ""
210
+
211
+ if not os.path.isabs(white_p):
212
+ # Allow absolute path only to avoid ambiguity
213
+ raise ValueError("white_image_path must be absolute")
214
+ if not os.path.isabs(mask_p):
215
+ raise ValueError("mask_path must be absolute")
216
+
217
+ import cv2
218
+ mask_loaded = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
219
+ if mask_loaded is None:
220
+ raise ValueError(f"Failed to read mask: {mask_p}")
221
+
222
+ base_img = Image.open(white_p).convert("RGB")
223
+
224
+ # Desired output size
225
+ fw = int(rec.get("img_width") or base_img.width)
226
+ fh = int(rec.get("img_height") or base_img.height)
227
+ base_img = base_img.resize((fw, fh), resample=Image.BILINEAR)
228
+ mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
229
+
230
+ # Tensors: target is the clean white image
231
+ target_tensor = to_tensor_and_norm(base_img)
232
+
233
+ # Binary mask at final_size
234
+ mask_np = np.array(mask_img)
235
+ mask_bin = (mask_np > 127).astype(np.uint8)
236
+
237
+ # Build source by drawing random white brushstrokes on top of the white image
238
+ base_np = np.array(base_img).astype(np.uint8)
239
+ stroked_np = _apply_white_brushstrokes(base_np, mask_bin)
240
+
241
+ # Build tensors
242
+ source_tensor = to_tensor_and_norm(Image.fromarray(stroked_np.astype(np.uint8)))
243
+ mask_tensor = torch.from_numpy(mask_bin.astype(np.float32)).unsqueeze(0)
244
+
245
+ # Caption: build instruction using description and object
246
+ description = rec.get("description", "")
247
+ obj_name = rec.get("object", "")
248
+ cap = _prepend_caption(description, obj_name)
249
+ ids1, ids2 = _tokenize(tokenizers, cap)
250
+
251
+ return {
252
+ "source_pixel_values": source_tensor,
253
+ "pixel_values": target_tensor,
254
+ "token_ids_clip": ids1,
255
+ "token_ids_t5": ids2,
256
+ "mask_values": mask_tensor,
257
+ }
258
+
259
+ return SubjectsDataset(records)
260
+
261
+
262
+
263
+
264
+ def _run_test_mode(test_jsonl: str, output_dir: str, num_samples: int = 50):
265
+ """Utility to visualize augmentation: saves pairs of (target, source) images.
266
+ Reads the JSONL directly, applies the same logic as dataset to produce
267
+ pixel_values (target) and source_pixel_values (with white strokes),
268
+ then writes them to output_dir for manual inspection.
269
+ """
270
+ os.makedirs(output_dir, exist_ok=True)
271
+ to_tensor_and_norm = transforms.Compose([
272
+ transforms.ToTensor(),
273
+ transforms.Normalize([0.5], [0.5]),
274
+ ])
275
+
276
+ # Minimal tokenizers shim to reuse dataset tokenization pipeline
277
+ class _NoOpTokenizer:
278
+ def __call__(self, texts, padding=None, max_length=None, truncation=None, return_tensors=None):
279
+ return type("T", (), {"input_ids": torch.zeros((1, 1), dtype=torch.long)})()
280
+
281
+ tokenizers = [_NoOpTokenizer(), _NoOpTokenizer()]
282
+
283
+ saved = 0
284
+ line_idx = 0
285
+ import cv2
286
+ with open(test_jsonl, "r", encoding="utf-8") as f:
287
+ for raw in f:
288
+ if saved >= num_samples:
289
+ break
290
+ raw = raw.strip()
291
+ if not raw:
292
+ continue
293
+ try:
294
+ obj = json.loads(raw)
295
+ except Exception:
296
+ try:
297
+ obj = json.loads(raw.rstrip(","))
298
+ except Exception:
299
+ continue
300
+
301
+ rec = {
302
+ "white_image_path": obj.get("white_image_path"),
303
+ "mask_path": obj.get("mask_path"),
304
+ "img_width": obj.get("img_width"),
305
+ "img_height": obj.get("img_height"),
306
+ "description": obj.get("description"),
307
+ }
308
+
309
+ white_p = rec.get("white_image_path", "") or ""
310
+ mask_p = rec.get("mask_path", "") or ""
311
+ if not white_p or not mask_p:
312
+ continue
313
+ if not (os.path.isabs(white_p) and os.path.isabs(mask_p)):
314
+ continue
315
+
316
+ mask_loaded = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
317
+ if mask_loaded is None:
318
+ continue
319
+
320
+ try:
321
+ base_img = Image.open(white_p).convert("RGB")
322
+ except Exception:
323
+ continue
324
+
325
+ fw = int(rec.get("img_width") or base_img.width)
326
+ fh = int(rec.get("img_height") or base_img.height)
327
+ base_img = base_img.resize((fw, fh), resample=Image.BILINEAR)
328
+ mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
329
+
330
+ mask_np = np.array(mask_img)
331
+ mask_bin = (mask_np > 127).astype(np.uint8)
332
+
333
+ base_np = np.array(base_img).astype(np.uint8)
334
+ stroked_np = _apply_white_brushstrokes(base_np, mask_bin)
335
+
336
+ # Save images
337
+ idx_str = f"{line_idx:05d}"
338
+ try:
339
+ Image.fromarray(base_np).save(os.path.join(output_dir, f"{idx_str}_target.jpg"))
340
+ Image.fromarray(stroked_np).save(os.path.join(output_dir, f"{idx_str}_source.jpg"))
341
+ Image.fromarray((mask_bin * 255).astype(np.uint8)).save(os.path.join(output_dir, f"{idx_str}_mask.png"))
342
+ saved += 1
343
+ except Exception:
344
+ pass
345
+ line_idx += 1
346
+
347
+
348
+ def _parse_test_args():
349
+ import argparse
350
+ parser = argparse.ArgumentParser(description="Test visualization for Kontext complete dataset")
351
+ parser.add_argument("--test_jsonl", type=str, default="/robby/share/Editing/lzc/subject_completion/white_bg_picked/results_picked_filtered.jsonl", help="Path to JSONL to preview")
352
+ parser.add_argument("--output_dir", type=str, default="/robby/share/Editing/lzc/subject_completion/train_test", help="Output directory to save pairs")
353
+ parser.add_argument("--num_samples", type=int, default=50, help="Number of pairs to save")
354
+ return parser.parse_args()
355
+
356
+
357
+ if __name__ == "__main__":
358
+ try:
359
+ args = _parse_test_args()
360
+ _run_test_mode(args.test_jsonl, args.output_dir, args.num_samples)
361
+ except SystemExit:
362
+ # Allow import usage without triggering test mode
363
+ pass
train/src/jsonl_datasets_kontext_edge.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from datasets import load_dataset
3
+ from torchvision import transforms
4
+ import random
5
+ import torch
6
+ import os
7
+ from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
8
+ import numpy as np
9
+ from src.condition.edge_extraction import (
10
+ CannyDetector, PidiNetDetector, TEDDetector, LineartStandardDetector, HEDdetector,
11
+ AnyLinePreprocessor, LineartDetector, InformativeDetector
12
+ )
13
+
14
+ Image.MAX_IMAGE_PIXELS = None
15
+
16
+ def multiple_16(num: float):
17
+ return int(round(num / 16) * 16)
18
+
19
+ def load_image_safely(image_path, size, root="/mnt/robby-b1/common/datasets/"):
20
+ image_path = os.path.join(root, image_path)
21
+ try:
22
+ image = Image.open(image_path).convert("RGB")
23
+ return image
24
+ except Exception as e:
25
+ print("file error: "+image_path)
26
+ with open("failed_images.txt", "a") as f:
27
+ f.write(f"{image_path}\n")
28
+ return Image.new("RGB", (size, size), (255, 255, 255))
29
+
30
+ def choose_kontext_resolution_from_wh(width: int, height: int):
31
+ aspect_ratio = width / max(1, height)
32
+ _, best_w, best_h = min(
33
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
34
+ )
35
+ return best_w, best_h
36
+
37
+ class EdgeExtractorManager:
38
+ _instance = None
39
+ _initialized = False
40
+
41
+ def __new__(cls):
42
+ if cls._instance is None:
43
+ cls._instance = super(EdgeExtractorManager, cls).__new__(cls)
44
+ return cls._instance
45
+
46
+ def __init__(self):
47
+ if not self._initialized:
48
+ self.edge_extractors = None
49
+ self.device = None
50
+ self._initialized = True
51
+
52
+ def set_device(self, device):
53
+ self.device = device
54
+
55
+ def get_edge_extractors(self, device=None):
56
+ # 强制在CPU上初始化,避免DataLoader子进程中触发CUDA初始化
57
+ current_device = "cpu"
58
+ if device is not None:
59
+ self.set_device(current_device)
60
+
61
+ if self.edge_extractors is None or len(self.edge_extractors) == 0:
62
+ self.edge_extractors = [
63
+ ("canny", CannyDetector()),
64
+ ("pidinet", PidiNetDetector.from_pretrained().to(current_device)),
65
+ ("ted", TEDDetector.from_pretrained().to(current_device)),
66
+ # ("lineart_standard", LineartStandardDetector()),
67
+ ("hed", HEDdetector.from_pretrained().to(current_device)),
68
+ ("anyline", AnyLinePreprocessor.from_pretrained().to(current_device)),
69
+ # ("lineart", LineartDetector.from_pretrained().to(current_device)),
70
+ ("informative", InformativeDetector.from_pretrained().to(current_device)),
71
+ ]
72
+
73
+ return self.edge_extractors
74
+
75
+ edge_extractor_manager = EdgeExtractorManager()
76
+
77
+ def collate_fn(examples):
78
+ if examples[0].get("cond_pixel_values") is not None:
79
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
80
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
81
+ else:
82
+ cond_pixel_values = None
83
+ source_pixel_values = None
84
+
85
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
86
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
87
+ token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
88
+ token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
89
+
90
+ return {
91
+ "cond_pixel_values": cond_pixel_values,
92
+ "source_pixel_values": source_pixel_values,
93
+ "pixel_values": target_pixel_values,
94
+ "text_ids_1": token_ids_clip,
95
+ "text_ids_2": token_ids_t5,
96
+ }
97
+
98
+
99
+ def make_train_dataset_inpaint_mask(args, tokenizers, accelerator=None):
100
+ # 加载CSV数据集:三列,第一列为图片相对路径,第三列为caption
101
+ if args.train_data_dir is not None:
102
+ dataset = load_dataset('csv', data_files=args.train_data_dir)
103
+
104
+ # 列名兼容处理:使用第 0 列作为图片路径,第 2 列作为caption
105
+ column_names = dataset["train"].column_names
106
+ image_col = column_names[0]
107
+ caption_col = column_names[2] if len(column_names) >= 3 else column_names[-1]
108
+
109
+ size = args.cond_size
110
+
111
+ # 设备设置(用于分布式时将部分检测器放到对应GPU)
112
+ if accelerator is not None:
113
+ device = accelerator.device
114
+ edge_extractor_manager.set_device(device)
115
+ else:
116
+ device = "cpu"
117
+
118
+ # Transforms
119
+ to_tensor_and_norm = transforms.Compose([
120
+ transforms.ToTensor(),
121
+ transforms.Normalize([0.5], [0.5]),
122
+ ])
123
+
124
+ # 与 jsonl_datasets_edge.py 保持一致:Resize -> CenterCrop -> ToTensor -> Normalize
125
+ cond_train_transforms = transforms.Compose([
126
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
127
+ transforms.CenterCrop((size, size)),
128
+ transforms.ToTensor(),
129
+ transforms.Normalize([0.5], [0.5]),
130
+ ])
131
+
132
+ tokenizer_clip = tokenizers[0]
133
+ tokenizer_t5 = tokenizers[1]
134
+
135
+ def tokenize_prompt_clip_t5(examples):
136
+ captions_raw = examples[caption_col]
137
+ captions = []
138
+ for c in captions_raw:
139
+ if isinstance(c, str):
140
+ if random.random() < 0.25:
141
+ captions.append("")
142
+ else:
143
+ captions.append(c)
144
+ else:
145
+ captions.append("")
146
+
147
+ text_inputs_clip = tokenizer_clip(
148
+ captions,
149
+ padding="max_length",
150
+ max_length=77,
151
+ truncation=True,
152
+ return_length=False,
153
+ return_overflowing_tokens=False,
154
+ return_tensors="pt",
155
+ )
156
+ text_input_ids_1 = text_inputs_clip.input_ids
157
+
158
+ text_inputs_t5 = tokenizer_t5(
159
+ captions,
160
+ padding="max_length",
161
+ max_length=128,
162
+ truncation=True,
163
+ return_length=False,
164
+ return_overflowing_tokens=False,
165
+ return_tensors="pt",
166
+ )
167
+ text_input_ids_2 = text_inputs_t5.input_ids
168
+ return text_input_ids_1, text_input_ids_2
169
+
170
+ def preprocess_train(examples):
171
+ batch = {}
172
+
173
+ img_paths = examples[image_col]
174
+
175
+ target_tensors = []
176
+ cond_tensors = []
177
+
178
+ for p in img_paths:
179
+ # Load image by joining with root in load_image_safely
180
+ img = load_image_safely(p, size)
181
+ img = img.convert("RGB")
182
+
183
+ # Resize to Kontext preferred resolution for target
184
+ w, h = img.size
185
+ best_w, best_h = choose_kontext_resolution_from_wh(w, h)
186
+ img_rs = img.resize((best_w, best_h), resample=Image.BILINEAR)
187
+ target_tensor = to_tensor_and_norm(img_rs)
188
+
189
+ # Build edge condition
190
+ extractor_name, extractor = random.choice(edge_extractor_manager.get_edge_extractors())
191
+ img_np = np.array(img)
192
+ if extractor_name == "informative":
193
+ edge = extractor(img_np, style="contour")
194
+ else:
195
+ edge = extractor(img_np)
196
+
197
+ if extractor_name == "ted":
198
+ th = 128
199
+ else:
200
+ th = 32
201
+
202
+ edge_np = np.array(edge) if isinstance(edge, Image.Image) else edge
203
+ if edge_np.ndim == 3:
204
+ edge_np = edge_np[..., 0]
205
+ edge_bin = (edge_np > th).astype(np.float32)
206
+ edge_pil = Image.fromarray((edge_bin * 255).astype(np.uint8))
207
+ edge_tensor = cond_train_transforms(edge_pil)
208
+ edge_tensor = edge_tensor.repeat(3, 1, 1)
209
+
210
+ target_tensors.append(target_tensor)
211
+ cond_tensors.append(edge_tensor)
212
+
213
+ batch["pixel_values"] = target_tensors
214
+ batch["cond_pixel_values"] = cond_tensors
215
+
216
+ batch["token_ids_clip"], batch["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
217
+ return batch
218
+
219
+ if accelerator is not None:
220
+ with accelerator.main_process_first():
221
+ train_dataset = dataset["train"].with_transform(preprocess_train)
222
+ else:
223
+ train_dataset = dataset["train"].with_transform(preprocess_train)
224
+
225
+ return train_dataset
train/src/jsonl_datasets_kontext_interactive_lora.py ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ import torchvision.transforms.functional as TF
4
+ import random
5
+ import torch
6
+ import os
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import json
10
+
11
+ Image.MAX_IMAGE_PIXELS = None
12
+
13
+
14
+ def collate_fn(examples):
15
+ if examples[0].get("cond_pixel_values") is not None:
16
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
17
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
18
+ else:
19
+ cond_pixel_values = None
20
+
21
+ if examples[0].get("source_pixel_values") is not None:
22
+ source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
23
+ source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
24
+ else:
25
+ source_pixel_values = None
26
+
27
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
28
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
29
+ token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
30
+ token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
31
+
32
+ mask_values = None
33
+ if examples[0].get("mask_values") is not None:
34
+ mask_values = torch.stack([example["mask_values"] for example in examples])
35
+ mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
36
+
37
+ return {
38
+ "cond_pixel_values": cond_pixel_values,
39
+ "source_pixel_values": source_pixel_values,
40
+ "pixel_values": target_pixel_values,
41
+ "text_ids_1": token_ids_clip,
42
+ "text_ids_2": token_ids_t5,
43
+ "mask_values": mask_values,
44
+ }
45
+
46
+
47
+ def _resolve_jsonl(path_str: str):
48
+ if path_str is None or str(path_str).strip() == "":
49
+ raise ValueError("train_data_jsonl is empty. Please set --train_data_jsonl to a JSON/JSONL file or a folder.")
50
+ if os.path.isdir(path_str):
51
+ files = [
52
+ os.path.join(path_str, f)
53
+ for f in os.listdir(path_str)
54
+ if f.lower().endswith((".jsonl", ".json"))
55
+ ]
56
+ if not files:
57
+ raise ValueError(f"No .json or .jsonl files found under directory: {path_str}")
58
+ return {"train": sorted(files)}
59
+ if not os.path.exists(path_str):
60
+ raise FileNotFoundError(f"train_data_jsonl not found: {path_str}")
61
+ return {"train": [path_str]}
62
+
63
+
64
+ def _tokenize(tokenizers, caption: str):
65
+ tokenizer_clip = tokenizers[0]
66
+ tokenizer_t5 = tokenizers[1]
67
+ text_inputs_clip = tokenizer_clip(
68
+ [caption], padding="max_length", max_length=77, truncation=True, return_tensors="pt"
69
+ )
70
+ text_inputs_t5 = tokenizer_t5(
71
+ [caption], padding="max_length", max_length=128, truncation=True, return_tensors="pt"
72
+ )
73
+ return text_inputs_clip.input_ids[0], text_inputs_t5.input_ids[0]
74
+
75
+
76
+ def _prepend_caption(caption: str) -> str:
77
+ """Prepend instruction and keep only instruction with 20% prob."""
78
+ instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary."
79
+ if random.random() < 0.2:
80
+ return instruction
81
+ caption = caption or ""
82
+ if caption.strip():
83
+ return f"{instruction} {caption.strip()}"
84
+ return instruction
85
+
86
+
87
+ def _color_augment(pil_img: Image.Image) -> Image.Image:
88
+ brightness = random.uniform(0.8, 1.2)
89
+ contrast = random.uniform(0.8, 1.2)
90
+ saturation = random.uniform(0.8, 1.2)
91
+ hue = random.uniform(-0.05, 0.05)
92
+ img = TF.adjust_brightness(pil_img, brightness)
93
+ img = TF.adjust_contrast(img, contrast)
94
+ img = TF.adjust_saturation(img, saturation)
95
+ img = TF.adjust_hue(img, hue)
96
+ return img
97
+
98
+
99
+ def _dilate_mask(mask_bin: np.ndarray, min_px: int = 5, max_px: int = 100) -> np.ndarray:
100
+ """Grow binary mask by a random radius in [min_px, max_px]. Expects values {0,1}."""
101
+ import cv2
102
+ radius = int(max(min_px, min(max_px, random.randint(min_px, max_px))))
103
+ if radius <= 0:
104
+ return mask_bin.astype(np.uint8)
105
+ ksize = 2 * radius + 1
106
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
107
+ grown = cv2.dilate(mask_bin.astype(np.uint8), kernel, iterations=1)
108
+ return (grown > 0).astype(np.uint8)
109
+
110
+
111
+ def _random_point_inside_mask(mask_bin: np.ndarray) -> tuple:
112
+ ys, xs = np.where(mask_bin > 0)
113
+ if len(xs) == 0:
114
+ h, w = mask_bin.shape
115
+ return w // 2, h // 2
116
+ idx = random.randrange(len(xs))
117
+ return int(xs[idx]), int(ys[idx])
118
+
119
+
120
+ def _bbox_containing_mask(mask_bin: np.ndarray, img_w: int, img_h: int) -> tuple:
121
+ ys, xs = np.where(mask_bin > 0)
122
+ if len(xs) == 0:
123
+ return 0, 0, img_w - 1, img_h - 1
124
+ x1, x2 = int(xs.min()), int(xs.max())
125
+ y1, y2 = int(ys.min()), int(ys.max())
126
+ # Random padding
127
+ max_pad = int(0.25 * min(img_w, img_h))
128
+ pad_x1 = random.randint(0, max_pad)
129
+ pad_x2 = random.randint(0, max_pad)
130
+ pad_y1 = random.randint(0, max_pad)
131
+ pad_y2 = random.randint(0, max_pad)
132
+ x1 = max(0, x1 - pad_x1)
133
+ y1 = max(0, y1 - pad_y1)
134
+ x2 = min(img_w - 1, x2 + pad_x2)
135
+ y2 = min(img_h - 1, y2 + pad_y2)
136
+ return x1, y1, x2, y2
137
+
138
+
139
+ def _constrained_random_mask(mask_bin: np.ndarray, image_h: int, image_w: int, aug_prob: float = 0.7) -> np.ndarray:
140
+ """Generate random mask whose box contains or starts in m_p, and brush strokes start inside m_p.
141
+ Returns binary 0/1 array of shape (H,W).
142
+ """
143
+ import cv2
144
+ if random.random() >= aug_prob:
145
+ return np.zeros((image_h, image_w), dtype=np.uint8)
146
+
147
+ # Scale similar to reference
148
+ ref_size = 1024
149
+ scale_factor = max(1.0, min(image_h, image_w) / float(ref_size))
150
+
151
+ out = np.zeros((image_h, image_w), dtype=np.uint8)
152
+
153
+ # Choose exactly one augmentation: bbox OR stroke
154
+ if random.random() < 0.2:
155
+ # BBox augmentation: draw N boxes (randomized), first box often contains mask
156
+ num_boxes = random.randint(1, 6)
157
+ for b in range(num_boxes):
158
+ if b == 0 and random.random() < 0.5:
159
+ x1, y1, x2, y2 = _bbox_containing_mask(mask_bin, image_w, image_h)
160
+ else:
161
+ sx, sy = _random_point_inside_mask(mask_bin)
162
+ max_w = int(500 * scale_factor)
163
+ min_w = int(100 * scale_factor)
164
+ bw = random.randint(max(1, min_w), max(2, max_w))
165
+ bh = random.randint(max(1, min_w), max(2, max_w))
166
+ x1 = max(0, sx - random.randint(0, bw))
167
+ y1 = max(0, sy - random.randint(0, bh))
168
+ x2 = min(image_w - 1, x1 + bw)
169
+ y2 = min(image_h - 1, y1 + bh)
170
+ out[y1:y2 + 1, x1:x2 + 1] = 1
171
+ else:
172
+ # Stroke augmentation: draw N strokes starting inside mask
173
+ num_strokes = random.randint(1, 6)
174
+ for _ in range(num_strokes):
175
+ num_points = random.randint(10, 30)
176
+ stroke_width = random.randint(max(1, int(100 * scale_factor)), max(2, int(400 * scale_factor)))
177
+ max_offset = max(1, int(100 * scale_factor))
178
+ start_x, start_y = _random_point_inside_mask(mask_bin)
179
+ px, py = start_x, start_y
180
+ for _ in range(num_points):
181
+ dx = random.randint(-max_offset, max_offset)
182
+ dy = random.randint(-max_offset, max_offset)
183
+ nx = int(np.clip(px + dx, 0, image_w - 1))
184
+ ny = int(np.clip(py + dy, 0, image_h - 1))
185
+ cv2.line(out, (px, py), (nx, ny), 1, stroke_width)
186
+ px, py = nx, ny
187
+
188
+ return (out > 0).astype(np.uint8)
189
+
190
+
191
+ def make_placement_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
192
+ """
193
+ Dataset for JSONL with fields:
194
+ - generated_image_path: relative to base_dir (target image with object)
195
+ - mask_path: relative to base_dir (mask of object)
196
+ - generated_width, generated_height: image dimensions
197
+ - final_prompt: caption
198
+ - relight_images: list of {mode, path} for relighted versions
199
+
200
+ source image construction:
201
+ - background is target_image with a hole punched by grown mask
202
+ - foreground is randomly selected from relight_images with weights
203
+ - includes perspective transformation (moved from interactive dataset)
204
+
205
+ Args:
206
+ base_dir: Base directory for resolving relative paths. If None, uses args.placement_base_dir.
207
+ """
208
+ if base_dir is None:
209
+ base_dir = getattr(args, "placement_base_dir")
210
+
211
+ data_files = _resolve_jsonl(getattr(args, "placement_data_jsonl", None))
212
+ file_paths = data_files.get("train", [])
213
+ records = []
214
+ for p in file_paths:
215
+ with open(p, "r", encoding="utf-8") as f:
216
+ for line in f:
217
+ line = line.strip()
218
+ if not line:
219
+ continue
220
+ try:
221
+ obj = json.loads(line)
222
+ except Exception:
223
+ try:
224
+ obj = json.loads(line.rstrip(","))
225
+ except Exception:
226
+ continue
227
+ # Keep only fields we need
228
+ pruned = {
229
+ "generated_image_path": obj.get("generated_image_path"),
230
+ "mask_path": obj.get("mask_path"),
231
+ "generated_width": obj.get("generated_width"),
232
+ "generated_height": obj.get("generated_height"),
233
+ "final_prompt": obj.get("final_prompt"),
234
+ "relight_images": obj.get("relight_images"),
235
+ }
236
+ records.append(pruned)
237
+
238
+ size = int(getattr(args, "cond_size", 512))
239
+
240
+ to_tensor_and_norm = transforms.Compose([
241
+ transforms.ToTensor(),
242
+ transforms.Normalize([0.5], [0.5]),
243
+ ])
244
+
245
+ class PlacementDataset(torch.utils.data.Dataset):
246
+ def __init__(self, hf_ds, base_dir):
247
+ self.ds = hf_ds
248
+ self.base_dir = base_dir
249
+ def __len__(self):
250
+ # Triplicate sampling per record
251
+ return len(self.ds)
252
+ def __getitem__(self, idx):
253
+ rec = self.ds[idx % len(self.ds)]
254
+
255
+ t_rel = rec.get("generated_image_path", "")
256
+ m_rel = rec.get("mask_path", "")
257
+
258
+ # Both are relative paths
259
+ t_p = os.path.join(self.base_dir, t_rel)
260
+ m_p = os.path.join(self.base_dir, m_rel)
261
+
262
+ import cv2
263
+ mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
264
+ if mask_loaded is None:
265
+ raise ValueError(f"Failed to read mask: {m_p}")
266
+
267
+ tgt_img = Image.open(t_p).convert("RGB")
268
+
269
+ fw = int(rec.get("generated_width", tgt_img.width))
270
+ fh = int(rec.get("generated_height", tgt_img.height))
271
+ tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
272
+ mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
273
+
274
+ target_tensor = to_tensor_and_norm(tgt_img)
275
+
276
+ # Binary mask at final_size
277
+ mask_np = np.array(mask_img)
278
+ mask_bin = (mask_np > 127).astype(np.uint8)
279
+
280
+ # 1) Grow mask by random 50-100 pixels
281
+ grown_mask = _dilate_mask(mask_bin, 50, 200)
282
+
283
+ # 2) Optional random augmentation mask constrained by mask
284
+ rand_mask = _constrained_random_mask(mask_bin, fh, fw, 7)
285
+
286
+ # 3) Final union mask
287
+ union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
288
+ tgt_np = np.array(tgt_img)
289
+
290
+ # Helper: choose relighted image from relight_images with weights
291
+ def _choose_relight_image(rec, width, height):
292
+ relight_list = rec.get("relight_images") or []
293
+ # Build map mode -> path
294
+ mode_to_path = {}
295
+ for it in relight_list:
296
+ try:
297
+ mode = str(it.get("mode", "")).strip().lower()
298
+ path = it.get("path")
299
+ except Exception:
300
+ continue
301
+ if not mode or not path:
302
+ continue
303
+ mode_to_path[mode] = path
304
+
305
+ weighted_order = [
306
+ ("grayscale", 0.5),
307
+ ("low", 0.3),
308
+ ("high", 0.2),
309
+ ]
310
+
311
+ # Filter to available
312
+ available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
313
+ chosen_path = None
314
+ if available:
315
+ rnd = random.random()
316
+ cum = 0.0
317
+ total_w = sum(w for _, w in available)
318
+ for m, w in available:
319
+ cum += w / total_w
320
+ if rnd <= cum:
321
+ chosen_path = mode_to_path.get(m)
322
+ break
323
+ if chosen_path is None:
324
+ chosen_path = mode_to_path.get(available[-1][0])
325
+ else:
326
+ # Fallback to any provided path
327
+ if mode_to_path:
328
+ chosen_path = next(iter(mode_to_path.values()))
329
+
330
+ # Open chosen image
331
+ if chosen_path is not None:
332
+ try:
333
+ # Paths are relative to base_dir
334
+ open_path = os.path.join(self.base_dir, chosen_path)
335
+ img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
336
+ return img
337
+ except Exception:
338
+ pass
339
+
340
+ # Fallback: return target image
341
+ return Image.open(t_p).convert("RGB").resize((width, height), resample=Image.BILINEAR)
342
+
343
+ # Choose base image with probabilities:
344
+ # 20%: original target, 20%: color augment(target), 60%: relight augment
345
+ rsel = random.random()
346
+ if rsel < 0.2:
347
+ base_img = tgt_img
348
+ elif rsel < 0.4:
349
+ base_img = _color_augment(tgt_img)
350
+ else:
351
+ base_img = _choose_relight_image(rec, fw, fh)
352
+ base_np = np.array(base_img)
353
+ fore_np = base_np.copy()
354
+
355
+ # Random perspective augmentation (50%): apply to foreground ROI (mask bbox) and its mask only
356
+ perspective_applied = False
357
+ roi_update = None
358
+ paste_mask_bool = mask_bin.astype(bool)
359
+ if random.random() < 0.5:
360
+ try:
361
+ import cv2
362
+ ys, xs = np.where(mask_bin > 0)
363
+ if len(xs) > 0 and len(ys) > 0:
364
+ x1, x2 = int(xs.min()), int(xs.max())
365
+ y1, y2 = int(ys.min()), int(ys.max())
366
+ if x2 > x1 and y2 > y1:
367
+ roi = base_np[y1:y2 + 1, x1:x2 + 1]
368
+ roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
369
+ bh, bw = roi.shape[:2]
370
+ # Random perturbation relative to ROI size
371
+ max_ratio = random.uniform(0.1, 0.3)
372
+ dx = bw * max_ratio
373
+ dy = bh * max_ratio
374
+ src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
375
+ dst = np.float32([
376
+ [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
377
+ [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
378
+ [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
379
+ [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
380
+ ])
381
+ M = cv2.getPerspectiveTransform(src, dst)
382
+ warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
383
+ warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
384
+ # Build a fresh foreground canvas
385
+ fore_np = np.zeros_like(base_np)
386
+ h_warp, w_warp = warped_mask_roi.shape
387
+ y2c = y1 + h_warp
388
+ x2c = x1 + w_warp
389
+ fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
390
+ paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
391
+ paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
392
+ roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
393
+ perspective_applied = True
394
+ except Exception:
395
+ perspective_applied = False
396
+ paste_mask_bool = mask_bin.astype(bool)
397
+ fore_np = base_np
398
+
399
+ # Optional: simulate resolution artifacts
400
+ if random.random() < 0.7:
401
+ ys, xs = np.where(paste_mask_bool)
402
+ if len(xs) > 0 and len(ys) > 0:
403
+ x1, x2 = int(xs.min()), int(xs.max())
404
+ y1, y2 = int(ys.min()), int(ys.max())
405
+ if x2 > x1 and y2 > y1:
406
+ crop = fore_np[y1:y2 + 1, x1:x2 + 1]
407
+ ch, cw = crop.shape[:2]
408
+ scale = random.uniform(0.15, 0.9)
409
+ dw = max(1, int(cw * scale))
410
+ dh = max(1, int(ch * scale))
411
+ try:
412
+ small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
413
+ back = small.resize((cw, ch), Image.BICUBIC)
414
+ crop_blurred = np.array(back).astype(np.uint8)
415
+ fore_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
416
+ except Exception:
417
+ pass
418
+
419
+ # Build masked target and compose
420
+ union_mask_for_target = union_mask.copy()
421
+ if roi_update is not None:
422
+ rx, ry, rh, rw, warped_mask_roi = roi_update
423
+ um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
424
+ union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
425
+ masked_t_np = tgt_np.copy()
426
+ masked_t_np[union_mask_for_target.astype(bool)] = 255
427
+ composed_np = masked_t_np.copy()
428
+ m_fore = paste_mask_bool
429
+ composed_np[m_fore] = fore_np[m_fore]
430
+
431
+ # Build tensors
432
+ source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
433
+ mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
434
+
435
+ # Caption: prepend instruction
436
+ cap_orig = rec.get("final_prompt", "") or ""
437
+ # Handle list format in final_prompt
438
+ if isinstance(cap_orig, list) and len(cap_orig) > 0:
439
+ cap_orig = cap_orig[0] if isinstance(cap_orig[0], str) else str(cap_orig[0])
440
+ cap = _prepend_caption(cap_orig)
441
+ if perspective_applied:
442
+ cap = f"{cap} Fix the perspective if necessary."
443
+ ids1, ids2 = _tokenize(tokenizers, cap)
444
+
445
+ return {
446
+ "source_pixel_values": source_tensor,
447
+ "pixel_values": target_tensor,
448
+ "token_ids_clip": ids1,
449
+ "token_ids_t5": ids2,
450
+ "mask_values": mask_tensor,
451
+ }
452
+
453
+ return PlacementDataset(records, base_dir)
454
+
455
+
456
+ def make_interactive_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
457
+ """
458
+ Dataset for JSONL with fields:
459
+ - input_path: relative to base_dir (target image)
460
+ - output_path: absolute path to image with foreground
461
+ - mask_after_completion: absolute path to mask
462
+ - img_width, img_height: resize dimensions
463
+ - prompt: caption
464
+
465
+ source image construction:
466
+ - background is target_image with a hole punched by grown `mask_after_completion`
467
+ - foreground is from `output_path` image, pasted using original `mask_after_completion`
468
+ - 50% chance to color augment the foreground source
469
+ - NO perspective transform (moved to placement dataset)
470
+
471
+ Args:
472
+ base_dir: Base directory for resolving relative paths. If None, uses args.interactive_base_dir.
473
+ """
474
+ if base_dir is None:
475
+ base_dir = getattr(args, "interactive_base_dir")
476
+
477
+ data_files = _resolve_jsonl(getattr(args, "train_data_jsonl", None))
478
+ file_paths = data_files.get("train", [])
479
+ records = []
480
+ for p in file_paths:
481
+ with open(p, "r", encoding="utf-8") as f:
482
+ for line in f:
483
+ line = line.strip()
484
+ if not line:
485
+ continue
486
+ try:
487
+ obj = json.loads(line)
488
+ except Exception:
489
+ # Best-effort: strip any trailing commas and retry
490
+ try:
491
+ obj = json.loads(line.rstrip(","))
492
+ except Exception:
493
+ continue
494
+ # Keep only fields we actually need to avoid schema issues
495
+ pruned = {
496
+ "input_path": obj.get("input_path"),
497
+ "output_path": obj.get("output_path"),
498
+ "mask_after_completion": obj.get("mask_after_completion"),
499
+ "img_width": obj.get("img_width"),
500
+ "img_height": obj.get("img_height"),
501
+ "prompt": obj.get("prompt"),
502
+ # New optional fields
503
+ "generated_images": obj.get("generated_images"),
504
+ "positive_prompt_used": obj.get("positive_prompt_used"),
505
+ "negative_caption_used": obj.get("negative_caption_used"),
506
+ }
507
+ records.append(pruned)
508
+
509
+ size = int(getattr(args, "cond_size", 512))
510
+
511
+ to_tensor_and_norm = transforms.Compose([
512
+ transforms.ToTensor(),
513
+ transforms.Normalize([0.5], [0.5]),
514
+ ])
515
+
516
+ class SubjectsDataset(torch.utils.data.Dataset):
517
+ def __init__(self, hf_ds, base_dir):
518
+ self.ds = hf_ds
519
+ self.base_dir = base_dir
520
+ def __len__(self):
521
+ # Triplicate sampling per record
522
+ return len(self.ds)
523
+ def __getitem__(self, idx):
524
+ rec = self.ds[idx % len(self.ds)]
525
+
526
+ t_rel = rec.get("input_path", "")
527
+ foreground_p = rec.get("output_path", "")
528
+ m_abs = rec.get("mask_after_completion", "")
529
+
530
+ if not os.path.isabs(m_abs):
531
+ raise ValueError("mask_after_completion must be absolute")
532
+ if not os.path.isabs(foreground_p):
533
+ raise ValueError("output_path must be absolute")
534
+
535
+ t_p = os.path.join(self.base_dir, t_rel)
536
+ m_p = m_abs
537
+
538
+ import cv2
539
+ mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
540
+ if mask_loaded is None:
541
+ raise ValueError(f"Failed to read mask: {m_p}")
542
+
543
+ tgt_img = Image.open(t_p).convert("RGB")
544
+ foreground_source_img = Image.open(foreground_p).convert("RGB")
545
+
546
+ fw = int(rec.get("img_width", tgt_img.width))
547
+ fh = int(rec.get("img_height", tgt_img.height))
548
+ tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
549
+ foreground_source_img = foreground_source_img.resize((fw, fh), resample=Image.BILINEAR)
550
+ mask_img = Image.fromarray(mask_loaded.astype(np.uint8)).convert("L").resize((fw, fh), Image.NEAREST)
551
+
552
+ # Ensure PIL images to tensors for outputs based on new logic later
553
+ target_tensor = to_tensor_and_norm(tgt_img)
554
+
555
+ # Binary mask at final_size
556
+ mask_np = np.array(mask_img)
557
+ mask_bin = (mask_np > 127).astype(np.uint8)
558
+
559
+ # 1) Grow m_p by random 50-100 pixels
560
+ grown_mask = _dilate_mask(mask_bin, 50, 200)
561
+
562
+ # 2) Optional random augmentation mask constrained by m_p
563
+ rand_mask = _constrained_random_mask(mask_bin, fh, fw, aug_prob=0.7)
564
+
565
+ # 3) Final union mask
566
+ union_mask = np.clip(grown_mask | rand_mask, 0, 1).astype(np.uint8)
567
+ tgt_np = np.array(tgt_img)
568
+
569
+ # Helper: choose relighted image from generated_images with weights
570
+ def _choose_relight_image(rec, default_img, width, height):
571
+ gen_list = rec.get("generated_images") or []
572
+ # Build map mode -> path
573
+ mode_to_path = {}
574
+ for it in gen_list:
575
+ try:
576
+ mode = str(it.get("mode", "")).strip().lower()
577
+ path = it.get("path")
578
+ except Exception:
579
+ continue
580
+ if not mode or not path:
581
+ continue
582
+ mode_to_path[mode] = path
583
+
584
+ # Weighted selection among available modes
585
+ weighted_order = [
586
+ ("grayscale", 0.5),
587
+ ("low", 0.3),
588
+ ("high", 0.2),
589
+ ]
590
+
591
+ # Filter to available
592
+ available = [(m, w) for (m, w) in weighted_order if m in mode_to_path]
593
+ chosen_path = None
594
+ if available:
595
+ rnd = random.random()
596
+ cum = 0.0
597
+ total_w = sum(w for _, w in available)
598
+ for m, w in available:
599
+ cum += w / total_w
600
+ if rnd <= cum:
601
+ chosen_path = mode_to_path.get(m)
602
+ break
603
+ if chosen_path is None:
604
+ chosen_path = mode_to_path.get(available[-1][0])
605
+ else:
606
+ # Fallback to any provided path
607
+ if mode_to_path:
608
+ chosen_path = next(iter(mode_to_path.values()))
609
+
610
+ # Open chosen image
611
+ if chosen_path is not None:
612
+ try:
613
+ open_path = chosen_path
614
+ # generated paths are typically absolute; if not, use as-is
615
+ img = Image.open(open_path).convert("RGB").resize((width, height), resample=Image.BILINEAR)
616
+ return img
617
+ except Exception:
618
+ pass
619
+
620
+ return default_img
621
+
622
+ # 5) Choose base image with probabilities:
623
+ # 20%: original, 20%: color augment(original), 60%: relight augment
624
+ rsel = random.random()
625
+ if rsel < 0.2:
626
+ base_img = foreground_source_img
627
+ elif rsel < 0.4:
628
+ base_img = _color_augment(foreground_source_img)
629
+ else:
630
+ base_img = _choose_relight_image(rec, foreground_source_img, fw, fh)
631
+ base_np = np.array(base_img)
632
+
633
+ # 5.1) Random perspective augmentation (20%): apply to foreground ROI (mask bbox) and its mask only
634
+ perspective_applied = False
635
+ roi_update = None
636
+ paste_mask_bool = mask_bin.astype(bool)
637
+ if random.random() < 0.5:
638
+ try:
639
+ import cv2
640
+ ys, xs = np.where(mask_bin > 0)
641
+ if len(xs) > 0 and len(ys) > 0:
642
+ x1, x2 = int(xs.min()), int(xs.max())
643
+ y1, y2 = int(ys.min()), int(ys.max())
644
+ if x2 > x1 and y2 > y1:
645
+ roi = base_np[y1:y2 + 1, x1:x2 + 1]
646
+ roi_mask = mask_bin[y1:y2 + 1, x1:x2 + 1]
647
+ bh, bw = roi.shape[:2]
648
+ # Random perturbation relative to ROI size
649
+ max_ratio = random.uniform(0.1, 0.3)
650
+ dx = bw * max_ratio
651
+ dy = bh * max_ratio
652
+ src = np.float32([[0, 0], [bw - 1, 0], [bw - 1, bh - 1], [0, bh - 1]])
653
+ dst = np.float32([
654
+ [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
655
+ [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(random.uniform(-dy, dy), 0, bh - 1)],
656
+ [np.clip(bw - 1 + random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
657
+ [np.clip(random.uniform(-dx, dx), 0, bw - 1), np.clip(bh - 1 + random.uniform(-dy, dy), 0, bh - 1)],
658
+ ])
659
+ M = cv2.getPerspectiveTransform(src, dst)
660
+ warped_roi = cv2.warpPerspective(roi, M, (bw, bh), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT101)
661
+ warped_mask_roi = cv2.warpPerspective((roi_mask.astype(np.uint8) * 255), M, (bw, bh), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) > 127
662
+ # Build a fresh foreground canvas
663
+ fore_np = np.zeros_like(base_np)
664
+ h_warp, w_warp = warped_mask_roi.shape
665
+ y2c = y1 + h_warp
666
+ x2c = x1 + w_warp
667
+ fore_np[y1:y2c, x1:x2c][warped_mask_roi] = warped_roi[warped_mask_roi]
668
+ paste_mask_bool = np.zeros_like(mask_bin, dtype=bool)
669
+ paste_mask_bool[y1:y2c, x1:x2c] = warped_mask_roi
670
+ roi_update = (x1, y1, h_warp, w_warp, warped_mask_roi)
671
+ perspective_applied = True
672
+ base_np = fore_np
673
+ except Exception:
674
+ perspective_applied = False
675
+ paste_mask_bool = mask_bin.astype(bool)
676
+
677
+ # Optional: simulate cut-out foregrounds coming from different resolutions by
678
+ # downscaling the masked foreground region and upscaling back to original size.
679
+ # This introduces realistic blur/aliasing seen in real inpaint workflows.
680
+ if random.random() < 0.7:
681
+ ys, xs = np.where(mask_bin > 0)
682
+ if len(xs) > 0 and len(ys) > 0:
683
+ x1, x2 = int(xs.min()), int(xs.max())
684
+ y1, y2 = int(ys.min()), int(ys.max())
685
+ # Ensure valid box
686
+ if x2 > x1 and y2 > y1:
687
+ crop = base_np[y1:y2 + 1, x1:x2 + 1]
688
+ ch, cw = crop.shape[:2]
689
+ scale = random.uniform(0.2, 0.9)
690
+ dw = max(1, int(cw * scale))
691
+ dh = max(1, int(ch * scale))
692
+ try:
693
+ small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
694
+ back = small.resize((cw, ch), Image.BICUBIC)
695
+ crop_blurred = np.array(back).astype(np.uint8)
696
+ base_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
697
+ except Exception:
698
+ # Fallback: skip if resize fails
699
+ pass
700
+
701
+ # 6) Build masked target using (possibly) updated union mask; then paste
702
+ union_mask_for_target = union_mask.copy()
703
+ if roi_update is not None:
704
+ rx, ry, rh, rw, warped_mask_roi = roi_update
705
+ # Ensure union covers the warped foreground area inside ROI using warped shape
706
+ um_roi = union_mask_for_target[ry:ry + rh, rx:rx + rw]
707
+ union_mask_for_target[ry:ry + rh, rx:rx + rw] = np.clip(um_roi | warped_mask_roi.astype(np.uint8), 0, 1)
708
+ masked_t_np = tgt_np.copy()
709
+ masked_t_np[union_mask_for_target.astype(bool)] = 255
710
+ composed_np = masked_t_np.copy()
711
+ m_fore = paste_mask_bool
712
+ composed_np[m_fore] = base_np[m_fore]
713
+
714
+ # 7) Build tensors
715
+ source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
716
+ mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
717
+
718
+ # 8) Caption: prepend instruction, 20% keep only instruction
719
+ cap_orig = rec.get("prompt", "") or ""
720
+ cap = _prepend_caption(cap_orig)
721
+ if perspective_applied:
722
+ cap = f"{cap} Fix the perspective if necessary."
723
+ ids1, ids2 = _tokenize(tokenizers, cap)
724
+
725
+ return {
726
+ "source_pixel_values": source_tensor,
727
+ "pixel_values": target_tensor,
728
+ "token_ids_clip": ids1,
729
+ "token_ids_t5": ids2,
730
+ "mask_values": mask_tensor,
731
+ }
732
+
733
+ return SubjectsDataset(records, base_dir)
734
+
735
+
736
+ def make_pexels_dataset_subjects(args, tokenizers, accelerator=None, base_dir=None):
737
+ """
738
+ Dataset for JSONL with fields:
739
+ - input_path: relative to base_dir (target image)
740
+ - output_path: relative to relight_base_dir (relighted image)
741
+ - final_size: {width, height} resize applied
742
+ - caption: text caption
743
+
744
+ Modified to use segmentation maps instead of raw_mask_path.
745
+ Randomly selects 2-5 segments from segmentation map, applies augmentation to each, and takes union.
746
+ This simulates multiple foreground objects being placed like a puzzle.
747
+
748
+ Each segment independently uses: 20% original, 20% color_augment, 60% relighted image.
749
+
750
+ Args:
751
+ base_dir: Base directory for resolving relative paths. If None, uses args.pexels_base_dir.
752
+ """
753
+ if base_dir is None:
754
+ base_dir = getattr(args, "pexels_base_dir", "/mnt/robby-b1/common/datasets")
755
+
756
+ relight_base_dir = getattr(args, "pexels_relight_base_dir", "/robby/share/Editing/lzc/data/relight_outputs")
757
+ seg_base_dir = getattr(args, "seg_base_dir", "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182")
758
+
759
+ data_files = _resolve_jsonl(getattr(args, "pexels_data_jsonl", None))
760
+ file_paths = data_files.get("train", [])
761
+ records = []
762
+ for p in file_paths:
763
+ with open(p, "r", encoding="utf-8") as f:
764
+ for line in f:
765
+ line = line.strip()
766
+ if not line:
767
+ continue
768
+ try:
769
+ obj = json.loads(line)
770
+ except Exception:
771
+ try:
772
+ obj = json.loads(line.rstrip(","))
773
+ except Exception:
774
+ continue
775
+ pruned = {
776
+ "input_path": obj.get("input_path"),
777
+ "output_path": obj.get("output_path"),
778
+ "final_size": obj.get("final_size"),
779
+ "caption": obj.get("caption"),
780
+ }
781
+ records.append(pruned)
782
+
783
+ to_tensor_and_norm = transforms.Compose([
784
+ transforms.ToTensor(),
785
+ transforms.Normalize([0.5], [0.5]),
786
+ ])
787
+
788
+ class PexelsDataset(torch.utils.data.Dataset):
789
+ def __init__(self, hf_ds, base_dir, relight_base_dir, seg_base_dir):
790
+ self.ds = hf_ds
791
+ self.base_dir = base_dir
792
+ self.relight_base_dir = relight_base_dir
793
+ self.seg_base_dir = seg_base_dir
794
+
795
+ def __len__(self):
796
+ return len(self.ds)
797
+
798
+ def _extract_hash_from_filename(self, filename: str) -> str:
799
+ """Extract hash from input filename for segmentation map lookup."""
800
+ stem = os.path.splitext(os.path.basename(filename))[0]
801
+ if '_' in stem:
802
+ parts = stem.split('_')
803
+ return parts[-1]
804
+ return stem
805
+
806
+ def _build_segmap_path(self, input_filename: str) -> str:
807
+ """Build path to segmentation map from input filename."""
808
+ hash_id = self._extract_hash_from_filename(input_filename)
809
+ return os.path.join(self.seg_base_dir, f"{hash_id}_map.png")
810
+
811
+ def _load_segmap_uint32(self, seg_path: str):
812
+ """Load segmentation map as uint32 array."""
813
+ import cv2
814
+ try:
815
+ with Image.open(seg_path) as im:
816
+ if im.mode == 'P':
817
+ seg = np.array(im)
818
+ elif im.mode in ('I;16', 'I', 'L'):
819
+ seg = np.array(im)
820
+ else:
821
+ seg = np.array(im.convert('L'))
822
+ except Exception:
823
+ return None
824
+
825
+ if seg.ndim == 3:
826
+ seg = cv2.cvtColor(seg, cv2.COLOR_BGR2GRAY)
827
+ return seg.astype(np.uint32)
828
+
829
+ def _extract_multiple_segments(
830
+ self,
831
+ image_h: int,
832
+ image_w: int,
833
+ seg_path: str,
834
+ min_area_ratio: float = 0.02,
835
+ max_area_ratio: float = 0.4,
836
+ ):
837
+ """Extract 2-5 individual segment masks from segmentation map."""
838
+ import cv2
839
+ seg = self._load_segmap_uint32(seg_path)
840
+ if seg is None:
841
+ return []
842
+
843
+ if seg.shape != (image_h, image_w):
844
+ seg = cv2.resize(seg.astype(np.uint16), (image_w, image_h), interpolation=cv2.INTER_NEAREST).astype(np.uint32)
845
+
846
+ labels, counts = np.unique(seg, return_counts=True)
847
+ if labels.size == 0:
848
+ return []
849
+
850
+ # Exclude background label 0
851
+ bg_mask = labels == 0
852
+ labels = labels[~bg_mask]
853
+ counts = counts[~bg_mask]
854
+ if labels.size == 0:
855
+ return []
856
+
857
+ area = image_h * image_w
858
+ min_px = int(round(min_area_ratio * area))
859
+ max_px = int(round(max_area_ratio * area))
860
+ keep = (counts >= min_px) & (counts <= max_px)
861
+ cand_labels = labels[keep]
862
+ if cand_labels.size == 0:
863
+ return []
864
+
865
+ # Select 2-5 labels randomly
866
+ max_sel = min(5, cand_labels.size)
867
+ min_sel = min(2, cand_labels.size)
868
+ num_to_select = random.randint(min_sel, max_sel)
869
+ chosen = np.random.choice(cand_labels, size=num_to_select, replace=False)
870
+
871
+ # Create individual masks for each chosen label
872
+ individual_masks = []
873
+ for lab in chosen:
874
+ binm = (seg == int(lab)).astype(np.uint8)
875
+ # Apply opening operation to clean up mask
876
+ k = max(3, int(round(max(image_h, image_w) * 0.01)))
877
+ if k % 2 == 0:
878
+ k += 1
879
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
880
+ eroded = cv2.erode(binm, kernel, iterations=1)
881
+ opened = cv2.dilate(eroded, kernel, iterations=1)
882
+ individual_masks.append(opened)
883
+
884
+ return individual_masks
885
+
886
+ def __getitem__(self, idx):
887
+ rec = self.ds[idx % len(self.ds)]
888
+
889
+ t_rel = rec.get("input_path", "")
890
+ r_rel = rec.get("output_path", "")
891
+
892
+ t_p = os.path.join(self.base_dir, t_rel)
893
+ relight_p = os.path.join(self.relight_base_dir, r_rel)
894
+
895
+ import cv2
896
+ tgt_img = Image.open(t_p).convert("RGB")
897
+
898
+ # Load relighted image, fallback to target if not available
899
+ try:
900
+ relighted_img = Image.open(relight_p).convert("RGB")
901
+ except Exception:
902
+ relighted_img = tgt_img.copy()
903
+
904
+ final_size = rec.get("final_size", {}) or {}
905
+ fw = int(final_size.get("width", tgt_img.width))
906
+ fh = int(final_size.get("height", tgt_img.height))
907
+ tgt_img = tgt_img.resize((fw, fh), resample=Image.BILINEAR)
908
+ relighted_img = relighted_img.resize((fw, fh), resample=Image.BILINEAR)
909
+
910
+ target_tensor = to_tensor_and_norm(tgt_img)
911
+
912
+ # Build segmentation map path and extract multiple segments
913
+ input_filename = os.path.basename(t_rel)
914
+ seg_path = self._build_segmap_path(input_filename)
915
+ individual_masks = self._extract_multiple_segments(fh, fw, seg_path)
916
+
917
+ if not individual_masks:
918
+ # Fallback: create empty mask (will be handled gracefully)
919
+ union_mask = np.zeros((fh, fw), dtype=np.uint8)
920
+ individual_masks = []
921
+ else:
922
+ # Apply augmentation to each segment mask and take union
923
+ augmented_masks = []
924
+ for seg_mask in individual_masks:
925
+ # 1) Grow mask by random 50-200 pixels
926
+ grown = _dilate_mask(seg_mask, 50, 200)
927
+ # 2) Optional random augmentation mask constrained by this segment
928
+ rand_mask = _constrained_random_mask(seg_mask, fh, fw, aug_prob=0.7)
929
+ # 3) Union for this segment
930
+ seg_union = np.clip(grown | rand_mask, 0, 1).astype(np.uint8)
931
+ augmented_masks.append(seg_union)
932
+
933
+ # Take union of all augmented masks
934
+ union_mask = np.zeros((fh, fw), dtype=np.uint8)
935
+ for m in augmented_masks:
936
+ union_mask = np.clip(union_mask | m, 0, 1).astype(np.uint8)
937
+
938
+ tgt_np = np.array(tgt_img)
939
+
940
+ # Build masked target first
941
+ masked_t_np = tgt_np.copy()
942
+ masked_t_np[union_mask.astype(bool)] = 255
943
+ composed_np = masked_t_np.copy()
944
+
945
+ # Process each segment independently with different augmentations
946
+ # This simulates multiple foreground objects from different sources
947
+ for seg_mask in individual_masks:
948
+ # 1) Choose source for this segment: 20% original, 20% color_augment, 60% relighted
949
+ r = random.random()
950
+ if r < 0.2:
951
+ # Original image
952
+ seg_source_img = tgt_img
953
+ else:
954
+ seg_source_img = _color_augment(tgt_img)
955
+ # elif r < 0.4:
956
+ # # Color augmentation
957
+ # seg_source_img = _color_augment(tgt_img)
958
+ # else:
959
+ # # Relighted image
960
+ # seg_source_img = relighted_img
961
+
962
+ seg_source_np = np.array(seg_source_img)
963
+
964
+ # 2) Apply resolution augmentation to this segment's region
965
+ if random.random() < 0.7:
966
+ ys, xs = np.where(seg_mask > 0)
967
+ if len(xs) > 0 and len(ys) > 0:
968
+ x1, x2 = int(xs.min()), int(xs.max())
969
+ y1, y2 = int(ys.min()), int(ys.max())
970
+ if x2 > x1 and y2 > y1:
971
+ crop = seg_source_np[y1:y2 + 1, x1:x2 + 1]
972
+ ch, cw = crop.shape[:2]
973
+ scale = random.uniform(0.2, 0.9)
974
+ dw = max(1, int(cw * scale))
975
+ dh = max(1, int(ch * scale))
976
+ try:
977
+ small = Image.fromarray(crop.astype(np.uint8)).resize((dw, dh), Image.BICUBIC)
978
+ back = small.resize((cw, ch), Image.BICUBIC)
979
+ crop_blurred = np.array(back).astype(np.uint8)
980
+ seg_source_np[y1:y2 + 1, x1:x2 + 1] = crop_blurred
981
+ except Exception:
982
+ pass
983
+
984
+ # 3) Paste this segment onto composed image
985
+ m_fore = seg_mask.astype(bool)
986
+ composed_np[m_fore] = seg_source_np[m_fore]
987
+
988
+ # Build tensors
989
+ source_tensor = to_tensor_and_norm(Image.fromarray(composed_np.astype(np.uint8)))
990
+ mask_tensor = torch.from_numpy(union_mask.astype(np.float32)).unsqueeze(0)
991
+
992
+ # Caption: prepend instruction
993
+ cap_orig = rec.get("caption", "") or ""
994
+ cap = _prepend_caption(cap_orig)
995
+ ids1, ids2 = _tokenize(tokenizers, cap)
996
+
997
+ return {
998
+ "source_pixel_values": source_tensor,
999
+ "pixel_values": target_tensor,
1000
+ "token_ids_clip": ids1,
1001
+ "token_ids_t5": ids2,
1002
+ "mask_values": mask_tensor,
1003
+ }
1004
+
1005
+ return PexelsDataset(records, base_dir, relight_base_dir, seg_base_dir)
1006
+
1007
+
1008
+ def make_mixed_dataset(args, tokenizers, interactive_jsonl_path=None, placement_jsonl_path=None,
1009
+ pexels_jsonl_path=None, interactive_base_dir=None, placement_base_dir=None,
1010
+ pexels_base_dir=None, interactive_weight=1.0, placement_weight=1.0,
1011
+ pexels_weight=1.0, accelerator=None):
1012
+ """
1013
+ Create a mixed dataset combining interactive, placement, and pexels datasets.
1014
+
1015
+ Args:
1016
+ args: Arguments object with dataset configuration
1017
+ tokenizers: Tuple of tokenizers for text encoding
1018
+ interactive_jsonl_path: Path to interactive dataset JSONL (optional)
1019
+ placement_jsonl_path: Path to placement dataset JSONL (optional)
1020
+ pexels_jsonl_path: Path to pexels dataset JSONL (optional)
1021
+ interactive_base_dir: Base directory for interactive dataset paths (optional)
1022
+ placement_base_dir: Base directory for placement dataset paths (optional)
1023
+ pexels_base_dir: Base directory for pexels dataset paths (optional)
1024
+ interactive_weight: Sampling weight for interactive dataset (default: 1.0)
1025
+ placement_weight: Sampling weight for placement dataset (default: 1.0)
1026
+ pexels_weight: Sampling weight for pexels dataset (default: 1.0)
1027
+ accelerator: Optional accelerator object
1028
+
1029
+ Returns:
1030
+ Mixed dataset that samples from all provided datasets with specified weights
1031
+ """
1032
+ datasets = []
1033
+ dataset_names = []
1034
+ dataset_weights = []
1035
+
1036
+ # Create interactive dataset if path provided
1037
+ if interactive_jsonl_path:
1038
+ interactive_args = type('Args', (), {})()
1039
+ for k, v in vars(args).items():
1040
+ setattr(interactive_args, k, v)
1041
+ interactive_args.train_data_jsonl = interactive_jsonl_path
1042
+ if interactive_base_dir:
1043
+ interactive_args.interactive_base_dir = interactive_base_dir
1044
+ interactive_ds = make_interactive_dataset_subjects(interactive_args, tokenizers, accelerator)
1045
+ datasets.append(interactive_ds)
1046
+ dataset_names.append("interactive")
1047
+ dataset_weights.append(interactive_weight)
1048
+
1049
+ # Create placement dataset if path provided
1050
+ if placement_jsonl_path:
1051
+ placement_args = type('Args', (), {})()
1052
+ for k, v in vars(args).items():
1053
+ setattr(placement_args, k, v)
1054
+ placement_args.placement_data_jsonl = placement_jsonl_path
1055
+ if placement_base_dir:
1056
+ placement_args.placement_base_dir = placement_base_dir
1057
+ placement_ds = make_placement_dataset_subjects(placement_args, tokenizers, accelerator)
1058
+ datasets.append(placement_ds)
1059
+ dataset_names.append("placement")
1060
+ dataset_weights.append(placement_weight)
1061
+
1062
+ # Create pexels dataset if path provided
1063
+ if pexels_jsonl_path:
1064
+ pexels_args = type('Args', (), {})()
1065
+ for k, v in vars(args).items():
1066
+ setattr(pexels_args, k, v)
1067
+ pexels_args.pexels_data_jsonl = pexels_jsonl_path
1068
+ if pexels_base_dir:
1069
+ pexels_args.pexels_base_dir = pexels_base_dir
1070
+ pexels_ds = make_pexels_dataset_subjects(pexels_args, tokenizers, accelerator)
1071
+ datasets.append(pexels_ds)
1072
+ dataset_names.append("pexels")
1073
+ dataset_weights.append(pexels_weight)
1074
+
1075
+ if not datasets:
1076
+ raise ValueError("At least one dataset path must be provided")
1077
+
1078
+ if len(datasets) == 1:
1079
+ return datasets[0]
1080
+
1081
+ # Mixed dataset class with balanced sampling (based on smallest dataset)
1082
+ class MixedDataset(torch.utils.data.Dataset):
1083
+ def __init__(self, datasets, dataset_names, dataset_weights):
1084
+ self.datasets = datasets
1085
+ self.dataset_names = dataset_names
1086
+ self.lengths = [len(ds) for ds in datasets]
1087
+
1088
+ # Normalize weights
1089
+ total_weight = sum(dataset_weights)
1090
+ self.weights = [w / total_weight for w in dataset_weights]
1091
+
1092
+ # Calculate samples per dataset based on smallest dataset and weights
1093
+ # Find the minimum weighted size
1094
+ min_weighted_size = min(length / weight for length, weight in zip(self.lengths, dataset_weights))
1095
+
1096
+ # Each dataset contributes samples proportional to its weight, scaled by min_weighted_size
1097
+ self.samples_per_dataset = [int(min_weighted_size * w) for w in dataset_weights]
1098
+ self.total_length = sum(self.samples_per_dataset)
1099
+
1100
+ # Build cumulative sample counts for indexing
1101
+ self.cumsum_samples = [0]
1102
+ for count in self.samples_per_dataset:
1103
+ self.cumsum_samples.append(self.cumsum_samples[-1] + count)
1104
+
1105
+ print(f"Balanced mixed dataset created:")
1106
+ for i, name in enumerate(dataset_names):
1107
+ print(f" {name}: {self.lengths[i]} total, {self.samples_per_dataset[i]} per epoch")
1108
+ print(f" Total samples per epoch: {self.total_length}")
1109
+
1110
+ def __len__(self):
1111
+ return self.total_length
1112
+
1113
+ def __getitem__(self, idx):
1114
+ # Determine which dataset this idx belongs to
1115
+ dataset_idx = 0
1116
+ for i in range(len(self.cumsum_samples) - 1):
1117
+ if self.cumsum_samples[i] <= idx < self.cumsum_samples[i + 1]:
1118
+ dataset_idx = i
1119
+ break
1120
+
1121
+ # Randomly sample from the selected dataset (enables different samples each epoch)
1122
+ local_idx = random.randint(0, self.lengths[dataset_idx] - 1)
1123
+ sample = self.datasets[dataset_idx][local_idx]
1124
+ # Add dataset source information
1125
+ sample["dataset_source"] = self.dataset_names[dataset_idx]
1126
+ return sample
1127
+
1128
+ return MixedDataset(datasets, dataset_names, dataset_weights)
1129
+
1130
+
1131
+ def _run_test_mode(
1132
+ interactive_jsonl: str = None,
1133
+ placement_jsonl: str = None,
1134
+ pexels_jsonl: str = None,
1135
+ interactive_base_dir: str = None,
1136
+ placement_base_dir: str = None,
1137
+ pexels_base_dir: str = None,
1138
+ pexels_relight_base_dir: str = None,
1139
+ seg_base_dir: str = None,
1140
+ interactive_weight: float = 1.0,
1141
+ placement_weight: float = 1.0,
1142
+ pexels_weight: float = 1.0,
1143
+ output_dir: str = "test_output",
1144
+ num_samples: int = 100
1145
+ ):
1146
+ """Test dataset by saving samples with source labels.
1147
+
1148
+ Args:
1149
+ interactive_jsonl: Path to interactive dataset JSONL (optional)
1150
+ placement_jsonl: Path to placement dataset JSONL (optional)
1151
+ pexels_jsonl: Path to pexels dataset JSONL (optional)
1152
+ interactive_base_dir: Base directory for interactive dataset
1153
+ placement_base_dir: Base directory for placement dataset
1154
+ pexels_base_dir: Base directory for pexels dataset
1155
+ pexels_relight_base_dir: Base directory for pexels relighted images
1156
+ seg_base_dir: Directory containing segmentation maps for pexels dataset
1157
+ interactive_weight: Sampling weight for interactive dataset (default: 1.0)
1158
+ placement_weight: Sampling weight for placement dataset (default: 1.0)
1159
+ pexels_weight: Sampling weight for pexels dataset (default: 1.0)
1160
+ output_dir: Output directory for test images
1161
+ num_samples: Number of samples to save
1162
+ """
1163
+ if not interactive_jsonl and not placement_jsonl and not pexels_jsonl:
1164
+ raise ValueError("At least one dataset path must be provided")
1165
+
1166
+ os.makedirs(output_dir, exist_ok=True)
1167
+
1168
+ # Create dummy tokenizers for testing
1169
+ class DummyTokenizer:
1170
+ def __call__(self, text, **kwargs):
1171
+ class Result:
1172
+ input_ids = torch.zeros(1, 77, dtype=torch.long)
1173
+ return Result()
1174
+
1175
+ tokenizers = (DummyTokenizer(), DummyTokenizer())
1176
+
1177
+ # Create args object
1178
+ class Args:
1179
+ cond_size = 512
1180
+
1181
+ args = Args()
1182
+ args.train_data_jsonl = interactive_jsonl
1183
+ args.placement_data_jsonl = placement_jsonl
1184
+ args.pexels_data_jsonl = pexels_jsonl
1185
+ args.interactive_base_dir = interactive_base_dir
1186
+ args.placement_base_dir = placement_base_dir
1187
+ args.pexels_base_dir = pexels_base_dir
1188
+ args.pexels_relight_base_dir = pexels_relight_base_dir if pexels_relight_base_dir else "/robby/share/Editing/lzc/data/relight_outputs"
1189
+ args.seg_base_dir = seg_base_dir if seg_base_dir else "/mnt/robby-b1/common/datasets/pexels-mask/20190515093182"
1190
+
1191
+ # Create dataset (single or mixed)
1192
+ try:
1193
+ # Count how many datasets are provided
1194
+ num_datasets = sum([bool(interactive_jsonl), bool(placement_jsonl), bool(pexels_jsonl)])
1195
+
1196
+ if num_datasets > 1:
1197
+ dataset = make_mixed_dataset(
1198
+ args, tokenizers,
1199
+ interactive_jsonl_path=interactive_jsonl,
1200
+ placement_jsonl_path=placement_jsonl,
1201
+ pexels_jsonl_path=pexels_jsonl,
1202
+ interactive_base_dir=args.interactive_base_dir,
1203
+ placement_base_dir=args.placement_base_dir,
1204
+ pexels_base_dir=args.pexels_base_dir,
1205
+ interactive_weight=interactive_weight,
1206
+ placement_weight=placement_weight,
1207
+ pexels_weight=pexels_weight
1208
+ )
1209
+ print(f"Created mixed dataset with {len(dataset)} samples")
1210
+ weights_str = []
1211
+ if interactive_jsonl:
1212
+ weights_str.append(f"Interactive: {interactive_weight:.2f}")
1213
+ if placement_jsonl:
1214
+ weights_str.append(f"Placement: {placement_weight:.2f}")
1215
+ if pexels_jsonl:
1216
+ weights_str.append(f"Pexels: {pexels_weight:.2f}")
1217
+ print(f"Sampling weights - {', '.join(weights_str)}")
1218
+ elif pexels_jsonl:
1219
+ dataset = make_pexels_dataset_subjects(args, tokenizers, base_dir=pexels_base_dir)
1220
+ print(f"Created pexels dataset with {len(dataset)} samples")
1221
+ elif placement_jsonl:
1222
+ dataset = make_placement_dataset_subjects(args, tokenizers, base_dir=args.placement_base_dir)
1223
+ print(f"Created placement dataset with {len(dataset)} samples")
1224
+ else:
1225
+ dataset = make_interactive_dataset_subjects(args, tokenizers, base_dir=args.interactive_base_dir)
1226
+ print(f"Created interactive dataset with {len(dataset)} samples")
1227
+ except Exception as e:
1228
+ print(f"Failed to create dataset: {e}")
1229
+ import traceback
1230
+ traceback.print_exc()
1231
+ return
1232
+
1233
+ # Sample and save
1234
+ saved = 0
1235
+ counts = {}
1236
+
1237
+ for attempt in range(min(num_samples * 3, len(dataset))):
1238
+ try:
1239
+ idx = random.randint(0, len(dataset) - 1)
1240
+ sample = dataset[idx]
1241
+
1242
+ source_name = sample.get("dataset_source", "single")
1243
+ counts[source_name] = counts.get(source_name, 0) + 1
1244
+
1245
+ # Denormalize tensors from [-1, 1] to [0, 255]
1246
+ source_np = ((sample["source_pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
1247
+ target_np = ((sample["pixel_values"].permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
1248
+
1249
+ # Save images
1250
+ idx_str = f"{saved:05d}"
1251
+ Image.fromarray(source_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_source.jpg"))
1252
+ Image.fromarray(target_np).save(os.path.join(output_dir, f"{idx_str}_{source_name}_target.jpg"))
1253
+
1254
+ saved += 1
1255
+ if saved % 10 == 0:
1256
+ print(f"Saved {saved}/{num_samples} samples - {counts}")
1257
+ if saved >= num_samples:
1258
+ break
1259
+
1260
+ except Exception as e:
1261
+ print(f"Failed to process sample: {e}")
1262
+ continue
1263
+
1264
+ print(f"\nTest complete. Saved {saved} samples to {output_dir}")
1265
+ print(f"Distribution: {counts}")
1266
+
1267
+
1268
+ def _parse_test_args():
1269
+ import argparse
1270
+ parser = argparse.ArgumentParser(description="Test visualization for Kontext datasets")
1271
+ parser.add_argument("--interactive_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl",
1272
+ help="Path to interactive dataset JSONL")
1273
+ parser.add_argument("--placement_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl",
1274
+ help="Path to placement dataset JSONL")
1275
+ parser.add_argument("--pexels_jsonl", type=str, default=None,
1276
+ help="Path to pexels dataset JSONL")
1277
+ parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1",
1278
+ help="Base directory for interactive dataset")
1279
+ parser.add_argument("--placement_base_dir", type=str, default=None,
1280
+ help="Base directory for placement dataset")
1281
+ parser.add_argument("--pexels_base_dir", type=str, default=None,
1282
+ help="Base directory for pexels dataset")
1283
+ parser.add_argument("--pexels_relight_base_dir", type=str, default="/robby/share/Editing/lzc/data/relight_outputs",
1284
+ help="Base directory for pexels relighted images")
1285
+ parser.add_argument("--seg_base_dir", type=str, default=None,
1286
+ help="Directory containing segmentation maps for pexels dataset")
1287
+ parser.add_argument("--interactive_weight", type=float, default=1.0,
1288
+ help="Sampling weight for interactive dataset (default: 1.0)")
1289
+ parser.add_argument("--placement_weight", type=float, default=1.0,
1290
+ help="Sampling weight for placement dataset (default: 1.0)")
1291
+ parser.add_argument("--pexels_weight", type=float, default=0,
1292
+ help="Sampling weight for pexels dataset (default: 1.0)")
1293
+ parser.add_argument("--output_dir", type=str, default="visualize_output",
1294
+ help="Output directory to save pairs")
1295
+ parser.add_argument("--num_samples", type=int, default=100,
1296
+ help="Number of pairs to save")
1297
+
1298
+ # Legacy arguments
1299
+ parser.add_argument("--test_jsonl", type=str, default=None,
1300
+ help="Legacy: Path to JSONL (uses as interactive_jsonl)")
1301
+ parser.add_argument("--base_dir", type=str, default=None,
1302
+ help="Legacy: Base directory (uses as interactive_base_dir)")
1303
+ return parser.parse_args()
1304
+
1305
+
1306
+ if __name__ == "__main__":
1307
+ try:
1308
+ args = _parse_test_args()
1309
+
1310
+ # Handle legacy args
1311
+ interactive_jsonl = args.interactive_jsonl or args.test_jsonl
1312
+ interactive_base_dir = args.interactive_base_dir or args.base_dir
1313
+
1314
+ _run_test_mode(
1315
+ interactive_jsonl=interactive_jsonl,
1316
+ placement_jsonl=args.placement_jsonl,
1317
+ pexels_jsonl=args.pexels_jsonl,
1318
+ interactive_base_dir=interactive_base_dir,
1319
+ placement_base_dir=args.placement_base_dir,
1320
+ pexels_base_dir=args.pexels_base_dir,
1321
+ pexels_relight_base_dir=args.pexels_relight_base_dir,
1322
+ seg_base_dir=args.seg_base_dir,
1323
+ interactive_weight=args.interactive_weight,
1324
+ placement_weight=args.placement_weight,
1325
+ pexels_weight=args.pexels_weight,
1326
+ output_dir=args.output_dir,
1327
+ num_samples=args.num_samples
1328
+ )
1329
+ except SystemExit:
1330
+ # Allow import usage without triggering test mode
1331
+ pass
1332
+
train/src/jsonl_datasets_kontext_local.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from datasets import Dataset
3
+ from torchvision import transforms
4
+ import random
5
+ import torch
6
+ import os
7
+ from .pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
8
+ from .jsonl_datasets_kontext import make_train_dataset_inpaint_mask
9
+ import numpy as np
10
+ import json
11
+ from .generate_diff_mask import generate_final_difference_mask, align_images
12
+
13
+ Image.MAX_IMAGE_PIXELS = None
14
+ BLEND_PIXEL_VALUES = True
15
+
16
+ def multiple_16(num: float):
17
+ return int(round(num / 16) * 16)
18
+
19
+ def choose_kontext_resolution_from_wh(width: int, height: int):
20
+ aspect_ratio = width / max(1, height)
21
+ _, best_w, best_h = min(
22
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
23
+ )
24
+ return best_w, best_h
25
+
26
+ def collate_fn(examples):
27
+ if examples[0].get("cond_pixel_values") is not None:
28
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
29
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
30
+ else:
31
+ cond_pixel_values = None
32
+ if examples[0].get("source_pixel_values") is not None:
33
+ source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
34
+ source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
35
+ else:
36
+ source_pixel_values = None
37
+
38
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
39
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
40
+ token_ids_clip = torch.stack([example["token_ids_clip"] for example in examples])
41
+ token_ids_t5 = torch.stack([example["token_ids_t5"] for example in examples])
42
+
43
+ mask_values = None
44
+ if examples[0].get("mask_values") is not None:
45
+ mask_values = torch.stack([example["mask_values"] for example in examples])
46
+ mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
47
+
48
+ return {
49
+ "cond_pixel_values": cond_pixel_values,
50
+ "source_pixel_values": source_pixel_values,
51
+ "pixel_values": target_pixel_values,
52
+ "text_ids_1": token_ids_clip,
53
+ "text_ids_2": token_ids_t5,
54
+ "mask_values": mask_values,
55
+ }
56
+
57
+
58
+ # New dataset for local_edits JSON mapping with on-the-fly diff-mask generation
59
+ def make_train_dataset_local_edits(args, tokenizers, accelerator=None):
60
+ # Read JSON entries
61
+ with open(args.local_edits_json, "r", encoding="utf-8") as f:
62
+ entries = json.load(f)
63
+
64
+ samples = []
65
+ for item in entries:
66
+ rel_path = item.get("path", "")
67
+ local_edits = item.get("local_edits", []) or []
68
+ if not rel_path or not local_edits:
69
+ continue
70
+
71
+ base_name = os.path.basename(rel_path)
72
+ prefix = os.path.splitext(base_name)[0]
73
+ group_dir = os.path.basename(os.path.dirname(rel_path))
74
+ gid_int = None
75
+ try:
76
+ gid_int = int(group_dir)
77
+ except Exception:
78
+ try:
79
+ digits = "".join([ch for ch in group_dir if ch.isdigit()])
80
+ gid_int = int(digits) if digits else None
81
+ except Exception:
82
+ gid_int = None
83
+
84
+ group_str = group_dir # e.g., "0139" from the JSON path segment
85
+
86
+ # Resolve source/target directories strictly as base/<0139>
87
+ src_dir_candidates = [os.path.join(args.source_frames_dir, group_str)]
88
+ tgt_dir_candidates = [os.path.join(args.target_frames_dir, group_str)]
89
+ src_dir = next((d for d in src_dir_candidates if d and os.path.isdir(d)), None)
90
+ tgt_dir = next((d for d in tgt_dir_candidates if d and os.path.isdir(d)), None)
91
+ if src_dir is None or tgt_dir is None:
92
+ continue
93
+
94
+ src_path = os.path.join(src_dir, f"{prefix}.png")
95
+ for idx, prompt in enumerate(local_edits, start=1):
96
+ tgt_path = os.path.join(tgt_dir, f"{prefix}_{idx}.png")
97
+ mask_path = os.path.join(args.masks_dir, group_str, f"{prefix}_{idx}.png")
98
+ if not (os.path.exists(src_path) and os.path.exists(tgt_path) and os.path.exists(mask_path)):
99
+ continue
100
+ samples.append({
101
+ "source_image": src_path,
102
+ "target_image": tgt_path,
103
+ "mask_image": mask_path,
104
+ "prompt": prompt,
105
+ })
106
+
107
+ size = args.cond_size
108
+
109
+ to_tensor_and_norm = transforms.Compose([
110
+ transforms.ToTensor(),
111
+ transforms.Normalize([0.5], [0.5]),
112
+ ])
113
+
114
+ cond_train_transforms = transforms.Compose(
115
+ [
116
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize([0.5], [0.5]),
119
+ ]
120
+ )
121
+
122
+ tokenizer_clip = tokenizers[0]
123
+ tokenizer_t5 = tokenizers[1]
124
+
125
+ def tokenize_prompt_single(caption: str):
126
+ text_inputs_clip = tokenizer_clip(
127
+ [caption],
128
+ padding="max_length",
129
+ max_length=77,
130
+ truncation=True,
131
+ return_tensors="pt",
132
+ )
133
+ text_input_ids_1 = text_inputs_clip.input_ids[0]
134
+
135
+ text_inputs_t5 = tokenizer_t5(
136
+ [caption],
137
+ padding="max_length",
138
+ max_length=128,
139
+ truncation=True,
140
+ return_tensors="pt",
141
+ )
142
+ text_input_ids_2 = text_inputs_t5.input_ids[0]
143
+ return text_input_ids_1, text_input_ids_2
144
+
145
+ class LocalEditsDataset(torch.utils.data.Dataset):
146
+ def __init__(self, samples_ls):
147
+ self.samples = samples_ls
148
+ def __len__(self):
149
+ return len(self.samples)
150
+ def __getitem__(self, idx):
151
+ sample = self.samples[idx]
152
+ s_p = sample["source_image"]
153
+ t_p = sample["target_image"]
154
+ m_p = sample["mask_image"]
155
+ cap = sample["prompt"]
156
+
157
+ rr = random.randint(10, 20)
158
+ ri = random.randint(3, 5)
159
+ import cv2
160
+ mask_loaded = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
161
+ if mask_loaded is None:
162
+ raise ValueError("mask load failed")
163
+ mask = mask_loaded.copy()
164
+
165
+ # Pre-expand mask by a fixed number of pixels before any random expansion
166
+ # Uses a cross-shaped kernel when tapered_corners is True to emulate "tapered" growth
167
+ pre_expand_px = int(getattr(args, "pre_expand_mask_px", 50))
168
+ pre_expand_tapered = bool(getattr(args, "pre_expand_tapered_corners", True))
169
+ if pre_expand_px != 0:
170
+ c = 0 if pre_expand_tapered else 1
171
+ pre_kernel = np.array([[c, 1, c],
172
+ [1, 1, 1],
173
+ [c, 1, c]], dtype=np.uint8)
174
+ if pre_expand_px > 0:
175
+ mask = cv2.dilate(mask, pre_kernel, iterations=pre_expand_px)
176
+ else:
177
+ mask = cv2.erode(mask, pre_kernel, iterations=abs(pre_expand_px))
178
+ if rr > 0 and ri > 0:
179
+ ksize = max(1, 2 * int(rr) + 1)
180
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
181
+ for _ in range(max(1, ri)):
182
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
183
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
184
+
185
+ src_aligned, tgt_aligned = align_images(s_p, t_p)
186
+
187
+ best_w, best_h = choose_kontext_resolution_from_wh(tgt_aligned.width, tgt_aligned.height)
188
+ final_img_rs = tgt_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
189
+ raw_img_rs = src_aligned.resize((best_w, best_h), resample=Image.BILINEAR)
190
+
191
+ target_tensor = to_tensor_and_norm(final_img_rs)
192
+ source_tensor = to_tensor_and_norm(raw_img_rs)
193
+
194
+ mask_img = Image.fromarray(mask.astype(np.uint8)).convert("L")
195
+ if mask_img.size != src_aligned.size:
196
+ mask_img = mask_img.resize(src_aligned.size, Image.NEAREST)
197
+ mask_np = np.array(mask_img)
198
+
199
+ mask_bin = (mask_np > 127).astype(np.uint8)
200
+ inv_mask = (1 - mask_bin).astype(np.uint8)
201
+ src_np = np.array(src_aligned)
202
+ masked_raw_np = src_np * inv_mask[..., None]
203
+ masked_raw_img = Image.fromarray(masked_raw_np.astype(np.uint8))
204
+ cond_tensor = cond_train_transforms(masked_raw_img)
205
+
206
+ # Prepare mask_values tensor at training resolution (best_w, best_h)
207
+ mask_img_rs = mask_img.resize((best_w, best_h), Image.NEAREST)
208
+ mask_np_rs = np.array(mask_img_rs)
209
+ mask_bin_rs = (mask_np_rs > 127).astype(np.float32)
210
+ mask_tensor = torch.from_numpy(mask_bin_rs).unsqueeze(0) # [1, H, W]
211
+
212
+ ids1, ids2 = tokenize_prompt_single(cap if isinstance(cap, str) else "")
213
+
214
+ # Optionally blend target and source using a blurred mask, controlled by args
215
+ if getattr(args, "blend_pixel_values", BLEND_PIXEL_VALUES):
216
+ blend_kernel = int(getattr(args, "blend_kernel", 21))
217
+ if blend_kernel % 2 == 0:
218
+ blend_kernel += 1
219
+ blend_sigma = float(getattr(args, "blend_sigma", 10.0))
220
+ gb = transforms.GaussianBlur(kernel_size=(blend_kernel, blend_kernel), sigma=(blend_sigma, blend_sigma))
221
+ # mask_tensor: [1, H, W] in [0,1]
222
+ blurred_mask = gb(mask_tensor) # [1, H, W]
223
+ # Expand to 3 channels to match image tensors
224
+ blurred_mask_3c = blurred_mask.expand(target_tensor.shape[0], -1, -1) # [3, H, W]
225
+ # Blend in normalized space (both tensors already normalized to [-1, 1])
226
+ target_tensor = (source_tensor * (1.0 - blurred_mask_3c)) + (target_tensor * blurred_mask_3c)
227
+ target_tensor = target_tensor.clamp(-1.0, 1.0)
228
+
229
+ return {
230
+ "source_pixel_values": source_tensor,
231
+ "pixel_values": target_tensor,
232
+ "cond_pixel_values": cond_tensor,
233
+ "token_ids_clip": ids1,
234
+ "token_ids_t5": ids2,
235
+ "mask_values": mask_tensor,
236
+ }
237
+
238
+ return LocalEditsDataset(samples)
239
+
240
+
241
+ class BalancedMixDataset(torch.utils.data.Dataset):
242
+ """
243
+ A wrapper dataset that mixes two datasets with a configurable ratio.
244
+
245
+ ratio_b_per_a defines how many samples from dataset_b for each sample from dataset_a:
246
+ - 0 => only dataset_a (local edits)
247
+ - 1 => 1:1 mix (default)
248
+ - 2 => 1:2 mix (A:B)
249
+ - any float supported (e.g., 0.5 => 2:1 mix)
250
+ """
251
+ def __init__(self, dataset_a, dataset_b, ratio_b_per_a: float = 1.0):
252
+ self.dataset_a = dataset_a
253
+ self.dataset_b = dataset_b
254
+ self.ratio_b_per_a = max(0.0, float(ratio_b_per_a))
255
+
256
+ len_a = len(dataset_a)
257
+ len_b = len(dataset_b)
258
+
259
+ # If ratio is 0, use all of dataset_a only
260
+ if self.ratio_b_per_a == 0 or len_b == 0:
261
+ a_indices = list(range(len_a))
262
+ random.shuffle(a_indices)
263
+ self.mapping = [(0, i) for i in a_indices]
264
+ return
265
+
266
+ # Determine how many we can draw without replacement
267
+ # n_a limited by A size and B availability according to ratio
268
+ n_a_by_ratio = int(len_b / self.ratio_b_per_a)
269
+ n_a = min(len_a, max(1, n_a_by_ratio))
270
+ n_b = min(len_b, max(1, int(round(n_a * self.ratio_b_per_a))))
271
+
272
+ a_indices = list(range(len_a))
273
+ b_indices = list(range(len_b))
274
+ random.shuffle(a_indices)
275
+ random.shuffle(b_indices)
276
+ a_indices = a_indices[: n_a]
277
+ b_indices = b_indices[: n_b]
278
+
279
+ mixed = [(0, i) for i in a_indices] + [(1, i) for i in b_indices]
280
+ random.shuffle(mixed)
281
+ self.mapping = mixed
282
+
283
+ def __len__(self):
284
+ return len(self.mapping)
285
+
286
+ def __getitem__(self, idx):
287
+ which, real_idx = self.mapping[idx]
288
+ if which == 0:
289
+ return self.dataset_a[real_idx]
290
+ else:
291
+ return self.dataset_b[real_idx]
292
+
293
+
294
+ def make_train_dataset_mixed(args, tokenizers, accelerator=None):
295
+ """
296
+ Create a mixed dataset from:
297
+ - Local edits dataset (this file)
298
+ - Inpaint-mask JSONL dataset (jsonl_datasets_kontext.make_train_dataset_inpaint_mask)
299
+
300
+ Ratio control via args.mix_ratio (float):
301
+ - 0 => only local edits dataset
302
+ - 1 => 1:1 mix (local:inpaint)
303
+ - 2 => 1:2 mix, etc.
304
+
305
+ Requirements:
306
+ - args.local_edits_json and related dirs must be set for local edits
307
+ - args.train_data_dir must be set for the JSONL inpaint dataset
308
+ """
309
+ ds_local = make_train_dataset_local_edits(args, tokenizers, accelerator)
310
+ ds_inpaint = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
311
+ mix_ratio = getattr(args, "mix_ratio", 1.0)
312
+ return BalancedMixDataset(ds_local, ds_inpaint, ratio_b_per_a=mix_ratio)
train/src/layers.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+ from einops import rearrange
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+
11
+ class LoRALinearLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ out_features: int,
16
+ rank: int = 4,
17
+ network_alpha: Optional[float] = None,
18
+ device: Optional[Union[torch.device, str]] = None,
19
+ dtype: Optional[torch.dtype] = None,
20
+ cond_width=512,
21
+ cond_height=512,
22
+ number=0,
23
+ n_loras=1
24
+ ):
25
+ super().__init__()
26
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30
+ self.network_alpha = network_alpha
31
+ self.rank = rank
32
+ self.out_features = out_features
33
+ self.in_features = in_features
34
+
35
+ nn.init.normal_(self.down.weight, std=1 / rank)
36
+ nn.init.zeros_(self.up.weight)
37
+
38
+ self.cond_height = cond_height
39
+ self.cond_width = cond_width
40
+ self.number = number
41
+ self.n_loras = n_loras
42
+
43
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44
+ orig_dtype = hidden_states.dtype
45
+ dtype = self.down.weight.dtype
46
+
47
+ #### img condition
48
+ batch_size = hidden_states.shape[0]
49
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
51
+ shape = (batch_size, hidden_states.shape[1], 3072)
52
+ mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53
+ mask[:, :block_size+self.number*cond_size, :] = 0
54
+ mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55
+ hidden_states = mask * hidden_states
56
+ ####
57
+
58
+ down_hidden_states = self.down(hidden_states.to(dtype))
59
+ up_hidden_states = self.up(down_hidden_states)
60
+
61
+ if self.network_alpha is not None:
62
+ up_hidden_states *= self.network_alpha / self.rank
63
+
64
+ return up_hidden_states.to(orig_dtype)
65
+
66
+
67
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
68
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69
+ super().__init__()
70
+ # Initialize a list to store the LoRA layers
71
+ self.n_loras = n_loras
72
+ self.cond_width = cond_width
73
+ self.cond_height = cond_height
74
+
75
+ self.q_loras = nn.ModuleList([
76
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77
+ for i in range(n_loras)
78
+ ])
79
+ self.k_loras = nn.ModuleList([
80
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81
+ for i in range(n_loras)
82
+ ])
83
+ self.v_loras = nn.ModuleList([
84
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85
+ for i in range(n_loras)
86
+ ])
87
+ self.lora_weights = lora_weights
88
+
89
+
90
+ def __call__(self,
91
+ attn: Attention,
92
+ hidden_states: torch.FloatTensor,
93
+ encoder_hidden_states: torch.FloatTensor = None,
94
+ attention_mask: Optional[torch.FloatTensor] = None,
95
+ image_rotary_emb: Optional[torch.Tensor] = None,
96
+ use_cond = False,
97
+ ) -> torch.FloatTensor:
98
+
99
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
100
+ query = attn.to_q(hidden_states)
101
+ key = attn.to_k(hidden_states)
102
+ value = attn.to_v(hidden_states)
103
+
104
+ for i in range(self.n_loras):
105
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
106
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
107
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
108
+
109
+ inner_dim = key.shape[-1]
110
+ head_dim = inner_dim // attn.heads
111
+
112
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
113
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
115
+
116
+ if attn.norm_q is not None:
117
+ query = attn.norm_q(query)
118
+ if attn.norm_k is not None:
119
+ key = attn.norm_k(key)
120
+
121
+ if image_rotary_emb is not None:
122
+ from diffusers.models.embeddings import apply_rotary_emb
123
+ query = apply_rotary_emb(query, image_rotary_emb)
124
+ key = apply_rotary_emb(key, image_rotary_emb)
125
+
126
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
127
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
128
+ scaled_cond_size = cond_size
129
+ scaled_block_size = block_size
130
+ scaled_seq_len = query.shape[2]
131
+
132
+ num_cond_blocks = self.n_loras
133
+ # mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
134
+ # mask[ :scaled_block_size, :] = 0 # First block_size row
135
+ # for i in range(num_cond_blocks):
136
+ # start = i * scaled_cond_size + scaled_block_size
137
+ # end = (i + 1) * scaled_cond_size + scaled_block_size
138
+ # mask[start:end, start:end] = 0 # Diagonal blocks
139
+ # mask = mask * -1e20
140
+ # mask = mask.to(query.dtype)
141
+
142
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
143
+
144
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
145
+ hidden_states = hidden_states.to(query.dtype)
146
+
147
+ cond_hidden_states = hidden_states[:, block_size:,:]
148
+ hidden_states = hidden_states[:, : block_size,:]
149
+
150
+ return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
151
+
152
+
153
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
154
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
155
+ super().__init__()
156
+
157
+ # Initialize a list to store the LoRA layers
158
+ self.n_loras = n_loras
159
+ self.cond_width = cond_width
160
+ self.cond_height = cond_height
161
+ self.q_loras = nn.ModuleList([
162
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
163
+ for i in range(n_loras)
164
+ ])
165
+ self.k_loras = nn.ModuleList([
166
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
167
+ for i in range(n_loras)
168
+ ])
169
+ self.v_loras = nn.ModuleList([
170
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
171
+ for i in range(n_loras)
172
+ ])
173
+ self.proj_loras = nn.ModuleList([
174
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
175
+ for i in range(n_loras)
176
+ ])
177
+ self.lora_weights = lora_weights
178
+
179
+
180
+ def __call__(self,
181
+ attn: Attention,
182
+ hidden_states: torch.FloatTensor,
183
+ encoder_hidden_states: torch.FloatTensor = None,
184
+ attention_mask: Optional[torch.FloatTensor] = None,
185
+ image_rotary_emb: Optional[torch.Tensor] = None,
186
+ use_cond=False,
187
+ ) -> torch.FloatTensor:
188
+
189
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
190
+
191
+ # `context` projections.
192
+ inner_dim = 3072
193
+ head_dim = inner_dim // attn.heads
194
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
195
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
196
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
197
+
198
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
199
+ batch_size, -1, attn.heads, head_dim
200
+ ).transpose(1, 2)
201
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
202
+ batch_size, -1, attn.heads, head_dim
203
+ ).transpose(1, 2)
204
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
205
+ batch_size, -1, attn.heads, head_dim
206
+ ).transpose(1, 2)
207
+
208
+ if attn.norm_added_q is not None:
209
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
210
+ if attn.norm_added_k is not None:
211
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
212
+
213
+ query = attn.to_q(hidden_states)
214
+ key = attn.to_k(hidden_states)
215
+ value = attn.to_v(hidden_states)
216
+ for i in range(self.n_loras):
217
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
218
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
219
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
220
+
221
+ inner_dim = key.shape[-1]
222
+ head_dim = inner_dim // attn.heads
223
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
224
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
225
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
226
+
227
+ if attn.norm_q is not None:
228
+ query = attn.norm_q(query)
229
+ if attn.norm_k is not None:
230
+ key = attn.norm_k(key)
231
+
232
+ # attention
233
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
234
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
235
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
236
+
237
+ if image_rotary_emb is not None:
238
+ from diffusers.models.embeddings import apply_rotary_emb
239
+ query = apply_rotary_emb(query, image_rotary_emb)
240
+ key = apply_rotary_emb(key, image_rotary_emb)
241
+
242
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
243
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
244
+ scaled_cond_size = cond_size
245
+ scaled_seq_len = query.shape[2]
246
+ scaled_block_size = scaled_seq_len - cond_size * self.n_loras
247
+
248
+ num_cond_blocks = self.n_loras
249
+ # mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
250
+ # mask[ :scaled_block_size, :] = 0 # First block_size row
251
+ # for i in range(num_cond_blocks):
252
+ # start = i * scaled_cond_size + scaled_block_size
253
+ # end = (i + 1) * scaled_cond_size + scaled_block_size
254
+ # mask[start:end, start:end] = 0 # Diagonal blocks
255
+ # mask = mask * -1e20
256
+ # mask = mask.to(query.dtype)
257
+
258
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
259
+
260
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
261
+ hidden_states = hidden_states.to(query.dtype)
262
+
263
+ encoder_hidden_states, hidden_states = (
264
+ hidden_states[:, : encoder_hidden_states.shape[1]],
265
+ hidden_states[:, encoder_hidden_states.shape[1] :],
266
+ )
267
+
268
+ # Linear projection (with LoRA weight applied to each proj layer)
269
+ hidden_states = attn.to_out[0](hidden_states)
270
+ for i in range(self.n_loras):
271
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
272
+ # dropout
273
+ hidden_states = attn.to_out[1](hidden_states)
274
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
275
+
276
+ cond_hidden_states = hidden_states[:, block_size:,:]
277
+ hidden_states = hidden_states[:, :block_size,:]
278
+
279
+ return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
train/src/lora_helper.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors import safe_open
3
+ import re
4
+ import torch
5
+ from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
+
7
+ device = "cuda"
8
+
9
+ def load_safetensors(path):
10
+ tensors = {}
11
+ with safe_open(path, framework="pt", device="cpu") as f:
12
+ for key in f.keys():
13
+ tensors[key] = f.get_tensor(key)
14
+ return tensors
15
+
16
+ def get_lora_rank(checkpoint):
17
+ for k in checkpoint.keys():
18
+ if k.endswith(".down.weight"):
19
+ return checkpoint[k].shape[0]
20
+
21
+ def load_checkpoint(local_path):
22
+ if local_path is not None:
23
+ if '.safetensors' in local_path:
24
+ print(f"Loading .safetensors checkpoint from {local_path}")
25
+ checkpoint = load_safetensors(local_path)
26
+ else:
27
+ print(f"Loading checkpoint from {local_path}")
28
+ checkpoint = torch.load(local_path, map_location='cpu')
29
+ return checkpoint
30
+
31
+ def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32
+ number = len(lora_weights)
33
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34
+ lora_attn_procs = {}
35
+ double_blocks_idx = list(range(19))
36
+ single_blocks_idx = list(range(38))
37
+ for name, attn_processor in transformer.attn_processors.items():
38
+ match = re.search(r'\.(\d+)\.', name)
39
+ if match:
40
+ layer_index = int(match.group(1))
41
+
42
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43
+
44
+ lora_state_dicts = {}
45
+ for key, value in checkpoint.items():
46
+ # Match based on the layer index in the key (assuming the key contains layer index)
47
+ if re.search(r'\.(\d+)\.', key):
48
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50
+ lora_state_dicts[key] = value
51
+
52
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54
+ )
55
+
56
+ # Load the weights from the checkpoint dictionary into the corresponding layers
57
+ for n in range(number):
58
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66
+ lora_attn_procs[name].to(device)
67
+
68
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69
+
70
+ lora_state_dicts = {}
71
+ for key, value in checkpoint.items():
72
+ # Match based on the layer index in the key (assuming the key contains layer index)
73
+ if re.search(r'\.(\d+)\.', key):
74
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76
+ lora_state_dicts[key] = value
77
+
78
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80
+ )
81
+ # Load the weights from the checkpoint dictionary into the corresponding layers
82
+ for n in range(number):
83
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89
+ lora_attn_procs[name].to(device)
90
+ else:
91
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
92
+
93
+ transformer.set_attn_processor(lora_attn_procs)
94
+
95
+
96
+ def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97
+ ck_number = len(checkpoints)
98
+ cond_lora_number = [len(ls) for ls in lora_weights]
99
+ cond_number = sum(cond_lora_number)
100
+ ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101
+ multi_lora_weight = []
102
+ for ls in lora_weights:
103
+ for n in ls:
104
+ multi_lora_weight.append(n)
105
+
106
+ lora_attn_procs = {}
107
+ double_blocks_idx = list(range(19))
108
+ single_blocks_idx = list(range(38))
109
+ for name, attn_processor in transformer.attn_processors.items():
110
+ match = re.search(r'\.(\d+)\.', name)
111
+ if match:
112
+ layer_index = int(match.group(1))
113
+
114
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115
+ lora_state_dicts = [{} for _ in range(ck_number)]
116
+ for idx, checkpoint in enumerate(checkpoints):
117
+ for key, value in checkpoint.items():
118
+ # Match based on the layer index in the key (assuming the key contains layer index)
119
+ if re.search(r'\.(\d+)\.', key):
120
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122
+ lora_state_dicts[idx][key] = value
123
+
124
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126
+ )
127
+
128
+ # Load the weights from the checkpoint dictionary into the corresponding layers
129
+ num = 0
130
+ for idx in range(ck_number):
131
+ for n in range(cond_lora_number[idx]):
132
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138
+ lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139
+ lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140
+ lora_attn_procs[name].to(device)
141
+ num += 1
142
+
143
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144
+
145
+ lora_state_dicts = [{} for _ in range(ck_number)]
146
+ for idx, checkpoint in enumerate(checkpoints):
147
+ for key, value in checkpoint.items():
148
+ # Match based on the layer index in the key (assuming the key contains layer index)
149
+ if re.search(r'\.(\d+)\.', key):
150
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152
+ lora_state_dicts[idx][key] = value
153
+
154
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156
+ )
157
+ # Load the weights from the checkpoint dictionary into the corresponding layers
158
+ num = 0
159
+ for idx in range(ck_number):
160
+ for n in range(cond_lora_number[idx]):
161
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167
+ lora_attn_procs[name].to(device)
168
+ num += 1
169
+
170
+ else:
171
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
172
+
173
+ transformer.set_attn_processor(lora_attn_procs)
174
+
175
+
176
+ def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177
+ checkpoint = load_checkpoint(local_path)
178
+ update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179
+
180
+ def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181
+ checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182
+ update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183
+
184
+ def unset_lora(transformer):
185
+ lora_attn_procs = {}
186
+ for name, attn_processor in transformer.attn_processors.items():
187
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
188
+ transformer.set_attn_processor(lora_attn_procs)
189
+
190
+
191
+ '''
192
+ unset_lora(pipe.transformer)
193
+ lora_path = "./lora.safetensors"
194
+ lora_weights = [1, 1]
195
+ set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196
+ '''
train/src/masks_integrated.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import logging
4
+ from enum import Enum
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import random
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+ class LinearRamp:
13
+ def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
14
+ self.start_value = start_value
15
+ self.end_value = end_value
16
+ self.start_iter = start_iter
17
+ self.end_iter = end_iter
18
+
19
+ def __call__(self, i):
20
+ if i < self.start_iter:
21
+ return self.start_value
22
+ if i >= self.end_iter:
23
+ return self.end_value
24
+ part = (i - self.start_iter) / (self.end_iter - self.start_iter)
25
+ return self.start_value * (1 - part) + self.end_value * part
26
+
27
+ class DrawMethod(Enum):
28
+ LINE = 'line'
29
+ CIRCLE = 'circle'
30
+ SQUARE = 'square'
31
+
32
+ def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
33
+ draw_method=DrawMethod.LINE):
34
+ """生成不规则mask - 基于角度和长度的线条"""
35
+ draw_method = DrawMethod(draw_method)
36
+
37
+ height, width = shape
38
+ mask = np.zeros((height, width), np.float32)
39
+ times = np.random.randint(min_times, max_times + 1)
40
+ for i in range(times):
41
+ start_x = np.random.randint(width)
42
+ start_y = np.random.randint(height)
43
+ for j in range(1 + np.random.randint(5)):
44
+ angle = 0.01 + np.random.randint(max_angle)
45
+ if i % 2 == 0:
46
+ angle = 2 * 3.1415926 - angle
47
+ length = 10 + np.random.randint(max_len)
48
+ brush_w = 5 + np.random.randint(max_width)
49
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
50
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
51
+ if draw_method == DrawMethod.LINE:
52
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
53
+ elif draw_method == DrawMethod.CIRCLE:
54
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
55
+ elif draw_method == DrawMethod.SQUARE:
56
+ radius = brush_w // 2
57
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
58
+ start_x, start_y = end_x, end_y
59
+ return mask[None, ...]
60
+
61
+
62
+ def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
63
+ """生成随机矩形mask"""
64
+ height, width = shape
65
+ mask = np.zeros((height, width), np.float32)
66
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
67
+ times = np.random.randint(min_times, max_times + 1)
68
+ for i in range(times):
69
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
70
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
71
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
72
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
73
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
74
+ return mask[None, ...]
75
+
76
+
77
+ def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
78
+ """生成超分辨率风格的规则网格mask"""
79
+ height, width = shape
80
+ mask = np.zeros((height, width), np.float32)
81
+ step_x = np.random.randint(min_step, max_step + 1)
82
+ width_x = np.random.randint(min_width, min(step_x, max_width + 1))
83
+ offset_x = np.random.randint(0, step_x)
84
+
85
+ step_y = np.random.randint(min_step, max_step + 1)
86
+ width_y = np.random.randint(min_width, min(step_y, max_width + 1))
87
+ offset_y = np.random.randint(0, step_y)
88
+
89
+ for dy in range(width_y):
90
+ mask[offset_y + dy::step_y] = 1
91
+ for dx in range(width_x):
92
+ mask[:, offset_x + dx::step_x] = 1
93
+ return mask[None, ...]
94
+
95
+
96
+ def make_brush_stroke_mask(shape, num_strokes_range=(1, 5), stroke_width_range=(5, 30),
97
+ max_offset=50, num_points_range=(5, 15)):
98
+ """生成笔刷描边样式的mask - 基于随机偏移的连续线条"""
99
+ num_strokes = random.randint(*num_strokes_range)
100
+ height, width = shape
101
+ mask = np.zeros((height, width), dtype=np.float32)
102
+
103
+ for _ in range(num_strokes):
104
+ # 随机起点
105
+ start_x = random.randint(0, width)
106
+ start_y = random.randint(0, height)
107
+
108
+ # 随机描边参数
109
+ num_points = random.randint(*num_points_range)
110
+ stroke_width = random.randint(*stroke_width_range)
111
+
112
+ points = [(start_x, start_y)]
113
+
114
+ # 生成连续的点
115
+ for i in range(num_points):
116
+ prev_x, prev_y = points[-1]
117
+ # 添加随机偏移
118
+ dx = random.randint(-max_offset, max_offset)
119
+ dy = random.randint(-max_offset, max_offset)
120
+ new_x = max(0, min(width, prev_x + dx))
121
+ new_y = max(0, min(height, prev_y + dy))
122
+ points.append((new_x, new_y))
123
+
124
+ # 绘制描边
125
+ for i in range(len(points) - 1):
126
+ cv2.line(mask, points[i], points[i+1], 1.0, stroke_width)
127
+
128
+ return mask[None, ...]
129
+
130
+
131
+ class RandomIrregularMaskGenerator:
132
+ """不规则mask生成器"""
133
+ def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
134
+ draw_method=DrawMethod.LINE):
135
+ self.max_angle = max_angle
136
+ self.max_len = max_len
137
+ self.max_width = max_width
138
+ self.min_times = min_times
139
+ self.max_times = max_times
140
+ self.draw_method = draw_method
141
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
142
+
143
+ def __call__(self, img, iter_i=None, raw_image=None):
144
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
145
+ cur_max_len = int(max(1, self.max_len * coef))
146
+ cur_max_width = int(max(1, self.max_width * coef))
147
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
148
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
149
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
150
+ draw_method=self.draw_method)
151
+
152
+
153
+ class RandomRectangleMaskGenerator:
154
+ """矩形mask生成器"""
155
+ def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
156
+ self.margin = margin
157
+ self.bbox_min_size = bbox_min_size
158
+ self.bbox_max_size = bbox_max_size
159
+ self.min_times = min_times
160
+ self.max_times = max_times
161
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
162
+
163
+ def __call__(self, img, iter_i=None, raw_image=None):
164
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
165
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
166
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
167
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
168
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
169
+ max_times=cur_max_times)
170
+
171
+
172
+ class RandomSuperresMaskGenerator:
173
+ """超分辨率mask生成器"""
174
+ def __init__(self, **kwargs):
175
+ self.kwargs = kwargs
176
+
177
+ def __call__(self, img, iter_i=None):
178
+ return make_random_superres_mask(img.shape[1:], **self.kwargs)
179
+
180
+
181
+ class BrushStrokeMaskGenerator:
182
+ """笔刷描边mask生成器"""
183
+ def __init__(self, num_strokes_range=(1, 5), stroke_width_range=(5, 30),
184
+ max_offset=50, num_points_range=(5, 15), ramp_kwargs=None):
185
+ self.num_strokes_range = num_strokes_range
186
+ self.stroke_width_range = stroke_width_range
187
+ self.max_offset = max_offset
188
+ self.num_points_range = num_points_range
189
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
190
+
191
+ def __call__(self, img, iter_i=None, raw_image=None):
192
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
193
+ cur_num_strokes = int(max(1, self.num_strokes_range[1] * coef))
194
+ cur_stroke_width_range = (
195
+ int(max(1, self.stroke_width_range[0] * coef)),
196
+ int(max(1, self.stroke_width_range[1] * coef))
197
+ )
198
+ return make_brush_stroke_mask(
199
+ img.shape[1:],
200
+ num_strokes_range=(cur_num_strokes, cur_num_strokes),
201
+ stroke_width_range=cur_stroke_width_range,
202
+ max_offset=self.max_offset,
203
+ num_points_range=self.num_points_range
204
+ )
205
+
206
+
207
+ class DumbAreaMaskGenerator:
208
+ """简单区域mask生成器"""
209
+ min_ratio = 0.1
210
+ max_ratio = 0.35
211
+ default_ratio = 0.225
212
+
213
+ def __init__(self, is_training):
214
+ #Parameters:
215
+ # is_training(bool): If true - random rectangular mask, if false - central square mask
216
+ self.is_training = is_training
217
+
218
+ def _random_vector(self, dimension):
219
+ if self.is_training:
220
+ lower_limit = math.sqrt(self.min_ratio)
221
+ upper_limit = math.sqrt(self.max_ratio)
222
+ mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
223
+ u = random.randint(0, dimension-mask_side-1)
224
+ v = u+mask_side
225
+ else:
226
+ margin = (math.sqrt(self.default_ratio) / 2) * dimension
227
+ u = round(dimension/2 - margin)
228
+ v = round(dimension/2 + margin)
229
+ return u, v
230
+
231
+ def __call__(self, img, iter_i=None, raw_image=None):
232
+ c, height, width = img.shape
233
+ mask = np.zeros((height, width), np.float32)
234
+ x1, x2 = self._random_vector(width)
235
+ y1, y2 = self._random_vector(height)
236
+ mask[x1:x2, y1:y2] = 1
237
+ return mask[None, ...]
238
+
239
+
240
+ class IntegratedMaskGenerator:
241
+ """集成的mask生成器 - 支持多种mask类型混合"""
242
+ def __init__(self, irregular_proba=1/4, irregular_kwargs=None,
243
+ box_proba=1/4, box_kwargs=None,
244
+ segm_proba=1/4, segm_kwargs=None,
245
+ brush_stroke_proba=1/4, brush_stroke_kwargs=None,
246
+ superres_proba=0, superres_kwargs=None,
247
+ squares_proba=0, squares_kwargs=None,
248
+ invert_proba=0):
249
+ self.probas = []
250
+ self.gens = []
251
+
252
+ if irregular_proba > 0:
253
+ self.probas.append(irregular_proba)
254
+ if irregular_kwargs is None:
255
+ irregular_kwargs = {}
256
+ else:
257
+ irregular_kwargs = dict(irregular_kwargs)
258
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
259
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
260
+
261
+ if box_proba > 0:
262
+ self.probas.append(box_proba)
263
+ if box_kwargs is None:
264
+ box_kwargs = {}
265
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
266
+
267
+ if brush_stroke_proba > 0:
268
+ self.probas.append(brush_stroke_proba)
269
+ if brush_stroke_kwargs is None:
270
+ brush_stroke_kwargs = {}
271
+ self.gens.append(BrushStrokeMaskGenerator(**brush_stroke_kwargs))
272
+
273
+ if superres_proba > 0:
274
+ self.probas.append(superres_proba)
275
+ if superres_kwargs is None:
276
+ superres_kwargs = {}
277
+ self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
278
+
279
+ if squares_proba > 0:
280
+ self.probas.append(squares_proba)
281
+ if squares_kwargs is None:
282
+ squares_kwargs = {}
283
+ else:
284
+ squares_kwargs = dict(squares_kwargs)
285
+ squares_kwargs['draw_method'] = DrawMethod.SQUARE
286
+ self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
287
+
288
+ self.probas = np.array(self.probas, dtype='float32')
289
+ self.probas /= self.probas.sum()
290
+ self.invert_proba = invert_proba
291
+
292
+ def __call__(self, img, iter_i=None, raw_image=None):
293
+ kind = np.random.choice(len(self.probas), p=self.probas)
294
+ gen = self.gens[kind]
295
+ result = gen(img, iter_i=iter_i, raw_image=raw_image)
296
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
297
+ result = 1 - result
298
+ return result
299
+
300
+
301
+ def get_mask_generator(kind, kwargs):
302
+ """获取mask生成器的工厂函数"""
303
+ if kind is None:
304
+ kind = "integrated"
305
+ if kwargs is None:
306
+ kwargs = {}
307
+
308
+ if kind == "integrated":
309
+ cl = IntegratedMaskGenerator
310
+ elif kind == "irregular":
311
+ cl = RandomIrregularMaskGenerator
312
+ elif kind == "rectangle":
313
+ cl = RandomRectangleMaskGenerator
314
+ elif kind == "brush_stroke":
315
+ cl = BrushStrokeMaskGenerator
316
+ elif kind == "superres":
317
+ cl = RandomSuperresMaskGenerator
318
+ elif kind == "dumb":
319
+ cl = DumbAreaMaskGenerator
320
+ else:
321
+ raise NotImplementedError(f"No such generator kind = {kind}")
322
+ return cl(**kwargs)
train/src/pipeline_flux_kontext_control.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from .transformer_flux import FluxTransformer2DModel
7
+ from transformers import (
8
+ CLIPImageProcessor,
9
+ CLIPTextModel,
10
+ CLIPTokenizer,
11
+ CLIPVisionModelWithProjection,
12
+ T5EncoderModel,
13
+ T5TokenizerFast,
14
+ )
15
+
16
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
17
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
18
+ from diffusers.models import AutoencoderKL
19
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
20
+ from diffusers.utils import (
21
+ USE_PEFT_BACKEND,
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ scale_lora_layers,
26
+ unscale_lora_layers,
27
+ )
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
31
+ from torchvision.transforms.functional import pad
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ PREFERRED_KONTEXT_RESOLUTIONS = [
45
+ (672, 1568),
46
+ (688, 1504),
47
+ (720, 1456),
48
+ (752, 1392),
49
+ (800, 1328),
50
+ (832, 1248),
51
+ (880, 1184),
52
+ (944, 1104),
53
+ (1024, 1024),
54
+ (1104, 944),
55
+ (1184, 880),
56
+ (1248, 832),
57
+ (1328, 800),
58
+ (1392, 752),
59
+ (1456, 720),
60
+ (1504, 688),
61
+ (1568, 672),
62
+ ]
63
+
64
+
65
+ def calculate_shift(
66
+ image_seq_len,
67
+ base_seq_len: int = 256,
68
+ max_seq_len: int = 4096,
69
+ base_shift: float = 0.5,
70
+ max_shift: float = 1.15,
71
+ ):
72
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
73
+ b = base_shift - m * base_seq_len
74
+ mu = image_seq_len * m + b
75
+ return mu
76
+
77
+
78
+ def prepare_latent_image_ids_(height, width, device, dtype):
79
+ latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
80
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y
81
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] # x
82
+ return latent_image_ids
83
+
84
+
85
+ def prepare_latent_subject_ids(height, width, device, dtype):
86
+ latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
87
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]
88
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]
89
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
90
+ latent_image_ids = latent_image_ids.reshape(
91
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
92
+ )
93
+ return latent_image_ids.to(device=device, dtype=dtype)
94
+
95
+
96
+ def resize_position_encoding(
97
+ batch_size, original_height, original_width, target_height, target_width, device, dtype
98
+ ):
99
+ latent_image_ids = prepare_latent_image_ids_(original_height // 2, original_width // 2, device, dtype)
100
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
101
+ latent_image_ids = latent_image_ids.reshape(
102
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
103
+ )
104
+
105
+ scale_h = original_height / target_height
106
+ scale_w = original_width / target_width
107
+ latent_image_ids_resized = torch.zeros(target_height // 2, target_width // 2, 3, device=device, dtype=dtype)
108
+ latent_image_ids_resized[..., 1] = (
109
+ latent_image_ids_resized[..., 1] + torch.arange(target_height // 2, device=device)[:, None] * scale_h
110
+ )
111
+ latent_image_ids_resized[..., 2] = (
112
+ latent_image_ids_resized[..., 2] + torch.arange(target_width // 2, device=device)[None, :] * scale_w
113
+ )
114
+
115
+ cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = (
116
+ latent_image_ids_resized.shape
117
+ )
118
+ cond_latent_image_ids = latent_image_ids_resized.reshape(
119
+ cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
120
+ )
121
+ return latent_image_ids, cond_latent_image_ids
122
+
123
+
124
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
125
+ def retrieve_timesteps(
126
+ scheduler,
127
+ num_inference_steps: Optional[int] = None,
128
+ device: Optional[Union[str, torch.device]] = None,
129
+ timesteps: Optional[List[int]] = None,
130
+ sigmas: Optional[List[float]] = None,
131
+ **kwargs,
132
+ ):
133
+ r"""
134
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
135
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
136
+
137
+ Args:
138
+ scheduler (`SchedulerMixin`):
139
+ The scheduler to get timesteps from.
140
+ num_inference_steps (`int`):
141
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
142
+ must be `None`.
143
+ device (`str` or `torch.device`, *optional*):
144
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
145
+ timesteps (`List[int]`, *optional*):
146
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
147
+ `num_inference_steps` and `sigmas` must be `None`.
148
+ sigmas (`List[float]`, *optional*):
149
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
150
+ `num_inference_steps` and `timesteps` must be `None`.
151
+
152
+ Returns:
153
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
154
+ second element is the number of inference steps.
155
+ """
156
+ if timesteps is not None and sigmas is not None:
157
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
158
+ if timesteps is not None:
159
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if not accepts_timesteps:
161
+ raise ValueError(
162
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
163
+ f" timestep schedules. Please check whether you are using the correct scheduler."
164
+ )
165
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ num_inference_steps = len(timesteps)
168
+ elif sigmas is not None:
169
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
170
+ if not accept_sigmas:
171
+ raise ValueError(
172
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
173
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
174
+ )
175
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
176
+ timesteps = scheduler.timesteps
177
+ num_inference_steps = len(timesteps)
178
+ else:
179
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
180
+ timesteps = scheduler.timesteps
181
+ return timesteps, num_inference_steps
182
+
183
+
184
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
185
+ def retrieve_latents(
186
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
187
+ ):
188
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
189
+ return encoder_output.latent_dist.sample(generator)
190
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
191
+ return encoder_output.latent_dist.mode()
192
+ elif hasattr(encoder_output, "latents"):
193
+ return encoder_output.latents
194
+ else:
195
+ raise AttributeError("Could not access latents of provided encoder_output")
196
+
197
+
198
+ class FluxKontextControlPipeline(
199
+ DiffusionPipeline,
200
+ FluxLoraLoaderMixin,
201
+ FromSingleFileMixin,
202
+ TextualInversionLoaderMixin,
203
+ ):
204
+ r"""
205
+ The Flux Kontext pipeline for image-to-image and text-to-image generation with EasyControl.
206
+
207
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
208
+
209
+ Args:
210
+ transformer ([`FluxTransformer2DModel`]):
211
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
212
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
213
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
214
+ vae ([`AutoencoderKL`]):
215
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
216
+ text_encoder ([`CLIPTextModel`]):
217
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
218
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
219
+ text_encoder_2 ([`T5EncoderModel`]):
220
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
221
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
222
+ tokenizer (`CLIPTokenizer`):
223
+ Tokenizer of class
224
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
225
+ tokenizer_2 (`T5TokenizerFast`):
226
+ Second Tokenizer of class
227
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
228
+ """
229
+
230
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
231
+ _optional_components = []
232
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
233
+
234
+ def __init__(
235
+ self,
236
+ scheduler: FlowMatchEulerDiscreteScheduler,
237
+ vae: AutoencoderKL,
238
+ text_encoder: CLIPTextModel,
239
+ tokenizer: CLIPTokenizer,
240
+ text_encoder_2: T5EncoderModel,
241
+ tokenizer_2: T5TokenizerFast,
242
+ transformer: FluxTransformer2DModel,
243
+ image_encoder: CLIPVisionModelWithProjection = None,
244
+ feature_extractor: CLIPImageProcessor = None,
245
+ ):
246
+ super().__init__()
247
+
248
+ self.register_modules(
249
+ vae=vae,
250
+ text_encoder=text_encoder,
251
+ text_encoder_2=text_encoder_2,
252
+ tokenizer=tokenizer,
253
+ tokenizer_2=tokenizer_2,
254
+ transformer=transformer,
255
+ scheduler=scheduler,
256
+ image_encoder=None,
257
+ feature_extractor=None,
258
+ )
259
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
260
+ # Flux latents are packed into 2x2 patches, so use VAE factor multiplied by patch size for image processing
261
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
262
+ self.tokenizer_max_length = (
263
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
264
+ )
265
+ self.default_sample_size = 128
266
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
267
+ # EasyControl: cache multiple control LoRA processors
268
+ self.control_lora_processors: Dict[str, Dict[str, Any]] = {}
269
+ self.control_lora_cond_sizes: Dict[str, Any] = {}
270
+ self.current_control_type: Optional[str] = None
271
+
272
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
273
+ def _get_t5_prompt_embeds(
274
+ self,
275
+ prompt: Union[str, List[str]] = None,
276
+ num_images_per_prompt: int = 1,
277
+ max_sequence_length: int = 512,
278
+ device: Optional[torch.device] = None,
279
+ dtype: Optional[torch.dtype] = None,
280
+ ):
281
+ device = device or self._execution_device
282
+ dtype = dtype or self.text_encoder.dtype
283
+
284
+ prompt = [prompt] if isinstance(prompt, str) else prompt
285
+ batch_size = len(prompt)
286
+
287
+ if isinstance(self, TextualInversionLoaderMixin):
288
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
289
+
290
+ text_inputs = self.tokenizer_2(
291
+ prompt,
292
+ padding="max_length",
293
+ max_length=max_sequence_length,
294
+ truncation=True,
295
+ return_length=False,
296
+ return_overflowing_tokens=False,
297
+ return_tensors="pt",
298
+ )
299
+ text_input_ids = text_inputs.input_ids
300
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
301
+
302
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
303
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
304
+ logger.warning(
305
+ "The following part of your input was truncated because `max_sequence_length` is set to "
306
+ f" {max_sequence_length} tokens: {removed_text}"
307
+ )
308
+
309
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
310
+
311
+ dtype = self.text_encoder_2.dtype
312
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
313
+
314
+ _, seq_len, _ = prompt_embeds.shape
315
+
316
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
317
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
318
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
319
+
320
+ return prompt_embeds
321
+
322
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
323
+ def _get_clip_prompt_embeds(
324
+ self,
325
+ prompt: Union[str, List[str]],
326
+ num_images_per_prompt: int = 1,
327
+ device: Optional[torch.device] = None,
328
+ ):
329
+ device = device or self._execution_device
330
+
331
+ prompt = [prompt] if isinstance(prompt, str) else prompt
332
+ batch_size = len(prompt)
333
+
334
+ if isinstance(self, TextualInversionLoaderMixin):
335
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
336
+
337
+ text_inputs = self.tokenizer(
338
+ prompt,
339
+ padding="max_length",
340
+ max_length=self.tokenizer_max_length,
341
+ truncation=True,
342
+ return_overflowing_tokens=False,
343
+ return_length=False,
344
+ return_tensors="pt",
345
+ )
346
+
347
+ text_input_ids = text_inputs.input_ids
348
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
349
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
350
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
351
+ logger.warning(
352
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
353
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
354
+ )
355
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
356
+
357
+ # Use pooled output of CLIPTextModel
358
+ prompt_embeds = prompt_embeds.pooler_output
359
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
360
+
361
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
362
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
363
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
364
+
365
+ return prompt_embeds
366
+
367
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
368
+ def encode_prompt(
369
+ self,
370
+ prompt: Union[str, List[str]],
371
+ prompt_2: Union[str, List[str]],
372
+ device: Optional[torch.device] = None,
373
+ num_images_per_prompt: int = 1,
374
+ prompt_embeds: Optional[torch.FloatTensor] = None,
375
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
376
+ max_sequence_length: int = 512,
377
+ lora_scale: Optional[float] = None,
378
+ ):
379
+ r"""
380
+
381
+ Args:
382
+ prompt (`str` or `List[str]`, *optional*):
383
+ prompt to be encoded
384
+ prompt_2 (`str` or `List[str]`, *optional*):
385
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
386
+ used in all text-encoders
387
+ device: (`torch.device`):
388
+ torch device
389
+ num_images_per_prompt (`int`):
390
+ number of images that should be generated per prompt
391
+ prompt_embeds (`torch.FloatTensor`, *optional*):
392
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
393
+ provided, text embeddings will be generated from `prompt` input argument.
394
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
395
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
396
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
397
+ lora_scale (`float`, *optional*):
398
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
399
+ """
400
+ device = device or self._execution_device
401
+
402
+ # set lora scale so that monkey patched LoRA
403
+ # function of text encoder can correctly access it
404
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
405
+ self._lora_scale = lora_scale
406
+
407
+ # dynamically adjust the LoRA scale
408
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
409
+ scale_lora_layers(self.text_encoder, lora_scale)
410
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
411
+ scale_lora_layers(self.text_encoder_2, lora_scale)
412
+
413
+ prompt = [prompt] if isinstance(prompt, str) else prompt
414
+
415
+ if prompt_embeds is None:
416
+ prompt_2 = prompt_2 or prompt
417
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
418
+
419
+ # We only use the pooled prompt output from the CLIPTextModel
420
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
421
+ prompt=prompt,
422
+ device=device,
423
+ num_images_per_prompt=num_images_per_prompt,
424
+ )
425
+ prompt_embeds = self._get_t5_prompt_embeds(
426
+ prompt=prompt_2,
427
+ num_images_per_prompt=num_images_per_prompt,
428
+ max_sequence_length=max_sequence_length,
429
+ device=device,
430
+ )
431
+
432
+ if self.text_encoder is not None:
433
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
434
+ # Retrieve the original scale by scaling back the LoRA layers
435
+ unscale_lora_layers(self.text_encoder, lora_scale)
436
+
437
+ if self.text_encoder_2 is not None:
438
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
439
+ # Retrieve the original scale by scaling back the LoRA layers
440
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
441
+
442
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
443
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
444
+
445
+ return prompt_embeds, pooled_prompt_embeds, text_ids
446
+
447
+ # Adapted from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
448
+ def check_inputs(
449
+ self,
450
+ prompt,
451
+ prompt_2,
452
+ height,
453
+ width,
454
+ prompt_embeds=None,
455
+ pooled_prompt_embeds=None,
456
+ callback_on_step_end_tensor_inputs=None,
457
+ max_sequence_length=None,
458
+ ):
459
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
460
+ raise ValueError(
461
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
462
+ )
463
+
464
+ if callback_on_step_end_tensor_inputs is not None and not all(
465
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
466
+ ):
467
+ raise ValueError(
468
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
469
+ )
470
+
471
+ if prompt is not None and prompt_embeds is not None:
472
+ raise ValueError(
473
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
474
+ " only forward one of the two."
475
+ )
476
+ elif prompt_2 is not None and prompt_embeds is not None:
477
+ raise ValueError(
478
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
479
+ " only forward one of the two."
480
+ )
481
+ elif prompt is None and prompt_embeds is None:
482
+ raise ValueError(
483
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
484
+ )
485
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
486
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
487
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
488
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
489
+
490
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
491
+ raise ValueError(
492
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
493
+ )
494
+
495
+ if max_sequence_length is not None and max_sequence_length > 512:
496
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
497
+
498
+ @staticmethod
499
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
500
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
501
+ latent_image_ids = torch.zeros(height, width, 3)
502
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
503
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
504
+
505
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
506
+
507
+ latent_image_ids = latent_image_ids.reshape(
508
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
509
+ )
510
+
511
+ return latent_image_ids.to(device=device, dtype=dtype)
512
+
513
+ @staticmethod
514
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
515
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
516
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
517
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
518
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
519
+
520
+ return latents
521
+
522
+ @staticmethod
523
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
524
+ def _unpack_latents(latents, height, width, vae_scale_factor):
525
+ batch_size, num_patches, channels = latents.shape
526
+
527
+ # VAE applies 8x compression on images but we must also account for packing which requires
528
+ # latent height and width to be divisible by 2.
529
+ height = 2 * (int(height) // (vae_scale_factor * 2))
530
+ width = 2 * (int(width) // (vae_scale_factor * 2))
531
+
532
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
533
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
534
+
535
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
536
+
537
+ return latents
538
+
539
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
540
+ if isinstance(generator, list):
541
+ image_latents = [
542
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
543
+ for i in range(image.shape[0])
544
+ ]
545
+ image_latents = torch.cat(image_latents, dim=0)
546
+ else:
547
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
548
+
549
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
550
+
551
+ return image_latents
552
+
553
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
554
+ def enable_vae_slicing(self):
555
+ r"""
556
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
557
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
558
+ """
559
+ self.vae.enable_slicing()
560
+
561
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
562
+ def disable_vae_slicing(self):
563
+ r"""
564
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
565
+ computing decoding in one step.
566
+ """
567
+ self.vae.disable_slicing()
568
+
569
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
570
+ def enable_vae_tiling(self):
571
+ r"""
572
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
573
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
574
+ processing larger images.
575
+ """
576
+ self.vae.enable_tiling()
577
+
578
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
579
+ def disable_vae_tiling(self):
580
+ r"""
581
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
582
+ computing decoding in one step.
583
+ """
584
+ self.vae.disable_tiling()
585
+
586
+ def prepare_latents(
587
+ self,
588
+ batch_size,
589
+ num_channels_latents,
590
+ height,
591
+ width,
592
+ dtype,
593
+ device,
594
+ generator,
595
+ image,
596
+ subject_images,
597
+ spatial_images,
598
+ latents=None,
599
+ cond_size=512,
600
+ ):
601
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
602
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
603
+ height_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
604
+ width_cond = 2 * (cond_size // (self.vae_scale_factor * 2))
605
+
606
+ image_latents = image_ids = None
607
+ # Prepare noise latents
608
+ shape = (batch_size, num_channels_latents, height, width)
609
+ if latents is None:
610
+ noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
611
+ else:
612
+ noise_latents = latents.to(device=device, dtype=dtype)
613
+
614
+ noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
615
+ # print(noise_latents.shape)
616
+ noise_latent_image_ids, cond_latent_image_ids_resized = resize_position_encoding(
617
+ batch_size, height, width, height_cond, width_cond, device, dtype
618
+ )
619
+ # noise IDs are marked with 0 in the first channel
620
+ noise_latent_image_ids[..., 0] = 0
621
+
622
+ cond_latents_to_concat = []
623
+ latents_ids_to_concat = [noise_latent_image_ids]
624
+
625
+ # 1. Prepare `image` (Kontext) latents
626
+ if image is not None:
627
+ image = image.to(device=device, dtype=dtype)
628
+ if image.shape[1] != self.latent_channels:
629
+ image_latents = self._encode_vae_image(image=image, generator=generator)
630
+ else:
631
+ image_latents = image
632
+
633
+ image_latent_h, image_latent_w = image_latents.shape[2:]
634
+ image_latents = self._pack_latents(
635
+ image_latents, batch_size, num_channels_latents, image_latent_h, image_latent_w
636
+ )
637
+ image_ids = self._prepare_latent_image_ids(
638
+ batch_size, image_latent_h // 2, image_latent_w // 2, device, dtype
639
+ )
640
+ image_ids[..., 0] = 1 # Mark as condition
641
+ latents_ids_to_concat.append(image_ids)
642
+
643
+ # 2. Prepare `subject_images` latents
644
+ if subject_images is not None and len(subject_images) > 0:
645
+ subject_images = subject_images.to(device=device, dtype=dtype)
646
+ subject_image_latents = self._encode_vae_image(image=subject_images, generator=generator)
647
+ subject_latents = self._pack_latents(
648
+ subject_image_latents, batch_size, num_channels_latents, height_cond * len(subject_images), width_cond
649
+ )
650
+
651
+ latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, device, dtype)
652
+ latent_subject_ids[..., 0] = 1
653
+ latent_subject_ids[:, 1] += image_latent_h // 2
654
+ subject_latent_image_ids = torch.cat([latent_subject_ids for _ in range(len(subject_images))], dim=0)
655
+
656
+ cond_latents_to_concat.append(subject_latents)
657
+ latents_ids_to_concat.append(subject_latent_image_ids)
658
+
659
+ # 3. Prepare `spatial_images` latents
660
+ if spatial_images is not None and len(spatial_images) > 0:
661
+ spatial_images = spatial_images.to(device=device, dtype=dtype)
662
+ spatial_image_latents = self._encode_vae_image(image=spatial_images, generator=generator)
663
+ cond_latents = self._pack_latents(
664
+ spatial_image_latents, batch_size, num_channels_latents, height_cond * len(spatial_images), width_cond
665
+ )
666
+ cond_latent_image_ids_resized[..., 0] = 2
667
+ cond_latent_image_ids = torch.cat(
668
+ [cond_latent_image_ids_resized for _ in range(len(spatial_images))], dim=0
669
+ )
670
+
671
+ cond_latents_to_concat.append(cond_latents)
672
+ latents_ids_to_concat.append(cond_latent_image_ids)
673
+
674
+ cond_latents = torch.cat(cond_latents_to_concat, dim=1) if cond_latents_to_concat else None
675
+ latent_image_ids = torch.cat(latents_ids_to_concat, dim=0)
676
+
677
+ return noise_latents, image_latents, cond_latents, latent_image_ids
678
+
679
+ @property
680
+ def guidance_scale(self):
681
+ return self._guidance_scale
682
+
683
+ @property
684
+ def joint_attention_kwargs(self):
685
+ return self._joint_attention_kwargs
686
+
687
+ @property
688
+ def num_timesteps(self):
689
+ return self._num_timesteps
690
+
691
+ @property
692
+ def current_timestep(self):
693
+ return self._current_timestep
694
+
695
+ @property
696
+ def interrupt(self):
697
+ return self._interrupt
698
+
699
+ @torch.no_grad()
700
+ def __call__(
701
+ self,
702
+ image: Optional[PipelineImageInput] = None,
703
+ prompt: Union[str, List[str]] = None,
704
+ prompt_2: Optional[Union[str, List[str]]] = None,
705
+ height: Optional[int] = None,
706
+ width: Optional[int] = None,
707
+ num_inference_steps: int = 28,
708
+ sigmas: Optional[List[float]] = None,
709
+ guidance_scale: float = 3.5,
710
+ num_images_per_prompt: Optional[int] = 1,
711
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
712
+ latents: Optional[torch.FloatTensor] = None,
713
+ prompt_embeds: Optional[torch.FloatTensor] = None,
714
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
715
+ output_type: Optional[str] = "pil",
716
+ return_dict: bool = True,
717
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
718
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
719
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
720
+ max_sequence_length: int = 512,
721
+ cond_size: int = 512,
722
+ control_dict: Optional[Dict[str, Any]] = None,
723
+ ):
724
+ r"""
725
+ Function invoked when calling the pipeline for generation.
726
+
727
+ Args:
728
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
729
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
730
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
731
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
732
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
733
+ latents as `image`, but if passing latents directly it is not encoded again.
734
+ prompt (`str` or `List[str]`, *optional*):
735
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
736
+ instead.
737
+ prompt_2 (`str` or `List[str]`, *optional*):
738
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
739
+ will be used instead.
740
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
741
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
742
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
743
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
744
+ num_inference_steps (`int`, *optional*, defaults to 50):
745
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
746
+ expense of slower inference.
747
+ sigmas (`List[float]`, *optional*):
748
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
749
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
750
+ will be used.
751
+ guidance_scale (`float`, *optional*, defaults to 3.5):
752
+ Guidance scale as defined in [Classifier-Free Diffusion
753
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
754
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
755
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
756
+ the text `prompt`, usually at the expense of lower image quality.
757
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
758
+ The number of images to generate per prompt.
759
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
760
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
761
+ to make generation deterministic.
762
+ latents (`torch.FloatTensor`, *optional*):
763
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
764
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
765
+ tensor will ge generated by sampling using the supplied random `generator`.
766
+ prompt_embeds (`torch.FloatTensor`, *optional*):
767
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
768
+ provided, text embeddings will be generated from `prompt` input argument.
769
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
770
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
771
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
772
+ output_type (`str`, *optional*, defaults to `"pil"`):
773
+ The output format of the generate image. Choose between
774
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
775
+ return_dict (`bool`, *optional*, defaults to `True`):
776
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
777
+ joint_attention_kwargs (`dict`, *optional*):
778
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
779
+ `self.processor` in
780
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
781
+ callback_on_step_end (`Callable`, *optional*):
782
+ A function that calls at the end of each denoising steps during the inference. The function is called
783
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
784
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
785
+ `callback_on_step_end_tensor_inputs`.
786
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
787
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
788
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
789
+ `._callback_tensor_inputs` attribute of your pipeline class.
790
+ max_sequence_length (`int` defaults to 512):
791
+ Maximum sequence length to use with the `prompt`.
792
+ cond_size (`int`, *optional*, defaults to 512):
793
+ The size for conditioning images.
794
+
795
+ Examples:
796
+
797
+ Returns:
798
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
799
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
800
+ images.
801
+ """
802
+
803
+ height = height or self.default_sample_size * self.vae_scale_factor
804
+ width = width or self.default_sample_size * self.vae_scale_factor
805
+
806
+ # 1. Check inputs. Raise error if not correct
807
+ self.check_inputs(
808
+ prompt,
809
+ prompt_2,
810
+ height,
811
+ width,
812
+ prompt_embeds=prompt_embeds,
813
+ pooled_prompt_embeds=pooled_prompt_embeds,
814
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
815
+ max_sequence_length=max_sequence_length,
816
+ )
817
+
818
+ self._guidance_scale = guidance_scale
819
+ self._joint_attention_kwargs = joint_attention_kwargs
820
+ self._current_timestep = None
821
+ self._interrupt = False
822
+
823
+ spatial_images = control_dict.get("spatial_images", [])
824
+ subject_images = control_dict.get("subject_images", [])
825
+
826
+ # 2. Define call parameters
827
+ if prompt is not None and isinstance(prompt, str):
828
+ batch_size = 1
829
+ elif prompt is not None and isinstance(prompt, list):
830
+ batch_size = len(prompt)
831
+ else:
832
+ batch_size = prompt_embeds.shape[0]
833
+
834
+ device = self._execution_device
835
+
836
+ lora_scale = (
837
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
838
+ )
839
+ (
840
+ prompt_embeds,
841
+ pooled_prompt_embeds,
842
+ text_ids,
843
+ ) = self.encode_prompt(
844
+ prompt=prompt,
845
+ prompt_2=prompt_2,
846
+ prompt_embeds=prompt_embeds,
847
+ pooled_prompt_embeds=pooled_prompt_embeds,
848
+ device=device,
849
+ num_images_per_prompt=num_images_per_prompt,
850
+ max_sequence_length=max_sequence_length,
851
+ lora_scale=lora_scale,
852
+ )
853
+
854
+ # 3. Preprocess images
855
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
856
+ img = image[0] if isinstance(image, list) else image
857
+ image_height, image_width = self.image_processor.get_default_height_width(img)
858
+ aspect_ratio = image_width / image_height
859
+ # Kontext is trained on specific resolutions, using one of them is recommended
860
+ _, image_width, image_height = min(
861
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
862
+ )
863
+ multiple_of = self.vae_scale_factor * 2
864
+ image_width = image_width // multiple_of * multiple_of
865
+ image_height = image_height // multiple_of * multiple_of
866
+ image = self.image_processor.resize(image, image_height, image_width)
867
+ image = self.image_processor.preprocess(image, image_height, image_width)
868
+ image = image.to(dtype=self.vae.dtype)
869
+
870
+ if len(subject_images) > 0:
871
+ subject_image_ls = []
872
+ for subject_image in subject_images:
873
+ w, h = subject_image.size[:2]
874
+ scale = cond_size / max(h, w)
875
+ new_h, new_w = int(h * scale), int(w * scale)
876
+ subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
877
+ subject_image = subject_image.to(dtype=self.vae.dtype)
878
+ pad_h = cond_size - subject_image.shape[-2]
879
+ pad_w = cond_size - subject_image.shape[-1]
880
+ subject_image = pad(
881
+ subject_image, padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)), fill=0
882
+ )
883
+ subject_image_ls.append(subject_image)
884
+ subject_images = torch.cat(subject_image_ls, dim=-2)
885
+ else:
886
+ subject_images = None
887
+
888
+ if len(spatial_images) > 0:
889
+ condition_image_ls = []
890
+ for img in spatial_images:
891
+ condition_image = self.image_processor.preprocess(img, height=cond_size, width=cond_size)
892
+ condition_image = condition_image.to(dtype=self.vae.dtype)
893
+ condition_image_ls.append(condition_image)
894
+ spatial_images = torch.cat(condition_image_ls, dim=-2)
895
+ else:
896
+ spatial_images = None
897
+
898
+ # 4. Prepare latent variables
899
+ num_channels_latents = self.transformer.config.in_channels // 4
900
+ latents, image_latents, cond_latents, latent_image_ids = self.prepare_latents(
901
+ batch_size * num_images_per_prompt,
902
+ num_channels_latents,
903
+ height,
904
+ width,
905
+ prompt_embeds.dtype,
906
+ device,
907
+ generator,
908
+ image,
909
+ subject_images,
910
+ spatial_images,
911
+ latents,
912
+ cond_size,
913
+ )
914
+
915
+ # 5. Prepare timesteps
916
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
917
+ image_seq_len = latents.shape[1]
918
+ mu = calculate_shift(
919
+ image_seq_len,
920
+ self.scheduler.config.get("base_image_seq_len", 256),
921
+ self.scheduler.config.get("max_image_seq_len", 4096),
922
+ self.scheduler.config.get("base_shift", 0.5),
923
+ self.scheduler.config.get("max_shift", 1.15),
924
+ )
925
+ timesteps, num_inference_steps = retrieve_timesteps(
926
+ self.scheduler,
927
+ num_inference_steps,
928
+ device,
929
+ sigmas=sigmas,
930
+ mu=mu,
931
+ )
932
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
933
+ self._num_timesteps = len(timesteps)
934
+
935
+ # handle guidance
936
+ if self.transformer.config.guidance_embeds:
937
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
938
+ guidance = guidance.expand(latents.shape[0])
939
+ else:
940
+ guidance = None
941
+
942
+ # 6. Denoising loop
943
+ self.scheduler.set_begin_index(0)
944
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
945
+ for i, t in enumerate(timesteps):
946
+ if self.interrupt:
947
+ continue
948
+
949
+ latent_model_input = latents
950
+ if image_latents is not None:
951
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
952
+
953
+ self._current_timestep = t
954
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
955
+ noise_pred = self.transformer(
956
+ hidden_states=latent_model_input,
957
+ cond_hidden_states=cond_latents,
958
+ timestep=timestep / 1000,
959
+ guidance=guidance,
960
+ pooled_projections=pooled_prompt_embeds,
961
+ encoder_hidden_states=prompt_embeds,
962
+ txt_ids=text_ids,
963
+ img_ids=latent_image_ids,
964
+ joint_attention_kwargs=self.joint_attention_kwargs,
965
+ return_dict=False,
966
+ )[0]
967
+
968
+ noise_pred = noise_pred[:, : latents.size(1)]
969
+
970
+ # compute the previous noisy sample x_t -> x_t-1
971
+ latents_dtype = latents.dtype
972
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
973
+
974
+ if latents.dtype != latents_dtype:
975
+ latents = latents.to(latents_dtype)
976
+
977
+ if callback_on_step_end is not None:
978
+ callback_kwargs = {}
979
+ for k in callback_on_step_end_tensor_inputs:
980
+ callback_kwargs[k] = locals()[k]
981
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
982
+
983
+ latents = callback_outputs.pop("latents", latents)
984
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
985
+
986
+ # call the callback, if provided
987
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
988
+ progress_bar.update()
989
+
990
+ if XLA_AVAILABLE:
991
+ xm.mark_step()
992
+
993
+ self._current_timestep = None
994
+
995
+ if output_type == "latent":
996
+ image = latents
997
+ else:
998
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
999
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1000
+ image = self.vae.decode(latents, return_dict=False)[0]
1001
+ image = self.image_processor.postprocess(image, output_type=output_type)
1002
+
1003
+ # Offload all models
1004
+ self.maybe_free_model_hooks()
1005
+
1006
+ if not return_dict:
1007
+ return (image,)
1008
+
1009
+ return FluxPipelineOutput(images=image)
train/src/prompt_helper.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def load_text_encoders(args, class_one, class_two):
5
+ text_encoder_one = class_one.from_pretrained(
6
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
7
+ )
8
+ text_encoder_two = class_two.from_pretrained(
9
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
10
+ )
11
+ return text_encoder_one, text_encoder_two
12
+
13
+
14
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
15
+ text_inputs = tokenizer(
16
+ prompt,
17
+ padding="max_length",
18
+ max_length=max_sequence_length,
19
+ truncation=True,
20
+ return_length=False,
21
+ return_overflowing_tokens=False,
22
+ return_tensors="pt",
23
+ )
24
+ text_input_ids = text_inputs.input_ids
25
+ return text_input_ids
26
+
27
+
28
+ def tokenize_prompt_clip(tokenizer, prompt):
29
+ text_inputs = tokenizer(
30
+ prompt,
31
+ padding="max_length",
32
+ max_length=77,
33
+ truncation=True,
34
+ return_length=False,
35
+ return_overflowing_tokens=False,
36
+ return_tensors="pt",
37
+ )
38
+ text_input_ids = text_inputs.input_ids
39
+ return text_input_ids
40
+
41
+
42
+ def tokenize_prompt_t5(tokenizer, prompt):
43
+ text_inputs = tokenizer(
44
+ prompt,
45
+ padding="max_length",
46
+ max_length=512,
47
+ truncation=True,
48
+ return_length=False,
49
+ return_overflowing_tokens=False,
50
+ return_tensors="pt",
51
+ )
52
+ text_input_ids = text_inputs.input_ids
53
+ return text_input_ids
54
+
55
+
56
+ def _encode_prompt_with_t5(
57
+ text_encoder,
58
+ tokenizer,
59
+ max_sequence_length=512,
60
+ prompt=None,
61
+ num_images_per_prompt=1,
62
+ device=None,
63
+ text_input_ids=None,
64
+ ):
65
+ prompt = [prompt] if isinstance(prompt, str) else prompt
66
+ batch_size = len(prompt)
67
+
68
+ if tokenizer is not None:
69
+ text_inputs = tokenizer(
70
+ prompt,
71
+ padding="max_length",
72
+ max_length=max_sequence_length,
73
+ truncation=True,
74
+ return_length=False,
75
+ return_overflowing_tokens=False,
76
+ return_tensors="pt",
77
+ )
78
+ text_input_ids = text_inputs.input_ids
79
+ else:
80
+ if text_input_ids is None:
81
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
82
+
83
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
84
+
85
+ dtype = text_encoder.dtype
86
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87
+
88
+ _, seq_len, _ = prompt_embeds.shape
89
+
90
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93
+
94
+ return prompt_embeds
95
+
96
+
97
+ def _encode_prompt_with_clip(
98
+ text_encoder,
99
+ tokenizer,
100
+ prompt: str,
101
+ device=None,
102
+ text_input_ids=None,
103
+ num_images_per_prompt: int = 1,
104
+ ):
105
+ prompt = [prompt] if isinstance(prompt, str) else prompt
106
+ batch_size = len(prompt)
107
+
108
+ if tokenizer is not None:
109
+ text_inputs = tokenizer(
110
+ prompt,
111
+ padding="max_length",
112
+ max_length=77,
113
+ truncation=True,
114
+ return_overflowing_tokens=False,
115
+ return_length=False,
116
+ return_tensors="pt",
117
+ )
118
+
119
+ text_input_ids = text_inputs.input_ids
120
+ else:
121
+ if text_input_ids is None:
122
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
123
+
124
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
125
+
126
+ # Use pooled output of CLIPTextModel
127
+ prompt_embeds = prompt_embeds.pooler_output
128
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
129
+
130
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
131
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
132
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
133
+
134
+ return prompt_embeds
135
+
136
+
137
+ def encode_prompt(
138
+ text_encoders,
139
+ tokenizers,
140
+ prompt: str,
141
+ max_sequence_length,
142
+ device=None,
143
+ num_images_per_prompt: int = 1,
144
+ text_input_ids_list=None,
145
+ ):
146
+ prompt = [prompt] if isinstance(prompt, str) else prompt
147
+ dtype = text_encoders[0].dtype
148
+
149
+ pooled_prompt_embeds = _encode_prompt_with_clip(
150
+ text_encoder=text_encoders[0],
151
+ tokenizer=tokenizers[0],
152
+ prompt=prompt,
153
+ device=device if device is not None else text_encoders[0].device,
154
+ num_images_per_prompt=num_images_per_prompt,
155
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
156
+ )
157
+
158
+ prompt_embeds = _encode_prompt_with_t5(
159
+ text_encoder=text_encoders[1],
160
+ tokenizer=tokenizers[1],
161
+ max_sequence_length=max_sequence_length,
162
+ prompt=prompt,
163
+ num_images_per_prompt=num_images_per_prompt,
164
+ device=device if device is not None else text_encoders[1].device,
165
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
166
+ )
167
+
168
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
169
+
170
+ return prompt_embeds, pooled_prompt_embeds, text_ids
171
+
172
+
173
+ def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
174
+ text_encoder_clip = text_encoders[0]
175
+ text_encoder_t5 = text_encoders[1]
176
+ tokens_clip, tokens_t5 = tokens[0], tokens[1]
177
+ batch_size = tokens_clip.shape[0]
178
+
179
+ if device == "cpu":
180
+ device = "cpu"
181
+ else:
182
+ device = accelerator.device
183
+
184
+ # clip
185
+ prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
186
+ # Use pooled output of CLIPTextModel
187
+ prompt_embeds = prompt_embeds.pooler_output
188
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
189
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
190
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
+ pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
192
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
193
+
194
+ # t5
195
+ prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
196
+ dtype = text_encoder_t5.dtype
197
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
198
+ _, seq_len, _ = prompt_embeds.shape
199
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
200
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
202
+
203
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
204
+
205
+ return prompt_embeds, pooled_prompt_embeds, text_ids
train/src/transformer_flux.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ FluxAttnProcessor2_0,
15
+ FluxAttnProcessor2_0_NPU,
16
+ FusedFluxAttnProcessor2_0,
17
+ )
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
21
+ from diffusers.utils.import_utils import is_torch_npu_available
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+ @maybe_allow_in_graph
29
+ class FluxSingleTransformerBlock(nn.Module):
30
+
31
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
32
+ super().__init__()
33
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
34
+
35
+ self.norm = AdaLayerNormZeroSingle(dim)
36
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
37
+ self.act_mlp = nn.GELU(approximate="tanh")
38
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
39
+
40
+ if is_torch_npu_available():
41
+ processor = FluxAttnProcessor2_0_NPU()
42
+ else:
43
+ processor = FluxAttnProcessor2_0()
44
+ self.attn = Attention(
45
+ query_dim=dim,
46
+ cross_attention_dim=None,
47
+ dim_head=attention_head_dim,
48
+ heads=num_attention_heads,
49
+ out_dim=dim,
50
+ bias=True,
51
+ processor=processor,
52
+ qk_norm="rms_norm",
53
+ eps=1e-6,
54
+ pre_only=True,
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ hidden_states: torch.Tensor,
60
+ cond_hidden_states: torch.Tensor,
61
+ temb: torch.Tensor,
62
+ cond_temb: torch.Tensor,
63
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
64
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65
+ ) -> torch.Tensor:
66
+ use_cond = cond_hidden_states is not None
67
+
68
+ residual = hidden_states
69
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
70
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
71
+
72
+ if use_cond:
73
+ residual_cond = cond_hidden_states
74
+ norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
75
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
76
+
77
+ norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
78
+
79
+ joint_attention_kwargs = joint_attention_kwargs or {}
80
+ attn_output = self.attn(
81
+ hidden_states=norm_hidden_states_concat,
82
+ image_rotary_emb=image_rotary_emb,
83
+ use_cond=use_cond,
84
+ **joint_attention_kwargs,
85
+ )
86
+ if use_cond:
87
+ attn_output, cond_attn_output = attn_output
88
+
89
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90
+ gate = gate.unsqueeze(1)
91
+ hidden_states = gate * self.proj_out(hidden_states)
92
+ hidden_states = residual + hidden_states
93
+
94
+ if use_cond:
95
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
96
+ cond_gate = cond_gate.unsqueeze(1)
97
+ condition_latents = cond_gate * self.proj_out(condition_latents)
98
+ condition_latents = residual_cond + condition_latents
99
+
100
+ if hidden_states.dtype == torch.float16:
101
+ hidden_states = hidden_states.clip(-65504, 65504)
102
+
103
+ return hidden_states, condition_latents if use_cond else None
104
+
105
+
106
+ @maybe_allow_in_graph
107
+ class FluxTransformerBlock(nn.Module):
108
+ def __init__(
109
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
110
+ ):
111
+ super().__init__()
112
+
113
+ self.norm1 = AdaLayerNormZero(dim)
114
+
115
+ self.norm1_context = AdaLayerNormZero(dim)
116
+
117
+ if hasattr(F, "scaled_dot_product_attention"):
118
+ processor = FluxAttnProcessor2_0()
119
+ else:
120
+ raise ValueError(
121
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
122
+ )
123
+ self.attn = Attention(
124
+ query_dim=dim,
125
+ cross_attention_dim=None,
126
+ added_kv_proj_dim=dim,
127
+ dim_head=attention_head_dim,
128
+ heads=num_attention_heads,
129
+ out_dim=dim,
130
+ context_pre_only=False,
131
+ bias=True,
132
+ processor=processor,
133
+ qk_norm=qk_norm,
134
+ eps=eps,
135
+ )
136
+
137
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139
+
140
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
141
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
142
+
143
+ # let chunk size default to None
144
+ self._chunk_size = None
145
+ self._chunk_dim = 0
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ cond_hidden_states: torch.Tensor,
151
+ encoder_hidden_states: torch.Tensor,
152
+ temb: torch.Tensor,
153
+ cond_temb: torch.Tensor,
154
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
156
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ use_cond = cond_hidden_states is not None
158
+
159
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
160
+ if use_cond:
161
+ (
162
+ norm_cond_hidden_states,
163
+ cond_gate_msa,
164
+ cond_shift_mlp,
165
+ cond_scale_mlp,
166
+ cond_gate_mlp,
167
+ ) = self.norm1(cond_hidden_states, emb=cond_temb)
168
+
169
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170
+ encoder_hidden_states, emb=temb
171
+ )
172
+
173
+ norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
174
+
175
+ joint_attention_kwargs = joint_attention_kwargs or {}
176
+ # Attention.
177
+ attention_outputs = self.attn(
178
+ hidden_states=norm_hidden_states,
179
+ encoder_hidden_states=norm_encoder_hidden_states,
180
+ image_rotary_emb=image_rotary_emb,
181
+ use_cond=use_cond,
182
+ **joint_attention_kwargs,
183
+ )
184
+
185
+ attn_output, context_attn_output = attention_outputs[:2]
186
+ cond_attn_output = attention_outputs[2] if use_cond else None
187
+
188
+ # Process attention outputs for the `hidden_states`.
189
+ attn_output = gate_msa.unsqueeze(1) * attn_output
190
+ hidden_states = hidden_states + attn_output
191
+
192
+ if use_cond:
193
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
194
+ cond_hidden_states = cond_hidden_states + cond_attn_output
195
+
196
+ norm_hidden_states = self.norm2(hidden_states)
197
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
198
+
199
+ if use_cond:
200
+ norm_cond_hidden_states = self.norm2(cond_hidden_states)
201
+ norm_cond_hidden_states = (
202
+ norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
203
+ + cond_shift_mlp[:, None]
204
+ )
205
+
206
+ ff_output = self.ff(norm_hidden_states)
207
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
208
+ hidden_states = hidden_states + ff_output
209
+
210
+ if use_cond:
211
+ cond_ff_output = self.ff(norm_cond_hidden_states)
212
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
213
+ cond_hidden_states = cond_hidden_states + cond_ff_output
214
+
215
+ # Process attention outputs for the `encoder_hidden_states`.
216
+
217
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
218
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
219
+
220
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
221
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
222
+
223
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
224
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
225
+ if encoder_hidden_states.dtype == torch.float16:
226
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
227
+
228
+ return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
229
+
230
+
231
+ class FluxTransformer2DModel(
232
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
233
+ ):
234
+ _supports_gradient_checkpointing = True
235
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
236
+
237
+ @register_to_config
238
+ def __init__(
239
+ self,
240
+ patch_size: int = 1,
241
+ in_channels: int = 64,
242
+ out_channels: Optional[int] = None,
243
+ num_layers: int = 19,
244
+ num_single_layers: int = 38,
245
+ attention_head_dim: int = 128,
246
+ num_attention_heads: int = 24,
247
+ joint_attention_dim: int = 4096,
248
+ pooled_projection_dim: int = 768,
249
+ guidance_embeds: bool = False,
250
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
251
+ ):
252
+ super().__init__()
253
+ self.out_channels = out_channels or in_channels
254
+ self.inner_dim = num_attention_heads * attention_head_dim
255
+
256
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
257
+
258
+ text_time_guidance_cls = (
259
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
260
+ )
261
+ self.time_text_embed = text_time_guidance_cls(
262
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
263
+ )
264
+
265
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
266
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
267
+
268
+ self.transformer_blocks = nn.ModuleList(
269
+ [
270
+ FluxTransformerBlock(
271
+ dim=self.inner_dim,
272
+ num_attention_heads=num_attention_heads,
273
+ attention_head_dim=attention_head_dim,
274
+ )
275
+ for _ in range(num_layers)
276
+ ]
277
+ )
278
+
279
+ self.single_transformer_blocks = nn.ModuleList(
280
+ [
281
+ FluxSingleTransformerBlock(
282
+ dim=self.inner_dim,
283
+ num_attention_heads=num_attention_heads,
284
+ attention_head_dim=attention_head_dim,
285
+ )
286
+ for _ in range(num_single_layers)
287
+ ]
288
+ )
289
+
290
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
291
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
292
+
293
+ self.gradient_checkpointing = False
294
+
295
+ @property
296
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
297
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
298
+ r"""
299
+ Returns:
300
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
301
+ indexed by its weight name.
302
+ """
303
+ # set recursively
304
+ processors = {}
305
+
306
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
307
+ if hasattr(module, "get_processor"):
308
+ processors[f"{name}.processor"] = module.get_processor()
309
+
310
+ for sub_name, child in module.named_children():
311
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
312
+
313
+ return processors
314
+
315
+ for name, module in self.named_children():
316
+ fn_recursive_add_processors(name, module, processors)
317
+
318
+ return processors
319
+
320
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
321
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
322
+ r"""
323
+ Sets the attention processor to use to compute attention.
324
+
325
+ Parameters:
326
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
327
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
328
+ for **all** `Attention` layers.
329
+
330
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
331
+ processor. This is strongly recommended when setting trainable attention processors.
332
+
333
+ """
334
+ count = len(self.attn_processors.keys())
335
+
336
+ if isinstance(processor, dict) and len(processor) != count:
337
+ raise ValueError(
338
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
339
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
340
+ )
341
+
342
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
343
+ if hasattr(module, "set_processor"):
344
+ if not isinstance(processor, dict):
345
+ module.set_processor(processor)
346
+ else:
347
+ module.set_processor(processor.pop(f"{name}.processor"))
348
+
349
+ for sub_name, child in module.named_children():
350
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
351
+
352
+ for name, module in self.named_children():
353
+ fn_recursive_attn_processor(name, module, processor)
354
+
355
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
356
+ def fuse_qkv_projections(self):
357
+ """
358
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
359
+ are fused. For cross-attention modules, key and value projection matrices are fused.
360
+
361
+ <Tip warning={true}>
362
+
363
+ This API is 🧪 experimental.
364
+
365
+ </Tip>
366
+ """
367
+ self.original_attn_processors = None
368
+
369
+ for _, attn_processor in self.attn_processors.items():
370
+ if "Added" in str(attn_processor.__class__.__name__):
371
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
372
+
373
+ self.original_attn_processors = self.attn_processors
374
+
375
+ for module in self.modules():
376
+ if isinstance(module, Attention):
377
+ module.fuse_projections(fuse=True)
378
+
379
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
380
+
381
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
382
+ def unfuse_qkv_projections(self):
383
+ """Disables the fused QKV projection if enabled.
384
+
385
+ <Tip warning={true}>
386
+
387
+ This API is 🧪 experimental.
388
+
389
+ </Tip>
390
+
391
+ """
392
+ if self.original_attn_processors is not None:
393
+ self.set_attn_processor(self.original_attn_processors)
394
+
395
+ def _set_gradient_checkpointing(self, module=None, enable=False, gradient_checkpointing_func=None):
396
+ # Align with diffusers' enable_gradient_checkpointing API which may call
397
+ # without a `module` argument and pass only keyword args.
398
+ # Toggle on both the provided module (if any) and on self for safety.
399
+ if module is not None and hasattr(module, "gradient_checkpointing"):
400
+ module.gradient_checkpointing = enable
401
+ if hasattr(self, "gradient_checkpointing"):
402
+ self.gradient_checkpointing = enable
403
+ # Optionally store the provided function for future use.
404
+ if gradient_checkpointing_func is not None:
405
+ setattr(self, "_gradient_checkpointing_func", gradient_checkpointing_func)
406
+
407
+ def forward(
408
+ self,
409
+ hidden_states: torch.Tensor,
410
+ cond_hidden_states: torch.Tensor = None,
411
+ encoder_hidden_states: torch.Tensor = None,
412
+ pooled_projections: torch.Tensor = None,
413
+ timestep: torch.LongTensor = None,
414
+ img_ids: torch.Tensor = None,
415
+ txt_ids: torch.Tensor = None,
416
+ guidance: torch.Tensor = None,
417
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
418
+ controlnet_block_samples=None,
419
+ controlnet_single_block_samples=None,
420
+ return_dict: bool = True,
421
+ controlnet_blocks_repeat: bool = False,
422
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
423
+ if cond_hidden_states is not None:
424
+ use_condition = True
425
+ else:
426
+ use_condition = False
427
+
428
+ if joint_attention_kwargs is not None:
429
+ joint_attention_kwargs = joint_attention_kwargs.copy()
430
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
431
+ else:
432
+ lora_scale = 1.0
433
+
434
+ if USE_PEFT_BACKEND:
435
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
436
+ scale_lora_layers(self, lora_scale)
437
+ else:
438
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
439
+ logger.warning(
440
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
441
+ )
442
+
443
+ hidden_states = self.x_embedder(hidden_states)
444
+ if cond_hidden_states is not None:
445
+ if cond_hidden_states.shape[-1] == self.x_embedder.in_features:
446
+ cond_hidden_states = self.x_embedder(cond_hidden_states)
447
+ elif cond_hidden_states.shape[-1] == 64:
448
+ # 只用前64列权重和bias
449
+ weight = self.x_embedder.weight[:, :64] # [inner_dim, 64]
450
+ bias = self.x_embedder.bias
451
+ cond_hidden_states = torch.nn.functional.linear(cond_hidden_states, weight, bias)
452
+
453
+ timestep = timestep.to(hidden_states.dtype) * 1000
454
+ if guidance is not None:
455
+ guidance = guidance.to(hidden_states.dtype) * 1000
456
+ else:
457
+ guidance = None
458
+
459
+ temb = (
460
+ self.time_text_embed(timestep, pooled_projections)
461
+ if guidance is None
462
+ else self.time_text_embed(timestep, guidance, pooled_projections)
463
+ )
464
+
465
+ cond_temb = (
466
+ self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
467
+ if guidance is None
468
+ else self.time_text_embed(
469
+ torch.ones_like(timestep) * 0, guidance, pooled_projections
470
+ )
471
+ )
472
+
473
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
474
+
475
+ if txt_ids.ndim == 3:
476
+ logger.warning(
477
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
478
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
479
+ )
480
+ txt_ids = txt_ids[0]
481
+ if img_ids.ndim == 3:
482
+ logger.warning(
483
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
484
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
485
+ )
486
+ img_ids = img_ids[0]
487
+
488
+ ids = torch.cat((txt_ids, img_ids), dim=0)
489
+ image_rotary_emb = self.pos_embed(ids)
490
+
491
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
492
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
493
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
494
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
495
+
496
+ for index_block, block in enumerate(self.transformer_blocks):
497
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
498
+
499
+ def create_custom_forward(module, return_dict=None):
500
+ def custom_forward(*inputs):
501
+ if return_dict is not None:
502
+ return module(*inputs, return_dict=return_dict)
503
+ else:
504
+ return module(*inputs)
505
+
506
+ return custom_forward
507
+
508
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
509
+ if use_condition:
510
+ encoder_hidden_states, hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
511
+ create_custom_forward(block),
512
+ hidden_states,
513
+ cond_hidden_states,
514
+ encoder_hidden_states,
515
+ temb,
516
+ cond_temb,
517
+ image_rotary_emb,
518
+ joint_attention_kwargs,
519
+ **ckpt_kwargs,
520
+ )
521
+ else:
522
+ encoder_hidden_states, hidden_states, _ = torch.utils.checkpoint.checkpoint(
523
+ create_custom_forward(block),
524
+ hidden_states,
525
+ None,
526
+ encoder_hidden_states,
527
+ temb,
528
+ None,
529
+ image_rotary_emb,
530
+ joint_attention_kwargs,
531
+ **ckpt_kwargs,
532
+ )
533
+
534
+ else:
535
+ encoder_hidden_states, hidden_states, cond_hidden_states = block(
536
+ hidden_states=hidden_states,
537
+ encoder_hidden_states=encoder_hidden_states,
538
+ cond_hidden_states=cond_hidden_states if use_condition else None,
539
+ temb=temb,
540
+ cond_temb=cond_temb if use_condition else None,
541
+ image_rotary_emb=image_rotary_emb,
542
+ joint_attention_kwargs=joint_attention_kwargs,
543
+ )
544
+
545
+ # controlnet residual
546
+ if controlnet_block_samples is not None:
547
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
548
+ interval_control = int(np.ceil(interval_control))
549
+ # For Xlabs ControlNet.
550
+ if controlnet_blocks_repeat:
551
+ hidden_states = (
552
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
553
+ )
554
+ else:
555
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
556
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
557
+
558
+ for index_block, block in enumerate(self.single_transformer_blocks):
559
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
560
+
561
+ def create_custom_forward(module, return_dict=None):
562
+ def custom_forward(*inputs):
563
+ if return_dict is not None:
564
+ return module(*inputs, return_dict=return_dict)
565
+ else:
566
+ return module(*inputs)
567
+
568
+ return custom_forward
569
+
570
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
571
+ if use_condition:
572
+ hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
573
+ create_custom_forward(block),
574
+ hidden_states,
575
+ cond_hidden_states,
576
+ temb,
577
+ cond_temb,
578
+ image_rotary_emb,
579
+ joint_attention_kwargs,
580
+ **ckpt_kwargs,
581
+ )
582
+ else:
583
+ hidden_states, _ = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(block),
585
+ hidden_states,
586
+ None,
587
+ temb,
588
+ None,
589
+ image_rotary_emb,
590
+ joint_attention_kwargs,
591
+ **ckpt_kwargs,
592
+ )
593
+
594
+ else:
595
+ hidden_states, cond_hidden_states = block(
596
+ hidden_states=hidden_states,
597
+ cond_hidden_states=cond_hidden_states if use_condition else None,
598
+ temb=temb,
599
+ cond_temb=cond_temb if use_condition else None,
600
+ image_rotary_emb=image_rotary_emb,
601
+ joint_attention_kwargs=joint_attention_kwargs,
602
+ )
603
+
604
+ # controlnet residual
605
+ if controlnet_single_block_samples is not None:
606
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
607
+ interval_control = int(np.ceil(interval_control))
608
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
609
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
610
+ + controlnet_single_block_samples[index_block // interval_control]
611
+ )
612
+
613
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
614
+
615
+ hidden_states = self.norm_out(hidden_states, temb)
616
+ output = self.proj_out(hidden_states)
617
+
618
+ if USE_PEFT_BACKEND:
619
+ # remove `lora_scale` from each PEFT layer
620
+ unscale_lora_layers(self, lora_scale)
621
+
622
+ if not return_dict:
623
+ return (output,)
624
+
625
+ return Transformer2DModelOutput(sample=output)
train/train_kontext_color.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import shutil
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+ import re
10
+
11
+ from safetensors.torch import save_file
12
+ from PIL import Image
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ import diffusers
23
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
24
+ from diffusers.optimization import get_scheduler
25
+ from diffusers.training_utils import (
26
+ cast_training_params,
27
+ compute_density_for_timestep_sampling,
28
+ compute_loss_weighting_for_sd3,
29
+ )
30
+ from diffusers.utils.torch_utils import is_compiled_module
31
+ from diffusers.utils import (
32
+ check_min_version,
33
+ is_wandb_available,
34
+ )
35
+
36
+ from src.prompt_helper import *
37
+ from src.lora_helper import *
38
+ from src.jsonl_datasets_kontext_color import make_train_dataset_inpaint_mask, collate_fn
39
+ from src.pipeline_flux_kontext_control import (
40
+ FluxKontextControlPipeline,
41
+ resize_position_encoding,
42
+ prepare_latent_subject_ids,
43
+ PREFERRED_KONTEXT_RESOLUTIONS
44
+ )
45
+ from src.transformer_flux import FluxTransformer2DModel
46
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
47
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
48
+ from tqdm.auto import tqdm
49
+
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.31.0.dev0")
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ def log_validation(
61
+ pipeline,
62
+ args,
63
+ accelerator,
64
+ pipeline_args,
65
+ step,
66
+ torch_dtype,
67
+ is_final_validation=False,
68
+ ):
69
+ logger.info(
70
+ f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
71
+ )
72
+ pipeline = pipeline.to(accelerator.device)
73
+ pipeline.set_progress_bar_config(disable=True)
74
+
75
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
76
+ autocast_ctx = nullcontext()
77
+
78
+ # Build per-case evaluation: require equal lengths for image, spatial image, and prompt
79
+ if args.validation_images is None or args.validation_images == ['None']:
80
+ raise ValueError("validation_images must be provided and non-empty")
81
+ if args.validation_prompt is None:
82
+ raise ValueError("validation_prompt must be provided and non-empty")
83
+
84
+ control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
85
+ spatial_ls = control_dict_root.get("spatial_images", []) or []
86
+
87
+ val_imgs = args.validation_images
88
+ prompts = args.validation_prompt
89
+
90
+ if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
91
+ raise ValueError(
92
+ f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
93
+ )
94
+
95
+ results = []
96
+
97
+ def _resize_to_preferred(img: Image.Image) -> Image.Image:
98
+ w, h = img.size
99
+ aspect_ratio = w / h if h != 0 else 1.0
100
+ _, target_w, target_h = min(
101
+ (abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
102
+ for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
103
+ )
104
+ return img.resize((target_w, target_h), Image.BICUBIC)
105
+
106
+ # Distributed per-rank assignment: each process handles its own slice of cases
107
+ num_cases = len(prompts)
108
+ logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
109
+
110
+ # Indices assigned to this rank
111
+ rank = accelerator.process_index
112
+ world_size = accelerator.num_processes
113
+ local_indices = list(range(rank, num_cases, world_size))
114
+
115
+ local_images = []
116
+ with autocast_ctx:
117
+ for idx in local_indices:
118
+ try:
119
+ base_img = Image.open(val_imgs[idx]).convert("RGB")
120
+ resized_img = _resize_to_preferred(base_img)
121
+ except Exception as e:
122
+ raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
123
+
124
+ case_args = dict(pipeline_args) if pipeline_args is not None else {}
125
+ case_args.pop("height", None)
126
+ case_args.pop("width", None)
127
+ if resized_img is not None:
128
+ tw, th = resized_img.size
129
+ case_args["height"] = th
130
+ case_args["width"] = tw
131
+
132
+ case_control = dict(case_args.get("control_dict", {}))
133
+ spatial_case = spatial_ls[idx]
134
+
135
+ # Load spatial image if it's a path; else assume it's already an image
136
+ if isinstance(spatial_case, str):
137
+ spatial_img = Image.open(spatial_case).convert("RGB")
138
+ else:
139
+ spatial_img = spatial_case
140
+
141
+ case_control["spatial_images"] = [spatial_img]
142
+ case_control["subject_images"] = []
143
+ case_args["control_dict"] = case_control
144
+
145
+ case_args["prompt"] = prompts[idx]
146
+ img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
147
+ local_images.append(img)
148
+
149
+ # Gather all images per rank (pad to equal count) to main process
150
+ fixed_size = (1024, 1024)
151
+ max_local = int(math.ceil(num_cases / world_size)) if world_size > 0 else len(local_images)
152
+ # Build per-rank batch tensors
153
+ imgs_rank = []
154
+ idx_rank = []
155
+ has_rank = []
156
+ for j in range(max_local):
157
+ if j < len(local_images):
158
+ resized = local_images[j].resize(fixed_size, Image.BICUBIC)
159
+ img_np = np.asarray(resized).astype(np.uint8)
160
+ imgs_rank.append(torch.from_numpy(img_np))
161
+ idx_rank.append(local_indices[j])
162
+ has_rank.append(1)
163
+ else:
164
+ imgs_rank.append(torch.from_numpy(np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)))
165
+ idx_rank.append(-1)
166
+ has_rank.append(0)
167
+ imgs_rank_tensor = torch.stack([t.to(device=accelerator.device) for t in imgs_rank], dim=0) # [max_local, H, W, C]
168
+ idx_rank_tensor = torch.tensor(idx_rank, device=accelerator.device, dtype=torch.long) # [max_local]
169
+ has_rank_tensor = torch.tensor(has_rank, device=accelerator.device, dtype=torch.int) # [max_local]
170
+
171
+ gathered_has = accelerator.gather(has_rank_tensor) # [world * max_local]
172
+ gathered_idx = accelerator.gather(idx_rank_tensor) # [world * max_local]
173
+ gathered_imgs = accelerator.gather(imgs_rank_tensor) # [world * max_local, H, W, C]
174
+
175
+ if accelerator.is_main_process:
176
+ world = int(world_size)
177
+ slots = int(max_local)
178
+ try:
179
+ gathered_has = gathered_has.view(world, slots)
180
+ gathered_idx = gathered_idx.view(world, slots)
181
+ gathered_imgs = gathered_imgs.view(world, slots, fixed_size[1], fixed_size[0], 3)
182
+ except Exception:
183
+ # Fallback: treat as flat if reshape fails
184
+ gathered_has = gathered_has.view(-1, 1)
185
+ gathered_idx = gathered_idx.view(-1, 1)
186
+ gathered_imgs = gathered_imgs.view(-1, 1, fixed_size[1], fixed_size[0], 3)
187
+ world = int(gathered_has.shape[0])
188
+ slots = 1
189
+ for i in range(world):
190
+ for j in range(slots):
191
+ if int(gathered_has[i, j].item()) == 1:
192
+ idx = int(gathered_idx[i, j].item())
193
+ arr = gathered_imgs[i, j].cpu().numpy()
194
+ pil_img = Image.fromarray(arr.astype(np.uint8))
195
+ # Resize back to original validation image size
196
+ try:
197
+ orig = Image.open(val_imgs[idx]).convert("RGB")
198
+ pil_img = pil_img.resize(orig.size, Image.BICUBIC)
199
+ except Exception:
200
+ pass
201
+ results.append(pil_img)
202
+
203
+ # Log results (resize to 1024x1024 for saving or external trackers). Skip TensorBoard per request.
204
+ resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
205
+ for tracker in accelerator.trackers:
206
+ phase_name = "test" if is_final_validation else "validation"
207
+ if tracker.name == "tensorboard":
208
+ continue
209
+ if tracker.name == "wandb":
210
+ tracker.log({
211
+ phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
212
+ })
213
+
214
+ del pipeline
215
+ if torch.cuda.is_available():
216
+ torch.cuda.empty_cache()
217
+
218
+ return results
219
+
220
+
221
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
222
+ text_encoder_config = transformers.PretrainedConfig.from_pretrained(
223
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
224
+ )
225
+ model_class = text_encoder_config.architectures[0]
226
+ if model_class == "CLIPTextModel":
227
+ from transformers import CLIPTextModel
228
+
229
+ return CLIPTextModel
230
+ elif model_class == "T5EncoderModel":
231
+ from transformers import T5EncoderModel
232
+
233
+ return T5EncoderModel
234
+ else:
235
+ raise ValueError(f"{model_class} is not supported.")
236
+
237
+
238
+ def parse_args(input_args=None):
239
+ parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
240
+ parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
241
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
242
+ parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
243
+
244
+ parser.add_argument("--train_data_dir", type=str, default="", help="Path to JSONL dataset.")
245
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
246
+ parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
247
+ parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
248
+ parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
249
+
250
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
251
+ parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
252
+ parser.add_argument("--kontext", type=str, default="disable")
253
+ parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
254
+ parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
255
+ parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
256
+ parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
257
+ parser.add_argument("--num_validation_images", type=int, default=4)
258
+ parser.add_argument("--validation_steps", type=int, default=20)
259
+
260
+ parser.add_argument("--ranks", type=int, nargs="+", default=[128], help="LoRA ranks")
261
+ parser.add_argument("--network_alphas", type=int, nargs="+", default=[128], help="LoRA network alphas")
262
+ parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
263
+ parser.add_argument("--seed", type=int, default=None)
264
+ parser.add_argument("--train_batch_size", type=int, default=1)
265
+ parser.add_argument("--num_train_epochs", type=int, default=50)
266
+ parser.add_argument("--max_train_steps", type=int, default=None)
267
+ parser.add_argument("--checkpointing_steps", type=int, default=1000)
268
+ parser.add_argument("--checkpoints_total_limit", type=int, default=None)
269
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
270
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
271
+ parser.add_argument("--gradient_checkpointing", action="store_true")
272
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
273
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
274
+ parser.add_argument("--scale_lr", action="store_true", default=False)
275
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
276
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
277
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
278
+ parser.add_argument("--lr_power", type=float, default=1.0)
279
+ parser.add_argument("--dataloader_num_workers", type=int, default=1)
280
+ parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
281
+ parser.add_argument("--logit_mean", type=float, default=0.0)
282
+ parser.add_argument("--logit_std", type=float, default=1.0)
283
+ parser.add_argument("--mode_scale", type=float, default=1.29)
284
+ parser.add_argument("--optimizer", type=str, default="AdamW")
285
+ parser.add_argument("--use_8bit_adam", action="store_true")
286
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
287
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
288
+ parser.add_argument("--prodigy_beta3", type=float, default=None)
289
+ parser.add_argument("--prodigy_decouple", type=bool, default=True)
290
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
291
+ parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
292
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
293
+ parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
294
+ parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
295
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
296
+ parser.add_argument("--logging_dir", type=str, default="logs")
297
+ parser.add_argument("--cache_latents", action="store_true", default=False)
298
+ parser.add_argument("--report_to", type=str, default="tensorboard")
299
+ parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
300
+ parser.add_argument("--upcast_before_saving", action="store_true", default=False)
301
+
302
+ if input_args is not None:
303
+ args = parser.parse_args(input_args)
304
+ else:
305
+ args = parser.parse_args()
306
+ return args
307
+
308
+
309
+ def main(args):
310
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
311
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
312
+
313
+ if args.output_dir is not None:
314
+ os.makedirs(args.output_dir, exist_ok=True)
315
+ os.makedirs(args.logging_dir, exist_ok=True)
316
+ logging_dir = Path(args.output_dir, args.logging_dir)
317
+
318
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
319
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
320
+ accelerator = Accelerator(
321
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
322
+ mixed_precision=args.mixed_precision,
323
+ log_with=args.report_to,
324
+ project_config=accelerator_project_config,
325
+ kwargs_handlers=[kwargs],
326
+ )
327
+
328
+ if torch.backends.mps.is_available():
329
+ accelerator.native_amp = False
330
+
331
+ if args.report_to == "wandb":
332
+ if not is_wandb_available():
333
+ raise ImportError("Install wandb for logging during training.")
334
+
335
+ logging.basicConfig(
336
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
337
+ datefmt="%m/%d/%Y %H:%M:%S",
338
+ level=logging.INFO,
339
+ )
340
+ logger.info(accelerator.state, main_process_only=False)
341
+ if accelerator.is_local_main_process:
342
+ transformers.utils.logging.set_verbosity_warning()
343
+ diffusers.utils.logging.set_verbosity_info()
344
+ else:
345
+ transformers.utils.logging.set_verbosity_error()
346
+ diffusers.utils.logging.set_verbosity_error()
347
+
348
+ if args.seed is not None:
349
+ set_seed(args.seed)
350
+
351
+ if accelerator.is_main_process and args.output_dir is not None:
352
+ os.makedirs(args.output_dir, exist_ok=True)
353
+
354
+ # Tokenizers
355
+ tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
356
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
357
+ )
358
+ tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
359
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
360
+ )
361
+
362
+ # Text encoders
363
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
364
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
365
+
366
+ # Scheduler and models
367
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
368
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
369
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
370
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
371
+ transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
372
+
373
+ # Train only LoRA adapters
374
+ transformer.requires_grad_(True)
375
+ vae.requires_grad_(False)
376
+ text_encoder_one.requires_grad_(False)
377
+ text_encoder_two.requires_grad_(False)
378
+
379
+ weight_dtype = torch.float32
380
+ if accelerator.mixed_precision == "fp16":
381
+ weight_dtype = torch.float16
382
+ elif accelerator.mixed_precision == "bf16":
383
+ weight_dtype = torch.bfloat16
384
+
385
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
386
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
387
+
388
+ vae.to(accelerator.device, dtype=weight_dtype)
389
+ transformer.to(accelerator.device, dtype=weight_dtype)
390
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
391
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
392
+
393
+ if args.gradient_checkpointing:
394
+ transformer.enable_gradient_checkpointing()
395
+
396
+ # Setup LoRA attention processors
397
+ if args.pretrained_lora_path is not None:
398
+ lora_path = args.pretrained_lora_path
399
+ checkpoint = load_checkpoint(lora_path)
400
+ lora_attn_procs = {}
401
+ double_blocks_idx = list(range(19))
402
+ single_blocks_idx = list(range(38))
403
+ number = 1
404
+ for name, attn_processor in transformer.attn_processors.items():
405
+ match = re.search(r'\.(\d+)\.', name)
406
+ if match:
407
+ layer_index = int(match.group(1))
408
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
409
+ lora_state_dicts = {}
410
+ for key, value in checkpoint.items():
411
+ if re.search(r'\.(\d+)\.', key):
412
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
413
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
414
+ lora_state_dicts[key] = value
415
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
416
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
417
+ )
418
+ for n in range(number):
419
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
420
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
421
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
422
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
423
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
424
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
425
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
426
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
427
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
428
+ lora_state_dicts = {}
429
+ for key, value in checkpoint.items():
430
+ if re.search(r'\.(\d+)\.', key):
431
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
432
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
433
+ lora_state_dicts[key] = value
434
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
435
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
436
+ )
437
+ for n in range(number):
438
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
439
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
440
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
441
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
442
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
443
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
444
+ else:
445
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
446
+ else:
447
+ lora_attn_procs = {}
448
+ double_blocks_idx = list(range(19))
449
+ single_blocks_idx = list(range(38))
450
+ for name, attn_processor in transformer.attn_processors.items():
451
+ match = re.search(r'\.(\d+)\.', name)
452
+ if match:
453
+ layer_index = int(match.group(1))
454
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
455
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
456
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
457
+ )
458
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
459
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
460
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
461
+ )
462
+ else:
463
+ lora_attn_procs[name] = attn_processor
464
+
465
+ transformer.set_attn_processor(lora_attn_procs)
466
+ transformer.train()
467
+ for n, param in transformer.named_parameters():
468
+ if '_lora' not in n:
469
+ param.requires_grad = False
470
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
471
+
472
+ def unwrap_model(model):
473
+ model = accelerator.unwrap_model(model)
474
+ model = model._orig_mod if is_compiled_module(model) else model
475
+ return model
476
+
477
+ if args.resume_from_checkpoint:
478
+ path = args.resume_from_checkpoint
479
+ global_step = int(path.split("-")[-1])
480
+ initial_global_step = global_step
481
+ else:
482
+ initial_global_step = 0
483
+ global_step = 0
484
+ first_epoch = 0
485
+
486
+ if args.scale_lr:
487
+ args.learning_rate = (
488
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
489
+ )
490
+
491
+ if args.mixed_precision == "fp16":
492
+ models = [transformer]
493
+ cast_training_params(models, dtype=torch.float32)
494
+
495
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
496
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
497
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
498
+
499
+ optimizer_class = torch.optim.AdamW
500
+ optimizer = optimizer_class(
501
+ [transformer_parameters_with_lr],
502
+ betas=(args.adam_beta1, args.adam_beta2),
503
+ weight_decay=args.adam_weight_decay,
504
+ eps=args.adam_epsilon,
505
+ )
506
+
507
+ tokenizers = [tokenizer_one, tokenizer_two]
508
+ text_encoders = [text_encoder_one, text_encoder_two]
509
+
510
+ train_dataset = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
511
+ train_dataloader = torch.utils.data.DataLoader(
512
+ train_dataset,
513
+ batch_size=args.train_batch_size,
514
+ shuffle=True,
515
+ collate_fn=collate_fn,
516
+ num_workers=args.dataloader_num_workers,
517
+ )
518
+
519
+ vae_config_shift_factor = vae.config.shift_factor
520
+ vae_config_scaling_factor = vae.config.scaling_factor
521
+
522
+ overrode_max_train_steps = False
523
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
524
+ if args.resume_from_checkpoint:
525
+ first_epoch = global_step // num_update_steps_per_epoch
526
+ if args.max_train_steps is None:
527
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
528
+ overrode_max_train_steps = True
529
+
530
+ lr_scheduler = get_scheduler(
531
+ args.lr_scheduler,
532
+ optimizer=optimizer,
533
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
534
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
535
+ num_cycles=args.lr_num_cycles,
536
+ power=args.lr_power,
537
+ )
538
+
539
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
540
+ transformer, optimizer, train_dataloader, lr_scheduler
541
+ )
542
+
543
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
544
+ if overrode_max_train_steps:
545
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
546
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
547
+
548
+ # Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
549
+ def _sanitize_hparams(config_dict):
550
+ sanitized = {}
551
+ for key, value in dict(config_dict).items():
552
+ try:
553
+ if value is None:
554
+ continue
555
+ # numpy scalar types
556
+ if isinstance(value, (np.integer,)):
557
+ sanitized[key] = int(value)
558
+ elif isinstance(value, (np.floating,)):
559
+ sanitized[key] = float(value)
560
+ elif isinstance(value, (int, float, bool, str)):
561
+ sanitized[key] = value
562
+ elif isinstance(value, Path):
563
+ sanitized[key] = str(value)
564
+ elif isinstance(value, (list, tuple)):
565
+ # stringify simple sequences; skip if fails
566
+ sanitized[key] = str(value)
567
+ else:
568
+ # best-effort stringify
569
+ sanitized[key] = str(value)
570
+ except Exception:
571
+ # skip unconvertible entries
572
+ continue
573
+ return sanitized
574
+
575
+ if accelerator.is_main_process:
576
+ tracker_name = "Easy_Control_Kontext"
577
+ accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
578
+
579
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
580
+ logger.info("***** Running training *****")
581
+ logger.info(f" Num examples = {len(train_dataset)}")
582
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
583
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
584
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
585
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
586
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
587
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
588
+
589
+ progress_bar = tqdm(
590
+ range(0, args.max_train_steps),
591
+ initial=initial_global_step,
592
+ desc="Steps",
593
+ disable=not accelerator.is_local_main_process,
594
+ )
595
+
596
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
597
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
598
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
599
+ timesteps = timesteps.to(accelerator.device)
600
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
601
+ sigma = sigmas[step_indices].flatten()
602
+ while len(sigma.shape) < n_dim:
603
+ sigma = sigma.unsqueeze(-1)
604
+ return sigma
605
+
606
+ # Kontext specifics
607
+ vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
608
+ # Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
609
+ height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
610
+ width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
611
+ offset = 64
612
+
613
+ for epoch in range(first_epoch, args.num_train_epochs):
614
+ transformer.train()
615
+ for step, batch in enumerate(train_dataloader):
616
+ models_to_accumulate = [transformer]
617
+ with accelerator.accumulate(models_to_accumulate):
618
+ tokens = [batch["text_ids_1"], batch["text_ids_2"]]
619
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
620
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
621
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
622
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
623
+
624
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
625
+ height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
626
+ width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
627
+
628
+ model_input = vae.encode(pixel_values).latent_dist.sample()
629
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
630
+ model_input = model_input.to(dtype=weight_dtype)
631
+
632
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
633
+ model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
634
+ )
635
+
636
+ noise = torch.randn_like(model_input)
637
+ bsz = model_input.shape[0]
638
+
639
+ u = compute_density_for_timestep_sampling(
640
+ weighting_scheme=args.weighting_scheme,
641
+ batch_size=bsz,
642
+ logit_mean=args.logit_mean,
643
+ logit_std=args.logit_std,
644
+ mode_scale=args.mode_scale,
645
+ )
646
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
647
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
648
+
649
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
650
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
651
+
652
+ packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
653
+ noisy_model_input,
654
+ batch_size=model_input.shape[0],
655
+ num_channels_latents=model_input.shape[1],
656
+ height=model_input.shape[2],
657
+ width=model_input.shape[3],
658
+ )
659
+
660
+ latent_image_ids_to_concat = [latent_image_ids]
661
+ packed_cond_model_input_to_concat = []
662
+
663
+ if args.kontext == "enable":
664
+ source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
665
+ source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
666
+ source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
667
+ image_latent_h, image_latent_w = source_image_latents.shape[2:]
668
+ packed_image_latents = FluxKontextControlPipeline._pack_latents(
669
+ source_image_latents,
670
+ batch_size=source_image_latents.shape[0],
671
+ num_channels_latents=source_image_latents.shape[1],
672
+ height=image_latent_h,
673
+ width=image_latent_w,
674
+ )
675
+ source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
676
+ batch_size=source_image_latents.shape[0],
677
+ height=image_latent_h // 2,
678
+ width=image_latent_w // 2,
679
+ device=accelerator.device,
680
+ dtype=weight_dtype,
681
+ )
682
+ source_image_ids[..., 0] = 1 # Mark as condition
683
+ latent_image_ids_to_concat.append(source_image_ids)
684
+
685
+
686
+ subject_pixel_values = batch.get("subject_pixel_values")
687
+ if subject_pixel_values is not None:
688
+ subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
689
+ subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
690
+ subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
691
+ subject_input = subject_input.to(dtype=weight_dtype)
692
+ sub_number = subject_pixel_values.shape[-2] // args.cond_size
693
+ latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
694
+ latent_subject_ids[..., 0] = 2
695
+ latent_subject_ids[:, 1] += offset
696
+ sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
697
+ latent_image_ids_to_concat.append(sub_latent_image_ids)
698
+
699
+ packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
700
+ subject_input,
701
+ batch_size=subject_input.shape[0],
702
+ num_channels_latents=subject_input.shape[1],
703
+ height=subject_input.shape[2],
704
+ width=subject_input.shape[3],
705
+ )
706
+ packed_cond_model_input_to_concat.append(packed_subject_model_input)
707
+
708
+ cond_pixel_values = batch.get("cond_pixel_values")
709
+ if cond_pixel_values is not None:
710
+ cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
711
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
712
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
713
+ cond_input = cond_input.to(dtype=weight_dtype)
714
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
715
+ cond_latent_image_ids[..., 0] = 2
716
+ cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
717
+ latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
718
+
719
+ packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
720
+ cond_input,
721
+ batch_size=cond_input.shape[0],
722
+ num_channels_latents=cond_input.shape[1],
723
+ height=cond_input.shape[2],
724
+ width=cond_input.shape[3],
725
+ )
726
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
727
+
728
+ latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
729
+ cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
730
+
731
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
732
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
733
+ guidance = guidance.expand(model_input.shape[0])
734
+ else:
735
+ guidance = None
736
+
737
+ latent_model_input=packed_noisy_model_input
738
+ if args.kontext == "enable":
739
+ latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
740
+ model_pred = transformer(
741
+ hidden_states=latent_model_input,
742
+ cond_hidden_states=cond_packed_noisy_model_input,
743
+ timestep=timesteps / 1000,
744
+ guidance=guidance,
745
+ pooled_projections=pooled_prompt_embeds,
746
+ encoder_hidden_states=prompt_embeds,
747
+ txt_ids=text_ids,
748
+ img_ids=latent_image_ids,
749
+ return_dict=False,
750
+ )[0]
751
+
752
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
753
+
754
+ model_pred = FluxKontextControlPipeline._unpack_latents(
755
+ model_pred,
756
+ height=int(pixel_values.shape[-2]),
757
+ width=int(pixel_values.shape[-1]),
758
+ vae_scale_factor=vae_scale_factor,
759
+ )
760
+
761
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
762
+ target = noise - model_input
763
+
764
+ loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
765
+ loss = loss.mean()
766
+ accelerator.backward(loss)
767
+ if accelerator.sync_gradients:
768
+ params_to_clip = (transformer.parameters())
769
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
770
+
771
+ optimizer.step()
772
+ lr_scheduler.step()
773
+ optimizer.zero_grad()
774
+
775
+ if accelerator.sync_gradients:
776
+ progress_bar.update(1)
777
+ global_step += 1
778
+
779
+ if accelerator.is_main_process:
780
+ if global_step % args.checkpointing_steps == 0:
781
+ if args.checkpoints_total_limit is not None:
782
+ checkpoints = os.listdir(args.output_dir)
783
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
784
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
785
+ if len(checkpoints) >= args.checkpoints_total_limit:
786
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
787
+ removing_checkpoints = checkpoints[0:num_to_remove]
788
+ logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
789
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
790
+ for removing_checkpoint in removing_checkpoints:
791
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
792
+ shutil.rmtree(removing_checkpoint)
793
+
794
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
795
+ os.makedirs(save_path, exist_ok=True)
796
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
797
+ lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
798
+ save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
799
+ logger.info(f"Saved state to {save_path}")
800
+
801
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
802
+ progress_bar.set_postfix(**logs)
803
+ accelerator.log(logs, step=global_step)
804
+
805
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
806
+ # Create pipeline on every rank to run validation in parallel
807
+ pipeline = FluxKontextControlPipeline.from_pretrained(
808
+ args.pretrained_model_name_or_path,
809
+ vae=vae,
810
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
811
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
812
+ transformer=accelerator.unwrap_model(transformer),
813
+ revision=args.revision,
814
+ variant=args.variant,
815
+ torch_dtype=weight_dtype,
816
+ )
817
+
818
+ if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
819
+ spatial_paths = args.spatial_test_images
820
+ else:
821
+ spatial_paths = []
822
+
823
+ pipeline_args = {
824
+ "prompt": args.validation_prompt,
825
+ "cond_size": args.cond_size,
826
+ "guidance_scale": 3.5,
827
+ "num_inference_steps": 20,
828
+ "max_sequence_length": 128,
829
+ "control_dict": {"spatial_images": spatial_paths},
830
+ }
831
+
832
+ images = log_validation(
833
+ pipeline=pipeline,
834
+ args=args,
835
+ accelerator=accelerator,
836
+ pipeline_args=pipeline_args,
837
+ step=global_step,
838
+ torch_dtype=weight_dtype,
839
+ )
840
+
841
+ # Only main process saves/logs
842
+ if accelerator.is_main_process:
843
+ save_path = os.path.join(args.output_dir, "validation")
844
+ os.makedirs(save_path, exist_ok=True)
845
+ save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
846
+ os.makedirs(save_folder, exist_ok=True)
847
+ for idx, img in enumerate(images):
848
+ img.save(os.path.join(save_folder, f"{idx}.jpg"))
849
+ del pipeline
850
+
851
+ accelerator.wait_for_everyone()
852
+ accelerator.end_training()
853
+
854
+
855
+ if __name__ == "__main__":
856
+ args = parse_args()
857
+ main(args)
858
+
train/train_kontext_color.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="" # your flux path
2
+ export OUTPUT_DIR="" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+ export TRAIN_DATA="" # your data jsonl file
5
+ export LOG_PATH="$OUTPUT_DIR/log"
6
+
7
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_color.py \
8
+ --pretrained_model_name_or_path $MODEL_DIR \
9
+ --lora_num=1 \
10
+ --cond_size=512 \
11
+ --ranks 128 \
12
+ --network_alphas 128 \
13
+ --output_dir=$OUTPUT_DIR \
14
+ --logging_dir=$LOG_PATH \
15
+ --mixed_precision="bf16" \
16
+ --train_data_dir=$TRAIN_DATA \
17
+ --learning_rate=1e-4 \
18
+ --train_batch_size=1 \
19
+ --num_train_epochs=1 \
20
+ --validation_steps=100 \
21
+ --checkpointing_steps=1000 \
22
+ --validation_images "./kontext_color_test/img_1.png" \
23
+ --spatial_test_images "./kontext_color_test/color_1.png" \
24
+ --validation_prompt "Let this woman have red purple and blue hair" \
25
+ --num_validation_images=1
train/train_kontext_complete_lora.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="" # your flux path
2
+ export OUTPUT_DIR="" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+ export LOG_PATH="$OUTPUT_DIR/log"
5
+
6
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_lora.py \
7
+ --train_data_jsonl "" \
8
+ --pretrained_model_name_or_path $MODEL_DIR \
9
+ --output_dir=$OUTPUT_DIR \
10
+ --logging_dir=$LOG_PATH \
11
+ --mixed_precision="bf16" \
12
+ --learning_rate=1e-4 \
13
+ --train_batch_size=1 \
14
+ --num_train_epochs=5 \
15
+ --validation_steps=100 \
16
+ --checkpointing_steps=500 \
17
+ --validation_images "./kontext_complete_test/img_1.png" \
18
+ --validation_prompt "" \
19
+ --gradient_checkpointing \
20
+ --num_validation_images=1
train/train_kontext_edge.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import shutil
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+ import re
10
+
11
+ from safetensors.torch import save_file
12
+ from PIL import Image
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ import diffusers
23
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
24
+ from diffusers.optimization import get_scheduler
25
+ from diffusers.training_utils import (
26
+ cast_training_params,
27
+ compute_density_for_timestep_sampling,
28
+ compute_loss_weighting_for_sd3,
29
+ )
30
+ from diffusers.utils.torch_utils import is_compiled_module
31
+ from diffusers.utils import (
32
+ check_min_version,
33
+ is_wandb_available,
34
+ )
35
+
36
+ from src.prompt_helper import *
37
+ from src.lora_helper import *
38
+ from src.jsonl_datasets_kontext_edge import make_train_dataset_inpaint_mask, collate_fn
39
+ from src.pipeline_flux_kontext_control import (
40
+ FluxKontextControlPipeline,
41
+ resize_position_encoding,
42
+ prepare_latent_subject_ids,
43
+ PREFERRED_KONTEXT_RESOLUTIONS
44
+ )
45
+ from src.transformer_flux import FluxTransformer2DModel
46
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
47
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
48
+ from tqdm.auto import tqdm
49
+
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.31.0.dev0")
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ def log_validation(
61
+ pipeline,
62
+ args,
63
+ accelerator,
64
+ pipeline_args,
65
+ step,
66
+ torch_dtype,
67
+ is_final_validation=False,
68
+ ):
69
+ logger.info(
70
+ f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
71
+ )
72
+ pipeline = pipeline.to(accelerator.device)
73
+ pipeline.set_progress_bar_config(disable=True)
74
+
75
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
76
+ autocast_ctx = nullcontext()
77
+
78
+ # Build per-case evaluation: require equal lengths for image, spatial image, and prompt
79
+ if args.validation_images is None or args.validation_images == ['None']:
80
+ raise ValueError("validation_images must be provided and non-empty")
81
+ if args.validation_prompt is None:
82
+ raise ValueError("validation_prompt must be provided and non-empty")
83
+
84
+ control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
85
+ spatial_ls = control_dict_root.get("spatial_images", []) or []
86
+
87
+ val_imgs = args.validation_images
88
+ prompts = args.validation_prompt
89
+
90
+ if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
91
+ raise ValueError(
92
+ f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
93
+ )
94
+
95
+ results = []
96
+
97
+ def _resize_to_preferred(img: Image.Image) -> Image.Image:
98
+ w, h = img.size
99
+ aspect_ratio = w / h if h != 0 else 1.0
100
+ _, target_w, target_h = min(
101
+ (abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
102
+ for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
103
+ )
104
+ return img.resize((target_w, target_h), Image.BICUBIC)
105
+
106
+ # Strict per-case loop
107
+ num_cases = len(prompts)
108
+ logger.info(f"Paired validation: {num_cases} (image, spatial, prompt) cases")
109
+ with autocast_ctx:
110
+ for idx in range(num_cases):
111
+ resized_img = None
112
+ # If validation image path is a non-empty string, load and resize; otherwise, skip passing image
113
+ if isinstance(val_imgs[idx], str) and val_imgs[idx] != "":
114
+ try:
115
+ base_img = Image.open(val_imgs[idx]).convert("RGB")
116
+ resized_img = _resize_to_preferred(base_img)
117
+ except Exception as e:
118
+ raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
119
+
120
+ case_args = dict(pipeline_args) if pipeline_args is not None else {}
121
+ case_args.pop("height", None)
122
+ case_args.pop("width", None)
123
+ if resized_img is not None:
124
+ tw, th = resized_img.size
125
+ case_args["height"] = th
126
+ case_args["width"] = tw
127
+ else:
128
+ # When no image is provided, default to 1024x1024
129
+ case_args["height"] = 1024
130
+ case_args["width"] = 1024
131
+
132
+ # Bind single spatial control image per case; pass it directly (no masking)
133
+ case_control = dict(case_args.get("control_dict", {}))
134
+ spatial_case = spatial_ls[idx]
135
+
136
+ # Load spatial image if it's a path; else assume it's already an image
137
+ try:
138
+ spatial_img = Image.open(spatial_case).convert("RGB") if isinstance(spatial_case, str) else spatial_case
139
+ except Exception:
140
+ spatial_img = spatial_case
141
+
142
+ case_control["spatial_images"] = [spatial_img]
143
+ case_control["subject_images"] = []
144
+ case_args["control_dict"] = case_control
145
+
146
+ # Override prompt per case
147
+ case_args["prompt"] = prompts[idx]
148
+
149
+ if resized_img is not None:
150
+ img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
151
+ else:
152
+ img = pipeline(**case_args, generator=generator).images[0]
153
+ results.append(img)
154
+
155
+ # Log results (resize to 1024x1024 for logging only)
156
+ resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
157
+ for tracker in accelerator.trackers:
158
+ phase_name = "test" if is_final_validation else "validation"
159
+ if tracker.name == "tensorboard":
160
+ np_images = np.stack([np.asarray(img) for img in resized_for_log])
161
+ tracker.writer.add_images(phase_name, np_images, step, dataformats="NHWC")
162
+ if tracker.name == "wandb":
163
+ tracker.log({
164
+ phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
165
+ })
166
+
167
+ del pipeline
168
+ if torch.cuda.is_available():
169
+ torch.cuda.empty_cache()
170
+
171
+ return results
172
+
173
+
174
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
175
+ text_encoder_config = transformers.PretrainedConfig.from_pretrained(
176
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
177
+ )
178
+ model_class = text_encoder_config.architectures[0]
179
+ if model_class == "CLIPTextModel":
180
+ from transformers import CLIPTextModel
181
+
182
+ return CLIPTextModel
183
+ elif model_class == "T5EncoderModel":
184
+ from transformers import T5EncoderModel
185
+
186
+ return T5EncoderModel
187
+ else:
188
+ raise ValueError(f"{model_class} is not supported.")
189
+
190
+
191
+ def parse_args(input_args=None):
192
+ parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
193
+ parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
194
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
195
+ parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
196
+
197
+ parser.add_argument("--train_data_dir", type=str, default="", help="Path to JSONL dataset.")
198
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
199
+ parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
200
+ parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
201
+ parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
202
+
203
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
204
+ parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
205
+ parser.add_argument("--kontext", type=str, default="disable")
206
+ parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
207
+ parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
208
+ parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
209
+ parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
210
+ parser.add_argument("--num_validation_images", type=int, default=4)
211
+ parser.add_argument("--validation_steps", type=int, default=20)
212
+
213
+ parser.add_argument("--ranks", type=int, nargs="+", default=[128], help="LoRA ranks")
214
+ parser.add_argument("--network_alphas", type=int, nargs="+", default=[128], help="LoRA network alphas")
215
+ parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
216
+ parser.add_argument("--seed", type=int, default=None)
217
+ parser.add_argument("--train_batch_size", type=int, default=1)
218
+ parser.add_argument("--num_train_epochs", type=int, default=50)
219
+ parser.add_argument("--max_train_steps", type=int, default=None)
220
+ parser.add_argument("--checkpointing_steps", type=int, default=1000)
221
+ parser.add_argument("--checkpoints_total_limit", type=int, default=None)
222
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
223
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
224
+ parser.add_argument("--gradient_checkpointing", action="store_true")
225
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
226
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
227
+ parser.add_argument("--scale_lr", action="store_true", default=False)
228
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
229
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
230
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
231
+ parser.add_argument("--lr_power", type=float, default=1.0)
232
+ parser.add_argument("--dataloader_num_workers", type=int, default=1)
233
+ parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
234
+ parser.add_argument("--logit_mean", type=float, default=0.0)
235
+ parser.add_argument("--logit_std", type=float, default=1.0)
236
+ parser.add_argument("--mode_scale", type=float, default=1.29)
237
+ parser.add_argument("--optimizer", type=str, default="AdamW")
238
+ parser.add_argument("--use_8bit_adam", action="store_true")
239
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
240
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
241
+ parser.add_argument("--prodigy_beta3", type=float, default=None)
242
+ parser.add_argument("--prodigy_decouple", type=bool, default=True)
243
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
244
+ parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
245
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
246
+ parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
247
+ parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
248
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
249
+ parser.add_argument("--logging_dir", type=str, default="logs")
250
+ parser.add_argument("--cache_latents", action="store_true", default=False)
251
+ parser.add_argument("--report_to", type=str, default="tensorboard")
252
+ parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
253
+ parser.add_argument("--upcast_before_saving", action="store_true", default=False)
254
+
255
+ if input_args is not None:
256
+ args = parser.parse_args(input_args)
257
+ else:
258
+ args = parser.parse_args()
259
+ return args
260
+
261
+
262
+ def main(args):
263
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
264
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
265
+
266
+ if args.output_dir is not None:
267
+ os.makedirs(args.output_dir, exist_ok=True)
268
+ os.makedirs(args.logging_dir, exist_ok=True)
269
+ logging_dir = Path(args.output_dir, args.logging_dir)
270
+
271
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
272
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
273
+ accelerator = Accelerator(
274
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
275
+ mixed_precision=args.mixed_precision,
276
+ log_with=args.report_to,
277
+ project_config=accelerator_project_config,
278
+ kwargs_handlers=[kwargs],
279
+ )
280
+
281
+ if torch.backends.mps.is_available():
282
+ accelerator.native_amp = False
283
+
284
+ if args.report_to == "wandb":
285
+ if not is_wandb_available():
286
+ raise ImportError("Install wandb for logging during training.")
287
+
288
+ logging.basicConfig(
289
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
290
+ datefmt="%m/%d/%Y %H:%M:%S",
291
+ level=logging.INFO,
292
+ )
293
+ logger.info(accelerator.state, main_process_only=False)
294
+ if accelerator.is_local_main_process:
295
+ transformers.utils.logging.set_verbosity_warning()
296
+ diffusers.utils.logging.set_verbosity_info()
297
+ else:
298
+ transformers.utils.logging.set_verbosity_error()
299
+ diffusers.utils.logging.set_verbosity_error()
300
+
301
+ if args.seed is not None:
302
+ set_seed(args.seed)
303
+
304
+ if accelerator.is_main_process and args.output_dir is not None:
305
+ os.makedirs(args.output_dir, exist_ok=True)
306
+
307
+ # Tokenizers
308
+ tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
309
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
310
+ )
311
+ tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
312
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
313
+ )
314
+
315
+ # Text encoders
316
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
317
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
318
+
319
+ # Scheduler and models
320
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
321
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
322
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
323
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
324
+ transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
325
+
326
+ # Train only LoRA adapters
327
+ transformer.requires_grad_(True)
328
+ vae.requires_grad_(False)
329
+ text_encoder_one.requires_grad_(False)
330
+ text_encoder_two.requires_grad_(False)
331
+
332
+ weight_dtype = torch.float32
333
+ if accelerator.mixed_precision == "fp16":
334
+ weight_dtype = torch.float16
335
+ elif accelerator.mixed_precision == "bf16":
336
+ weight_dtype = torch.bfloat16
337
+
338
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
339
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
340
+
341
+ vae.to(accelerator.device, dtype=weight_dtype)
342
+ transformer.to(accelerator.device, dtype=weight_dtype)
343
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
344
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
345
+
346
+ if args.gradient_checkpointing:
347
+ transformer.enable_gradient_checkpointing()
348
+
349
+ # Setup LoRA attention processors
350
+ if args.pretrained_lora_path is not None:
351
+ lora_path = args.pretrained_lora_path
352
+ checkpoint = load_checkpoint(lora_path)
353
+ lora_attn_procs = {}
354
+ double_blocks_idx = list(range(19))
355
+ single_blocks_idx = list(range(38))
356
+ number = 1
357
+ for name, attn_processor in transformer.attn_processors.items():
358
+ match = re.search(r'\.(\d+)\.', name)
359
+ if match:
360
+ layer_index = int(match.group(1))
361
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
362
+ lora_state_dicts = {}
363
+ for key, value in checkpoint.items():
364
+ if re.search(r'\.(\d+)\.', key):
365
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
366
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
367
+ lora_state_dicts[key] = value
368
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
369
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
370
+ )
371
+ for n in range(number):
372
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
373
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
374
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
375
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
376
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
377
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
378
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
379
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
380
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
381
+ lora_state_dicts = {}
382
+ for key, value in checkpoint.items():
383
+ if re.search(r'\.(\d+)\.', key):
384
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
385
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
386
+ lora_state_dicts[key] = value
387
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
388
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
389
+ )
390
+ for n in range(number):
391
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
392
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
393
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
394
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
395
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
396
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
397
+ else:
398
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
399
+ else:
400
+ lora_attn_procs = {}
401
+ double_blocks_idx = list(range(19))
402
+ single_blocks_idx = list(range(38))
403
+ for name, attn_processor in transformer.attn_processors.items():
404
+ match = re.search(r'\.(\d+)\.', name)
405
+ if match:
406
+ layer_index = int(match.group(1))
407
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
408
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
409
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
410
+ )
411
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
412
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
413
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
414
+ )
415
+ else:
416
+ lora_attn_procs[name] = attn_processor
417
+
418
+ transformer.set_attn_processor(lora_attn_procs)
419
+ transformer.train()
420
+ for n, param in transformer.named_parameters():
421
+ if '_lora' not in n:
422
+ param.requires_grad = False
423
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
424
+
425
+ def unwrap_model(model):
426
+ model = accelerator.unwrap_model(model)
427
+ model = model._orig_mod if is_compiled_module(model) else model
428
+ return model
429
+
430
+ if args.resume_from_checkpoint:
431
+ path = args.resume_from_checkpoint
432
+ global_step = int(path.split("-")[-1])
433
+ initial_global_step = global_step
434
+ else:
435
+ initial_global_step = 0
436
+ global_step = 0
437
+ first_epoch = 0
438
+
439
+ if args.scale_lr:
440
+ args.learning_rate = (
441
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
442
+ )
443
+
444
+ if args.mixed_precision == "fp16":
445
+ models = [transformer]
446
+ cast_training_params(models, dtype=torch.float32)
447
+
448
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
449
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
450
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
451
+
452
+ optimizer_class = torch.optim.AdamW
453
+ optimizer = optimizer_class(
454
+ [transformer_parameters_with_lr],
455
+ betas=(args.adam_beta1, args.adam_beta2),
456
+ weight_decay=args.adam_weight_decay,
457
+ eps=args.adam_epsilon,
458
+ )
459
+
460
+ tokenizers = [tokenizer_one, tokenizer_two]
461
+ text_encoders = [text_encoder_one, text_encoder_two]
462
+
463
+ train_dataset = make_train_dataset_inpaint_mask(args, tokenizers, accelerator)
464
+ train_dataloader = torch.utils.data.DataLoader(
465
+ train_dataset,
466
+ batch_size=args.train_batch_size,
467
+ shuffle=True,
468
+ collate_fn=collate_fn,
469
+ num_workers=args.dataloader_num_workers,
470
+ )
471
+
472
+ vae_config_shift_factor = vae.config.shift_factor
473
+ vae_config_scaling_factor = vae.config.scaling_factor
474
+
475
+ overrode_max_train_steps = False
476
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
477
+ if args.resume_from_checkpoint:
478
+ first_epoch = global_step // num_update_steps_per_epoch
479
+ if args.max_train_steps is None:
480
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
481
+ overrode_max_train_steps = True
482
+
483
+ lr_scheduler = get_scheduler(
484
+ args.lr_scheduler,
485
+ optimizer=optimizer,
486
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
487
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
488
+ num_cycles=args.lr_num_cycles,
489
+ power=args.lr_power,
490
+ )
491
+
492
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
493
+ transformer, optimizer, train_dataloader, lr_scheduler
494
+ )
495
+
496
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
497
+ if overrode_max_train_steps:
498
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
499
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
500
+
501
+ # Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
502
+ def _sanitize_hparams(config_dict):
503
+ sanitized = {}
504
+ for key, value in dict(config_dict).items():
505
+ try:
506
+ if value is None:
507
+ continue
508
+ # numpy scalar types
509
+ if isinstance(value, (np.integer,)):
510
+ sanitized[key] = int(value)
511
+ elif isinstance(value, (np.floating,)):
512
+ sanitized[key] = float(value)
513
+ elif isinstance(value, (int, float, bool, str)):
514
+ sanitized[key] = value
515
+ elif isinstance(value, Path):
516
+ sanitized[key] = str(value)
517
+ elif isinstance(value, (list, tuple)):
518
+ # stringify simple sequences; skip if fails
519
+ sanitized[key] = str(value)
520
+ else:
521
+ # best-effort stringify
522
+ sanitized[key] = str(value)
523
+ except Exception:
524
+ # skip unconvertible entries
525
+ continue
526
+ return sanitized
527
+
528
+ if accelerator.is_main_process:
529
+ tracker_name = "Easy_Control_Kontext"
530
+ accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
531
+
532
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
533
+ logger.info("***** Running training *****")
534
+ logger.info(f" Num examples = {len(train_dataset)}")
535
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
536
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
537
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
538
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
539
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
540
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
541
+
542
+ progress_bar = tqdm(
543
+ range(0, args.max_train_steps),
544
+ initial=initial_global_step,
545
+ desc="Steps",
546
+ disable=not accelerator.is_local_main_process,
547
+ )
548
+
549
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
550
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
551
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
552
+ timesteps = timesteps.to(accelerator.device)
553
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
554
+ sigma = sigmas[step_indices].flatten()
555
+ while len(sigma.shape) < n_dim:
556
+ sigma = sigma.unsqueeze(-1)
557
+ return sigma
558
+
559
+ # Kontext specifics
560
+ vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
561
+ # Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
562
+ height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
563
+ width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
564
+ offset = 64
565
+
566
+ for epoch in range(first_epoch, args.num_train_epochs):
567
+ transformer.train()
568
+ for step, batch in enumerate(train_dataloader):
569
+ models_to_accumulate = [transformer]
570
+ with accelerator.accumulate(models_to_accumulate):
571
+ tokens = [batch["text_ids_1"], batch["text_ids_2"]]
572
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
573
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
574
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
575
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
576
+
577
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
578
+ height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
579
+ width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
580
+
581
+ model_input = vae.encode(pixel_values).latent_dist.sample()
582
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
583
+ model_input = model_input.to(dtype=weight_dtype)
584
+
585
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
586
+ model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
587
+ )
588
+
589
+ noise = torch.randn_like(model_input)
590
+ bsz = model_input.shape[0]
591
+
592
+ u = compute_density_for_timestep_sampling(
593
+ weighting_scheme=args.weighting_scheme,
594
+ batch_size=bsz,
595
+ logit_mean=args.logit_mean,
596
+ logit_std=args.logit_std,
597
+ mode_scale=args.mode_scale,
598
+ )
599
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
600
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
601
+
602
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
603
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
604
+
605
+ packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
606
+ noisy_model_input,
607
+ batch_size=model_input.shape[0],
608
+ num_channels_latents=model_input.shape[1],
609
+ height=model_input.shape[2],
610
+ width=model_input.shape[3],
611
+ )
612
+
613
+ latent_image_ids_to_concat = [latent_image_ids]
614
+ packed_cond_model_input_to_concat = []
615
+
616
+ if args.kontext == "enable":
617
+ source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
618
+ source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
619
+ source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
620
+ image_latent_h, image_latent_w = source_image_latents.shape[2:]
621
+ packed_image_latents = FluxKontextControlPipeline._pack_latents(
622
+ source_image_latents,
623
+ batch_size=source_image_latents.shape[0],
624
+ num_channels_latents=source_image_latents.shape[1],
625
+ height=image_latent_h,
626
+ width=image_latent_w,
627
+ )
628
+ source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
629
+ batch_size=source_image_latents.shape[0],
630
+ height=image_latent_h // 2,
631
+ width=image_latent_w // 2,
632
+ device=accelerator.device,
633
+ dtype=weight_dtype,
634
+ )
635
+ source_image_ids[..., 0] = 1 # Mark as condition
636
+ latent_image_ids_to_concat.append(source_image_ids)
637
+
638
+
639
+ subject_pixel_values = batch.get("subject_pixel_values")
640
+ if subject_pixel_values is not None:
641
+ subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
642
+ subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
643
+ subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
644
+ subject_input = subject_input.to(dtype=weight_dtype)
645
+ sub_number = subject_pixel_values.shape[-2] // args.cond_size
646
+ latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
647
+ latent_subject_ids[..., 0] = 2
648
+ latent_subject_ids[:, 1] += offset
649
+ sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
650
+ latent_image_ids_to_concat.append(sub_latent_image_ids)
651
+
652
+ packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
653
+ subject_input,
654
+ batch_size=subject_input.shape[0],
655
+ num_channels_latents=subject_input.shape[1],
656
+ height=subject_input.shape[2],
657
+ width=subject_input.shape[3],
658
+ )
659
+ packed_cond_model_input_to_concat.append(packed_subject_model_input)
660
+
661
+ cond_pixel_values = batch.get("cond_pixel_values")
662
+ if cond_pixel_values is not None:
663
+ cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
664
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
665
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
666
+ cond_input = cond_input.to(dtype=weight_dtype)
667
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
668
+ cond_latent_image_ids[..., 0] = 2
669
+ cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
670
+ latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
671
+
672
+ packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
673
+ cond_input,
674
+ batch_size=cond_input.shape[0],
675
+ num_channels_latents=cond_input.shape[1],
676
+ height=cond_input.shape[2],
677
+ width=cond_input.shape[3],
678
+ )
679
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
680
+
681
+ latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
682
+ cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
683
+
684
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
685
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
686
+ guidance = guidance.expand(model_input.shape[0])
687
+ else:
688
+ guidance = None
689
+
690
+ latent_model_input=packed_noisy_model_input
691
+ if args.kontext == "enable":
692
+ latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
693
+ model_pred = transformer(
694
+ hidden_states=latent_model_input,
695
+ cond_hidden_states=cond_packed_noisy_model_input,
696
+ timestep=timesteps / 1000,
697
+ guidance=guidance,
698
+ pooled_projections=pooled_prompt_embeds,
699
+ encoder_hidden_states=prompt_embeds,
700
+ txt_ids=text_ids,
701
+ img_ids=latent_image_ids,
702
+ return_dict=False,
703
+ )[0]
704
+
705
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
706
+
707
+ model_pred = FluxKontextControlPipeline._unpack_latents(
708
+ model_pred,
709
+ height=int(pixel_values.shape[-2]),
710
+ width=int(pixel_values.shape[-1]),
711
+ vae_scale_factor=vae_scale_factor,
712
+ )
713
+
714
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
715
+ target = noise - model_input
716
+
717
+ loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
718
+ loss = loss.mean()
719
+ accelerator.backward(loss)
720
+ if accelerator.sync_gradients:
721
+ params_to_clip = (transformer.parameters())
722
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
723
+
724
+ optimizer.step()
725
+ lr_scheduler.step()
726
+ optimizer.zero_grad()
727
+
728
+ if accelerator.sync_gradients:
729
+ progress_bar.update(1)
730
+ global_step += 1
731
+
732
+ if accelerator.is_main_process:
733
+ if global_step % args.checkpointing_steps == 0:
734
+ if args.checkpoints_total_limit is not None:
735
+ checkpoints = os.listdir(args.output_dir)
736
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
737
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
738
+ if len(checkpoints) >= args.checkpoints_total_limit:
739
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
740
+ removing_checkpoints = checkpoints[0:num_to_remove]
741
+ logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
742
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
743
+ for removing_checkpoint in removing_checkpoints:
744
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
745
+ shutil.rmtree(removing_checkpoint)
746
+
747
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
748
+ os.makedirs(save_path, exist_ok=True)
749
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
750
+ lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
751
+ save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
752
+ logger.info(f"Saved state to {save_path}")
753
+
754
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
755
+ progress_bar.set_postfix(**logs)
756
+ accelerator.log(logs, step=global_step)
757
+
758
+ if accelerator.is_main_process:
759
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
760
+ pipeline = FluxKontextControlPipeline.from_pretrained(
761
+ args.pretrained_model_name_or_path,
762
+ vae=vae,
763
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
764
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
765
+ transformer=accelerator.unwrap_model(transformer),
766
+ revision=args.revision,
767
+ variant=args.variant,
768
+ torch_dtype=weight_dtype,
769
+ )
770
+
771
+ if args.subject_test_images is not None and len(args.subject_test_images) != 0 and args.subject_test_images != ['None']:
772
+ subject_paths = args.subject_test_images
773
+ subject_ls = [Image.open(image_path).convert("RGB") for image_path in subject_paths]
774
+ else:
775
+ subject_ls = []
776
+ if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
777
+ spatial_paths = args.spatial_test_images
778
+ spatial_ls = [Image.open(image_path).convert("RGB") for image_path in spatial_paths]
779
+ else:
780
+ spatial_ls = []
781
+
782
+ pipeline_args = {
783
+ "prompt": args.validation_prompt,
784
+ "cond_size": args.cond_size,
785
+ "guidance_scale": 3.5,
786
+ "num_inference_steps": 20,
787
+ "max_sequence_length": 128,
788
+ "control_dict": {"spatial_images": spatial_ls, "subject_images": subject_ls},
789
+ }
790
+
791
+ images = log_validation(
792
+ pipeline=pipeline,
793
+ args=args,
794
+ accelerator=accelerator,
795
+ pipeline_args=pipeline_args,
796
+ step=global_step,
797
+ torch_dtype=weight_dtype,
798
+ )
799
+ save_path = os.path.join(args.output_dir, "validation")
800
+ os.makedirs(save_path, exist_ok=True)
801
+ save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
802
+ os.makedirs(save_folder, exist_ok=True)
803
+ for idx, img in enumerate(images):
804
+ img.save(os.path.join(save_folder, f"{idx}.jpg"))
805
+ del pipeline
806
+
807
+ accelerator.wait_for_everyone()
808
+ accelerator.end_training()
809
+
810
+
811
+ if __name__ == "__main__":
812
+ args = parse_args()
813
+ main(args)
814
+
train/train_kontext_edge.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="/robby/share/Editing/lzc/FLUX.1-Kontext-dev" # your flux path
2
+ export OUTPUT_DIR="/robby/share/Editing/lzc/EasyControl_kontext_edge_test_hed" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+ export TRAIN_DATA="/robby/share/MM/zkc/data/i2i_csv/pexel_Qwen2_5VL7BInstruct.csv " # your data jsonl file
5
+ export LOG_PATH="$OUTPUT_DIR/log"
6
+
7
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_edge.py \
8
+ --pretrained_model_name_or_path $MODEL_DIR \
9
+ --lora_num=1 \
10
+ --cond_size=512 \
11
+ --ranks 128 \
12
+ --network_alphas 128 \
13
+ --output_dir=$OUTPUT_DIR \
14
+ --logging_dir=$LOG_PATH \
15
+ --mixed_precision="bf16" \
16
+ --train_data_dir=$TRAIN_DATA \
17
+ --learning_rate=1e-4 \
18
+ --train_batch_size=1 \
19
+ --num_train_epochs=1 \
20
+ --validation_steps=500 \
21
+ --checkpointing_steps=1000 \
22
+ --validation_images "./kontext_edge_test/img_1.png" "./kontext_edge_test/img_2.png" "" "" "./kontext_edge_test/img_3.png" \
23
+ --spatial_test_images "./kontext_edge_test/edge_1.png" "./kontext_edge_test/edge_2.png" "./kontext_edge_test/edge_1.png" "./kontext_edge_test/edge_2.png" "./kontext_edge_test/edge_3.png" \
24
+ --validation_prompt "The cake was cut off a piece" "Let this black woman wearing a transparent sunglasses" "This image shows a beautifully decorated cake with golden-orange sides and white frosting on top, and a piece of cake is being cut. The cake is displayed on a rustic wooden slice that serves as a cake stand." "This is a striking portrait photograph featuring a person wearing an ornate golden crown and a heart-shape sunglasses. The subject has dramatic golden metallic eyeshadow that extends across their eyelids, complementing the warm tones of the crown." "move the cup to the left" \
25
+ --num_validation_images=1
train/train_kontext_interactive_lora.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="" # your flux path
2
+ export OUTPUT_DIR="" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+ export LOG_PATH="$OUTPUT_DIR/log"
5
+
6
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_lora.py \
7
+ --pretrained_model_name_or_path $MODEL_DIR \
8
+ --output_dir=$OUTPUT_DIR \
9
+ --logging_dir=$LOG_PATH \
10
+ --mixed_precision="bf16" \
11
+ --learning_rate=1e-4 \
12
+ --train_batch_size=1 \
13
+ --num_train_epochs=10 \
14
+ --validation_steps=100 \
15
+ --checkpointing_steps=500 \
16
+ --validation_images "./kontext_interactive_test/img_1.png" \
17
+ --validation_prompt "Let the man hold the AK47 using both hands." \
18
+ --num_validation_images=1
train/train_kontext_local.py ADDED
@@ -0,0 +1,876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import shutil
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+ import re
10
+
11
+ from safetensors.torch import save_file
12
+ from PIL import Image
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ import diffusers
23
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
24
+ from diffusers.optimization import get_scheduler
25
+ from diffusers.training_utils import (
26
+ cast_training_params,
27
+ compute_density_for_timestep_sampling,
28
+ compute_loss_weighting_for_sd3,
29
+ )
30
+ from diffusers.utils.torch_utils import is_compiled_module
31
+ from diffusers.utils import (
32
+ check_min_version,
33
+ is_wandb_available,
34
+ )
35
+
36
+ from src.prompt_helper import *
37
+ from src.lora_helper import *
38
+ from src.jsonl_datasets_kontext_local import make_train_dataset_mixed, collate_fn
39
+ from src.pipeline_flux_kontext_control import (
40
+ FluxKontextControlPipeline,
41
+ resize_position_encoding,
42
+ prepare_latent_subject_ids,
43
+ PREFERRED_KONTEXT_RESOLUTIONS
44
+ )
45
+ from src.transformer_flux import FluxTransformer2DModel
46
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
47
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
48
+ from tqdm.auto import tqdm
49
+
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.31.0.dev0")
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ def compute_background_preserving_loss(model_pred, target, mask_values, weighting, background_weight: float = 3.0):
61
+ """
62
+ Compute loss with higher penalty on background (non-masked) regions to preserve them.
63
+ model_pred/target: [B, C, H, W]
64
+ mask_values: [B, 1, H_img, W_img] with values in {0,1} at image resolution
65
+ weighting: broadcastable to [B, C, H, W]
66
+ Returns per-pixel loss map [B, C, H, W]
67
+ """
68
+ base_loss = (weighting.float() * (model_pred.float() - target.float()) ** 2)
69
+ mask_latent = torch.nn.functional.interpolate(
70
+ mask_values,
71
+ size=(model_pred.shape[2], model_pred.shape[3]),
72
+ mode='bilinear',
73
+ align_corners=False,
74
+ )
75
+ foreground_mask = mask_latent
76
+ background_mask = 1.0 - mask_latent
77
+ foreground_mask = foreground_mask.expand_as(base_loss)
78
+ background_mask = background_mask.expand_as(base_loss)
79
+ foreground_loss = base_loss * foreground_mask
80
+ background_loss = base_loss * background_mask * float(background_weight)
81
+ total_loss = foreground_loss + background_loss
82
+ return total_loss
83
+
84
+ def log_validation(
85
+ pipeline,
86
+ args,
87
+ accelerator,
88
+ pipeline_args,
89
+ step,
90
+ torch_dtype,
91
+ is_final_validation=False,
92
+ ):
93
+ logger.info(
94
+ f"Running validation... Strict per-case evaluation for image, spatial image, and prompt."
95
+ )
96
+ pipeline = pipeline.to(accelerator.device)
97
+ pipeline.set_progress_bar_config(disable=True)
98
+
99
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
100
+ autocast_ctx = nullcontext()
101
+
102
+ # Build per-case evaluation: require equal lengths for image, spatial image, and prompt
103
+ if args.validation_images is None or args.validation_images == ['None']:
104
+ raise ValueError("validation_images must be provided and non-empty")
105
+ if args.validation_prompt is None:
106
+ raise ValueError("validation_prompt must be provided and non-empty")
107
+
108
+ control_dict_root = dict(pipeline_args.get("control_dict", {})) if pipeline_args is not None else {}
109
+ spatial_ls = control_dict_root.get("spatial_images", []) or []
110
+
111
+ val_imgs = args.validation_images
112
+ prompts = args.validation_prompt
113
+
114
+ if not (len(val_imgs) == len(prompts) == len(spatial_ls)):
115
+ raise ValueError(
116
+ f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}, spatial_images={len(spatial_ls)}"
117
+ )
118
+
119
+ results = []
120
+
121
+ def _resize_to_preferred(img: Image.Image) -> Image.Image:
122
+ w, h = img.size
123
+ aspect_ratio = w / h if h != 0 else 1.0
124
+ _, target_w, target_h = min(
125
+ (abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
126
+ for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
127
+ )
128
+ return img.resize((target_w, target_h), Image.BICUBIC)
129
+
130
+ # Distributed per-rank assignment: each process handles its own slice of cases
131
+ num_cases = len(prompts)
132
+ logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
133
+
134
+ rank = accelerator.process_index
135
+ world_size = accelerator.num_processes
136
+ local_indices = list(range(rank, num_cases, world_size))
137
+
138
+ local_images = []
139
+ with autocast_ctx:
140
+ for idx in local_indices:
141
+ try:
142
+ base_img = Image.open(val_imgs[idx]).convert("RGB")
143
+ resized_img = _resize_to_preferred(base_img)
144
+ except Exception as e:
145
+ raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
146
+
147
+ case_args = dict(pipeline_args) if pipeline_args is not None else {}
148
+ case_args.pop("height", None)
149
+ case_args.pop("width", None)
150
+ if resized_img is not None:
151
+ tw, th = resized_img.size
152
+ case_args["height"] = th
153
+ case_args["width"] = tw
154
+
155
+ case_control = dict(case_args.get("control_dict", {}))
156
+ spatial_case = spatial_ls[idx]
157
+
158
+ # Compose masked image cond: resized_img * (1 - binary_mask)
159
+ try:
160
+ mask_img = Image.open(spatial_case).convert("L") if isinstance(spatial_case, str) else spatial_case.convert("L")
161
+ except Exception:
162
+ mask_img = spatial_case.convert("L")
163
+ mask_img = mask_img.resize(resized_img.size, Image.NEAREST)
164
+ mask_np = np.array(mask_img)
165
+ mask_bin = (mask_np > 127).astype(np.uint8)
166
+ inv_mask = (1 - mask_bin).astype(np.uint8)
167
+ base_np = np.array(resized_img)
168
+ masked_np = base_np * inv_mask[..., None]
169
+ masked_img = Image.fromarray(masked_np.astype(np.uint8))
170
+
171
+ case_control["spatial_images"] = [masked_img]
172
+ case_args["control_dict"] = case_control
173
+
174
+ case_args["prompt"] = prompts[idx]
175
+ img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
176
+ local_images.append(img)
177
+
178
+ # Gather one image per rank (pad missing ranks with black images) to main process
179
+ fixed_size = (1024, 1024)
180
+ has_sample = torch.tensor([1 if len(local_images) > 0 else 0], device=accelerator.device, dtype=torch.int)
181
+ local_idx = torch.tensor([local_indices[0] if len(local_indices) > 0 else -1], device=accelerator.device, dtype=torch.long)
182
+ if len(local_images) > 0:
183
+ gathered_img = local_images[0].resize(fixed_size, Image.BICUBIC)
184
+ img_np = np.asarray(gathered_img).astype(np.uint8)
185
+ else:
186
+ img_np = np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
187
+ img_tensor = torch.from_numpy(img_np).to(device=accelerator.device)
188
+ if img_tensor.ndim == 3:
189
+ img_tensor = img_tensor.unsqueeze(0)
190
+
191
+ gathered_has = accelerator.gather(has_sample)
192
+ gathered_idx = accelerator.gather(local_idx)
193
+ gathered_imgs = accelerator.gather(img_tensor)
194
+
195
+ if accelerator.is_main_process:
196
+ for i in range(int(gathered_has.shape[0])):
197
+ if int(gathered_has[i].item()) == 1:
198
+ idx = int(gathered_idx[i].item())
199
+ arr = gathered_imgs[i].cpu().numpy()
200
+ pil_img = Image.fromarray(arr.astype(np.uint8))
201
+ # Resize back to original validation image size
202
+ try:
203
+ orig = Image.open(val_imgs[idx]).convert("RGB")
204
+ pil_img = pil_img.resize(orig.size, Image.BICUBIC)
205
+ except Exception:
206
+ pass
207
+ results.append(pil_img)
208
+
209
+ del pipeline
210
+ if torch.cuda.is_available():
211
+ torch.cuda.empty_cache()
212
+
213
+ return results
214
+
215
+
216
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
217
+ text_encoder_config = transformers.PretrainedConfig.from_pretrained(
218
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
219
+ )
220
+ model_class = text_encoder_config.architectures[0]
221
+ if model_class == "CLIPTextModel":
222
+ from transformers import CLIPTextModel
223
+
224
+ return CLIPTextModel
225
+ elif model_class == "T5EncoderModel":
226
+ from transformers import T5EncoderModel
227
+
228
+ return T5EncoderModel
229
+ else:
230
+ raise ValueError(f"{model_class} is not supported.")
231
+
232
+
233
+ def parse_args(input_args=None):
234
+ parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
235
+ parser.add_argument("--lora_num", type=int, default=1, help="number of the lora.")
236
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
237
+ parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
238
+
239
+ # New dataset (local edits + inpaint JSONL) mixed 1:1
240
+ parser.add_argument("--local_edits_json", type=str, default="/robby/share/Editing/qingyan/InstructV2V/Qwen2_5_72B_instructs_10W.json", help="Path to local edits JSON")
241
+ parser.add_argument("--train_data_dir", type=str, default="/robby/share/Editing/lzc/data/pexel_final/inpaint_edit_outputs_merged.jsonl", help="Path to inpaint JSONL file for mixing 1:1")
242
+ parser.add_argument("--source_frames_dir", type=str, default="/robby/share/Editing/qingyan/InstructV2V/pexel-video-merged-1frame", help="Root dir containing group folders like 0139")
243
+ parser.add_argument("--target_frames_dir", type=str, default="/robby/share/Editing/qingyan/InstructV2V/pexel-video-1frame-kontext-edit/local", help="Root dir containing group folders like 0139")
244
+ parser.add_argument("--masks_dir", type=str, default="/robby/share/Editing/lzc/InstructV2V/diff_masks", help="Root dir of precomputed masks organized as <group>/<prefix>_{i}.png")
245
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
246
+ parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
247
+ parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
248
+ parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
249
+
250
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
251
+ parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
252
+ parser.add_argument("--kontext", type=str, default="enable")
253
+ parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
254
+ parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
255
+ parser.add_argument("--subject_test_images", type=str, nargs="+", default=None, help="List of subject test images")
256
+ parser.add_argument("--spatial_test_images", type=str, nargs="+", default=None, help="List of spatial test images")
257
+ parser.add_argument("--num_validation_images", type=int, default=4)
258
+ parser.add_argument("--validation_steps", type=int, default=20)
259
+
260
+ parser.add_argument("--ranks", type=int, nargs="+", default=[256], help="LoRA ranks")
261
+ parser.add_argument("--network_alphas", type=int, nargs="+", default=[256], help="LoRA network alphas")
262
+ parser.add_argument("--output_dir", type=str, default="/tiamat-NAS/zhangyuxuan/projects2/Easy_Control_0120/single_models/subject_model", help="Output directory")
263
+ parser.add_argument("--seed", type=int, default=None)
264
+ parser.add_argument("--train_batch_size", type=int, default=1)
265
+ parser.add_argument("--num_train_epochs", type=int, default=50)
266
+ parser.add_argument("--max_train_steps", type=int, default=None)
267
+ parser.add_argument("--checkpointing_steps", type=int, default=1000)
268
+ parser.add_argument("--checkpoints_total_limit", type=int, default=None)
269
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
270
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
271
+ parser.add_argument("--gradient_checkpointing", action="store_true")
272
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
273
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
274
+ parser.add_argument("--scale_lr", action="store_true", default=False)
275
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
276
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
277
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
278
+ parser.add_argument("--lr_power", type=float, default=1.0)
279
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
280
+ parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
281
+ parser.add_argument("--logit_mean", type=float, default=0.0)
282
+ parser.add_argument("--logit_std", type=float, default=1.0)
283
+ parser.add_argument("--mode_scale", type=float, default=1.29)
284
+ parser.add_argument("--optimizer", type=str, default="AdamW")
285
+ parser.add_argument("--use_8bit_adam", action="store_true")
286
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
287
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
288
+ parser.add_argument("--prodigy_beta3", type=float, default=None)
289
+ parser.add_argument("--prodigy_decouple", type=bool, default=True)
290
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
291
+ parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
292
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
293
+ parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
294
+ parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
295
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
296
+ parser.add_argument("--logging_dir", type=str, default="logs")
297
+ parser.add_argument("--cache_latents", action="store_true", default=False)
298
+ parser.add_argument("--report_to", type=str, default="tensorboard")
299
+ parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
300
+ parser.add_argument("--upcast_before_saving", action="store_true", default=False)
301
+ parser.add_argument("--mix_ratio", type=float, default=0, help="Ratio of inpaint to local edits (B per A). 0=only local edits, 1=1:1, 2=1:2")
302
+ parser.add_argument("--background_weight", type=float, default=1.0, help="Background preserving loss weight multiplier")
303
+
304
+ # Blending options for dataset pixel_values
305
+ parser.add_argument("--blend_pixel_values", action="store_true", help="Blend target/source into pixel_values using mask")
306
+ parser.add_argument("--blend_kernel", type=int, default=21, help="Gaussian blur kernel size (must be odd)")
307
+ parser.add_argument("--blend_sigma", type=float, default=10.0, help="Gaussian blur sigma")
308
+
309
+ if input_args is not None:
310
+ args = parser.parse_args(input_args)
311
+ else:
312
+ args = parser.parse_args()
313
+ return args
314
+
315
+
316
+ def main(args):
317
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
318
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
319
+
320
+ if args.output_dir is not None:
321
+ os.makedirs(args.output_dir, exist_ok=True)
322
+ os.makedirs(args.logging_dir, exist_ok=True)
323
+ logging_dir = Path(args.output_dir, args.logging_dir)
324
+
325
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
326
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
327
+ accelerator = Accelerator(
328
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
329
+ mixed_precision=args.mixed_precision,
330
+ log_with=args.report_to,
331
+ project_config=accelerator_project_config,
332
+ kwargs_handlers=[kwargs],
333
+ )
334
+
335
+ if torch.backends.mps.is_available():
336
+ accelerator.native_amp = False
337
+
338
+ if args.report_to == "wandb":
339
+ if not is_wandb_available():
340
+ raise ImportError("Install wandb for logging during training.")
341
+
342
+ logging.basicConfig(
343
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
344
+ datefmt="%m/%d/%Y %H:%M:%S",
345
+ level=logging.INFO,
346
+ )
347
+ logger.info(accelerator.state, main_process_only=False)
348
+ if accelerator.is_local_main_process:
349
+ transformers.utils.logging.set_verbosity_warning()
350
+ diffusers.utils.logging.set_verbosity_info()
351
+ else:
352
+ transformers.utils.logging.set_verbosity_error()
353
+ diffusers.utils.logging.set_verbosity_error()
354
+
355
+ if args.seed is not None:
356
+ set_seed(args.seed)
357
+
358
+ if accelerator.is_main_process and args.output_dir is not None:
359
+ os.makedirs(args.output_dir, exist_ok=True)
360
+
361
+ # Tokenizers
362
+ tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
363
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
364
+ )
365
+ tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
366
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
367
+ )
368
+
369
+ # Text encoders
370
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
371
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
372
+
373
+ # Scheduler and models
374
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
375
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
376
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
377
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
378
+ transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
379
+
380
+ # Train only LoRA adapters
381
+ transformer.requires_grad_(True)
382
+ vae.requires_grad_(False)
383
+ text_encoder_one.requires_grad_(False)
384
+ text_encoder_two.requires_grad_(False)
385
+
386
+ weight_dtype = torch.float32
387
+ if accelerator.mixed_precision == "fp16":
388
+ weight_dtype = torch.float16
389
+ elif accelerator.mixed_precision == "bf16":
390
+ weight_dtype = torch.bfloat16
391
+
392
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
393
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
394
+
395
+ vae.to(accelerator.device, dtype=weight_dtype)
396
+ transformer.to(accelerator.device, dtype=weight_dtype)
397
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
398
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
399
+
400
+ if args.gradient_checkpointing:
401
+ transformer.enable_gradient_checkpointing()
402
+
403
+ # Setup LoRA attention processors
404
+ if args.pretrained_lora_path is not None:
405
+ lora_path = args.pretrained_lora_path
406
+ checkpoint = load_checkpoint(lora_path)
407
+ lora_attn_procs = {}
408
+ double_blocks_idx = list(range(19))
409
+ single_blocks_idx = list(range(38))
410
+ number = 1
411
+ for name, attn_processor in transformer.attn_processors.items():
412
+ match = re.search(r'\.(\d+)\.', name)
413
+ if match:
414
+ layer_index = int(match.group(1))
415
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
416
+ lora_state_dicts = {}
417
+ for key, value in checkpoint.items():
418
+ if re.search(r'\.(\d+)\.', key):
419
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
420
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
421
+ lora_state_dicts[key] = value
422
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
423
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
424
+ )
425
+ for n in range(number):
426
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
427
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
428
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
429
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
430
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
431
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
432
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
433
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
434
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
435
+ lora_state_dicts = {}
436
+ for key, value in checkpoint.items():
437
+ if re.search(r'\.(\d+)\.', key):
438
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
439
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
440
+ lora_state_dicts[key] = value
441
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
442
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
443
+ )
444
+ for n in range(number):
445
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
446
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
447
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
448
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
449
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
450
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
451
+ else:
452
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
453
+ else:
454
+ lora_attn_procs = {}
455
+ double_blocks_idx = list(range(19))
456
+ single_blocks_idx = list(range(38))
457
+ for name, attn_processor in transformer.attn_processors.items():
458
+ match = re.search(r'\.(\d+)\.', name)
459
+ if match:
460
+ layer_index = int(match.group(1))
461
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
462
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
463
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
464
+ )
465
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
466
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
467
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
468
+ )
469
+ else:
470
+ lora_attn_procs[name] = attn_processor
471
+
472
+ transformer.set_attn_processor(lora_attn_procs)
473
+ transformer.train()
474
+ for n, param in transformer.named_parameters():
475
+ if '_lora' not in n:
476
+ param.requires_grad = False
477
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
478
+
479
+ def unwrap_model(model):
480
+ model = accelerator.unwrap_model(model)
481
+ model = model._orig_mod if is_compiled_module(model) else model
482
+ return model
483
+
484
+ if args.resume_from_checkpoint:
485
+ path = args.resume_from_checkpoint
486
+ global_step = int(path.split("-")[-1])
487
+ initial_global_step = global_step
488
+ else:
489
+ initial_global_step = 0
490
+ global_step = 0
491
+ first_epoch = 0
492
+
493
+ if args.scale_lr:
494
+ args.learning_rate = (
495
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
496
+ )
497
+
498
+ if args.mixed_precision == "fp16":
499
+ models = [transformer]
500
+ cast_training_params(models, dtype=torch.float32)
501
+
502
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
503
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
504
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
505
+
506
+ optimizer_class = torch.optim.AdamW
507
+ optimizer = optimizer_class(
508
+ [transformer_parameters_with_lr],
509
+ betas=(args.adam_beta1, args.adam_beta2),
510
+ weight_decay=args.adam_weight_decay,
511
+ eps=args.adam_epsilon,
512
+ )
513
+
514
+ tokenizers = [tokenizer_one, tokenizer_two]
515
+ text_encoders = [text_encoder_one, text_encoder_two]
516
+
517
+ train_dataset = make_train_dataset_mixed(args, tokenizers, accelerator)
518
+ train_dataloader = torch.utils.data.DataLoader(
519
+ train_dataset,
520
+ batch_size=args.train_batch_size,
521
+ shuffle=True,
522
+ collate_fn=collate_fn,
523
+ num_workers=args.dataloader_num_workers,
524
+ )
525
+
526
+ vae_config_shift_factor = vae.config.shift_factor
527
+ vae_config_scaling_factor = vae.config.scaling_factor
528
+
529
+ overrode_max_train_steps = False
530
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
531
+ if args.resume_from_checkpoint:
532
+ first_epoch = global_step // num_update_steps_per_epoch
533
+ if args.max_train_steps is None:
534
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
535
+ overrode_max_train_steps = True
536
+
537
+ lr_scheduler = get_scheduler(
538
+ args.lr_scheduler,
539
+ optimizer=optimizer,
540
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
541
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
542
+ num_cycles=args.lr_num_cycles,
543
+ power=args.lr_power,
544
+ )
545
+
546
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
547
+ transformer, optimizer, train_dataloader, lr_scheduler
548
+ )
549
+
550
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
551
+ if overrode_max_train_steps:
552
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
553
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
554
+
555
+ # Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
556
+ def _sanitize_hparams(config_dict):
557
+ sanitized = {}
558
+ for key, value in dict(config_dict).items():
559
+ try:
560
+ if value is None:
561
+ continue
562
+ # numpy scalar types
563
+ if isinstance(value, (np.integer,)):
564
+ sanitized[key] = int(value)
565
+ elif isinstance(value, (np.floating,)):
566
+ sanitized[key] = float(value)
567
+ elif isinstance(value, (int, float, bool, str)):
568
+ sanitized[key] = value
569
+ elif isinstance(value, Path):
570
+ sanitized[key] = str(value)
571
+ elif isinstance(value, (list, tuple)):
572
+ # stringify simple sequences; skip if fails
573
+ sanitized[key] = str(value)
574
+ else:
575
+ # best-effort stringify
576
+ sanitized[key] = str(value)
577
+ except Exception:
578
+ # skip unconvertible entries
579
+ continue
580
+ return sanitized
581
+
582
+ if accelerator.is_main_process:
583
+ tracker_name = "Easy_Control_Kontext"
584
+ accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
585
+
586
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
587
+ logger.info("***** Running training *****")
588
+ logger.info(f" Num examples = {len(train_dataset)}")
589
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
590
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
591
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
592
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
593
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
594
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
595
+
596
+ progress_bar = tqdm(
597
+ range(0, args.max_train_steps),
598
+ initial=initial_global_step,
599
+ desc="Steps",
600
+ disable=not accelerator.is_local_main_process,
601
+ )
602
+
603
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
604
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
605
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
606
+ timesteps = timesteps.to(accelerator.device)
607
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
608
+ sigma = sigmas[step_indices].flatten()
609
+ while len(sigma.shape) < n_dim:
610
+ sigma = sigma.unsqueeze(-1)
611
+ return sigma
612
+
613
+ # Kontext specifics
614
+ vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
615
+ # Match pipeline's prepare_latents cond resolution: 2 * (cond_size // (vae_scale_factor * 2))
616
+ height_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
617
+ width_cond = 2 * (args.cond_size // (vae_scale_factor * 2))
618
+ offset = 64
619
+
620
+ for epoch in range(first_epoch, args.num_train_epochs):
621
+ transformer.train()
622
+ for step, batch in enumerate(train_dataloader):
623
+ models_to_accumulate = [transformer]
624
+ with accelerator.accumulate(models_to_accumulate):
625
+ tokens = [batch["text_ids_1"], batch["text_ids_2"]]
626
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
627
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
628
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
629
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
630
+
631
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
632
+ height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
633
+ width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
634
+
635
+ model_input = vae.encode(pixel_values).latent_dist.sample()
636
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
637
+ model_input = model_input.to(dtype=weight_dtype)
638
+
639
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
640
+ model_input.shape[0], height_, width_, height_cond, width_cond, accelerator.device, weight_dtype
641
+ )
642
+
643
+ noise = torch.randn_like(model_input)
644
+ bsz = model_input.shape[0]
645
+
646
+ u = compute_density_for_timestep_sampling(
647
+ weighting_scheme=args.weighting_scheme,
648
+ batch_size=bsz,
649
+ logit_mean=args.logit_mean,
650
+ logit_std=args.logit_std,
651
+ mode_scale=args.mode_scale,
652
+ )
653
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
654
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
655
+
656
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
657
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
658
+
659
+ packed_noisy_model_input = FluxKontextControlPipeline._pack_latents(
660
+ noisy_model_input,
661
+ batch_size=model_input.shape[0],
662
+ num_channels_latents=model_input.shape[1],
663
+ height=model_input.shape[2],
664
+ width=model_input.shape[3],
665
+ )
666
+
667
+ latent_image_ids_to_concat = [latent_image_ids]
668
+ packed_cond_model_input_to_concat = []
669
+
670
+ if args.kontext == "enable":
671
+ source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
672
+ source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
673
+ source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
674
+ image_latent_h, image_latent_w = source_image_latents.shape[2:]
675
+ packed_image_latents = FluxKontextControlPipeline._pack_latents(
676
+ source_image_latents,
677
+ batch_size=source_image_latents.shape[0],
678
+ num_channels_latents=source_image_latents.shape[1],
679
+ height=image_latent_h,
680
+ width=image_latent_w,
681
+ )
682
+ source_image_ids = FluxKontextControlPipeline._prepare_latent_image_ids(
683
+ batch_size=source_image_latents.shape[0],
684
+ height=image_latent_h // 2,
685
+ width=image_latent_w // 2,
686
+ device=accelerator.device,
687
+ dtype=weight_dtype,
688
+ )
689
+ source_image_ids[..., 0] = 1 # Mark as condition
690
+ latent_image_ids_to_concat.append(source_image_ids)
691
+
692
+
693
+ subject_pixel_values = batch.get("subject_pixel_values")
694
+ if subject_pixel_values is not None:
695
+ subject_pixel_values = subject_pixel_values.to(dtype=vae.dtype)
696
+ subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
697
+ subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
698
+ subject_input = subject_input.to(dtype=weight_dtype)
699
+ sub_number = subject_pixel_values.shape[-2] // args.cond_size
700
+ latent_subject_ids = prepare_latent_subject_ids(height_cond // 2, width_cond // 2, accelerator.device, weight_dtype)
701
+ latent_subject_ids[..., 0] = 2
702
+ latent_subject_ids[:, 1] += offset
703
+ sub_latent_image_ids = torch.cat([latent_subject_ids for _ in range(sub_number)], dim=0)
704
+ latent_image_ids_to_concat.append(sub_latent_image_ids)
705
+
706
+ packed_subject_model_input = FluxKontextControlPipeline._pack_latents(
707
+ subject_input,
708
+ batch_size=subject_input.shape[0],
709
+ num_channels_latents=subject_input.shape[1],
710
+ height=subject_input.shape[2],
711
+ width=subject_input.shape[3],
712
+ )
713
+ packed_cond_model_input_to_concat.append(packed_subject_model_input)
714
+
715
+ cond_pixel_values = batch.get("cond_pixel_values")
716
+ if cond_pixel_values is not None:
717
+ cond_pixel_values = cond_pixel_values.to(dtype=vae.dtype)
718
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
719
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
720
+ cond_input = cond_input.to(dtype=weight_dtype)
721
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
722
+ cond_latent_image_ids[..., 0] = 2
723
+ cond_latent_image_ids_rep = torch.cat([cond_latent_image_ids for _ in range(cond_number)], dim=0)
724
+ latent_image_ids_to_concat.append(cond_latent_image_ids_rep)
725
+
726
+ packed_cond_model_input = FluxKontextControlPipeline._pack_latents(
727
+ cond_input,
728
+ batch_size=cond_input.shape[0],
729
+ num_channels_latents=cond_input.shape[1],
730
+ height=cond_input.shape[2],
731
+ width=cond_input.shape[3],
732
+ )
733
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
734
+
735
+ latent_image_ids = torch.cat(latent_image_ids_to_concat, dim=0)
736
+ cond_packed_noisy_model_input = torch.cat(packed_cond_model_input_to_concat, dim=1)
737
+
738
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
739
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
740
+ guidance = guidance.expand(model_input.shape[0])
741
+ else:
742
+ guidance = None
743
+
744
+ latent_model_input=packed_noisy_model_input
745
+ if args.kontext == "enable":
746
+ latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
747
+ model_pred = transformer(
748
+ hidden_states=latent_model_input,
749
+ cond_hidden_states=cond_packed_noisy_model_input,
750
+ timestep=timesteps / 1000,
751
+ guidance=guidance,
752
+ pooled_projections=pooled_prompt_embeds,
753
+ encoder_hidden_states=prompt_embeds,
754
+ txt_ids=text_ids,
755
+ img_ids=latent_image_ids,
756
+ return_dict=False,
757
+ )[0]
758
+
759
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
760
+
761
+ model_pred = FluxKontextControlPipeline._unpack_latents(
762
+ model_pred,
763
+ height=int(pixel_values.shape[-2]),
764
+ width=int(pixel_values.shape[-1]),
765
+ vae_scale_factor=vae_scale_factor,
766
+ )
767
+
768
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
769
+ target = noise - model_input
770
+
771
+ # mask_values = batch.get("mask_values")
772
+ # if mask_values is not None:
773
+ # mask_values = mask_values.to(device=accelerator.device, dtype=model_pred.dtype)
774
+ # loss_map = compute_background_preserving_loss(
775
+ # model_pred=model_pred,
776
+ # target=target,
777
+ # mask_values=mask_values,
778
+ # weighting=weighting,
779
+ # background_weight=args.background_weight,
780
+ # )
781
+ # loss = torch.mean(loss_map.reshape(target.shape[0], -1), 1)
782
+ # loss = loss.mean()
783
+ # else:
784
+ loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
785
+ loss = loss.mean()
786
+ accelerator.backward(loss)
787
+ if accelerator.sync_gradients:
788
+ params_to_clip = (transformer.parameters())
789
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
790
+
791
+ optimizer.step()
792
+ lr_scheduler.step()
793
+ optimizer.zero_grad()
794
+
795
+ if accelerator.sync_gradients:
796
+ progress_bar.update(1)
797
+ global_step += 1
798
+
799
+ if accelerator.is_main_process:
800
+ if global_step % args.checkpointing_steps == 0:
801
+ if args.checkpoints_total_limit is not None:
802
+ checkpoints = os.listdir(args.output_dir)
803
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
804
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
805
+ if len(checkpoints) >= args.checkpoints_total_limit:
806
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
807
+ removing_checkpoints = checkpoints[0:num_to_remove]
808
+ logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
809
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
810
+ for removing_checkpoint in removing_checkpoints:
811
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
812
+ shutil.rmtree(removing_checkpoint)
813
+
814
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
815
+ os.makedirs(save_path, exist_ok=True)
816
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
817
+ lora_state_dict = {k: unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
818
+ save_file(lora_state_dict, os.path.join(save_path, "lora.safetensors"))
819
+ logger.info(f"Saved state to {save_path}")
820
+
821
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
822
+ progress_bar.set_postfix(**logs)
823
+ accelerator.log(logs, step=global_step)
824
+
825
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
826
+ pipeline = FluxKontextControlPipeline.from_pretrained(
827
+ args.pretrained_model_name_or_path,
828
+ vae=vae,
829
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
830
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
831
+ transformer=accelerator.unwrap_model(transformer),
832
+ revision=args.revision,
833
+ variant=args.variant,
834
+ torch_dtype=weight_dtype,
835
+ )
836
+
837
+ if args.spatial_test_images is not None and len(args.spatial_test_images) != 0 and args.spatial_test_images != ['None']:
838
+ spatial_paths = args.spatial_test_images
839
+ spatial_ls = [Image.open(image_path).convert("RGB") for image_path in spatial_paths]
840
+ else:
841
+ spatial_ls = []
842
+
843
+ pipeline_args = {
844
+ "prompt": args.validation_prompt,
845
+ "cond_size": args.cond_size,
846
+ "guidance_scale": 3.5,
847
+ "num_inference_steps": 20,
848
+ "max_sequence_length": 128,
849
+ "control_dict": {"spatial_images": spatial_ls},
850
+ }
851
+
852
+ images = log_validation(
853
+ pipeline=pipeline,
854
+ args=args,
855
+ accelerator=accelerator,
856
+ pipeline_args=pipeline_args,
857
+ step=global_step,
858
+ torch_dtype=weight_dtype,
859
+ )
860
+ if accelerator.is_main_process:
861
+ save_path = os.path.join(args.output_dir, "validation")
862
+ os.makedirs(save_path, exist_ok=True)
863
+ save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
864
+ os.makedirs(save_folder, exist_ok=True)
865
+ for idx, img in enumerate(images):
866
+ img.save(os.path.join(save_folder, f"{idx}.jpg"))
867
+ del pipeline
868
+
869
+ accelerator.wait_for_everyone()
870
+ accelerator.end_training()
871
+
872
+
873
+ if __name__ == "__main__":
874
+ args = parse_args()
875
+ main(args)
876
+
train/train_kontext_local.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="" # your flux path
2
+ export OUTPUT_DIR="" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+ export LOG_PATH="$OUTPUT_DIR/log"
5
+
6
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file $CONFIG train_kontext_qy.py \
7
+ --pretrained_model_name_or_path $MODEL_DIR \
8
+ --pretrained_lora_path "" \
9
+ --lora_num=1 \
10
+ --cond_size=512 \
11
+ --ranks 128 \
12
+ --network_alphas 128 \
13
+ --output_dir=$OUTPUT_DIR \
14
+ --logging_dir=$LOG_PATH \
15
+ --mixed_precision="bf16" \
16
+ --learning_rate=1e-4 \
17
+ --train_batch_size=1 \
18
+ --num_train_epochs=1 \
19
+ --validation_steps=250 \
20
+ --checkpointing_steps=1000 \
21
+ --validation_images "./kontext_local_test/img_1.png" \
22
+ --spatial_test_images "./kontext_local_test/mask_1.png" \
23
+ --validation_prompt "convert the dinosaur into blue color" \
24
+ --gradient_checkpointing \
25
+ --blend_pixel_values \
26
+ --num_validation_images=1
train/train_kontext_lora.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import shutil
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+ import re
10
+ import time
11
+
12
+ from safetensors.torch import save_file
13
+ from PIL import Image
14
+ import numpy as np
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ import transformers
18
+
19
+ from accelerate import Accelerator
20
+ from accelerate.logging import get_logger
21
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
22
+
23
+ import diffusers
24
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline
25
+ from diffusers.optimization import get_scheduler
26
+ from diffusers.training_utils import (
27
+ cast_training_params,
28
+ compute_density_for_timestep_sampling,
29
+ compute_loss_weighting_for_sd3,
30
+ )
31
+ from diffusers.utils.torch_utils import is_compiled_module
32
+ from diffusers.utils import (
33
+ check_min_version,
34
+ is_wandb_available,
35
+ )
36
+
37
+ from src.prompt_helper import *
38
+ from src.lora_helper import *
39
+ from src.jsonl_datasets_kontext_interactive_lora import make_interactive_dataset_subjects, make_placement_dataset_subjects, make_pexels_dataset_subjects, make_mixed_dataset, collate_fn
40
+ from diffusers import FluxKontextPipeline
41
+ from diffusers.models import FluxTransformer2DModel
42
+ from tqdm.auto import tqdm
43
+ from peft import LoraConfig
44
+ from peft.utils import get_peft_model_state_dict
45
+ from diffusers.utils import convert_state_dict_to_diffusers
46
+
47
+ if is_wandb_available():
48
+ import wandb
49
+
50
+
51
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
52
+ check_min_version("0.31.0.dev0")
53
+
54
+ logger = get_logger(__name__)
55
+
56
+
57
+ PREFERRED_KONTEXT_RESOLUTIONS = [
58
+ (672, 1568),
59
+ (688, 1504),
60
+ (720, 1456),
61
+ (752, 1392),
62
+ (832, 1248),
63
+ (880, 1184),
64
+ (944, 1104),
65
+ (1024, 1024),
66
+ (1104, 944),
67
+ (1184, 880),
68
+ (1248, 832),
69
+ (1392, 752),
70
+ (1456, 720),
71
+ (1504, 688),
72
+ (1568, 672),
73
+ ]
74
+
75
+
76
+ def log_validation(
77
+ pipeline,
78
+ args,
79
+ accelerator,
80
+ pipeline_args,
81
+ step,
82
+ torch_dtype,
83
+ is_final_validation=False,
84
+ ):
85
+ logger.info(
86
+ f"Running validation... Paired evaluation for image and prompt."
87
+ )
88
+ pipeline = pipeline.to(device=accelerator.device, dtype=torch_dtype)
89
+ pipeline.set_progress_bar_config(disable=True)
90
+
91
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
92
+ # Match compute dtype for validation to avoid dtype mismatches (e.g., VAE bf16 vs float latents)
93
+ if torch_dtype in (torch.float16, torch.bfloat16):
94
+ device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
95
+ autocast_ctx = torch.autocast(device_type=device_type, dtype=torch_dtype)
96
+ else:
97
+ autocast_ctx = nullcontext()
98
+
99
+ # Build per-case evaluation
100
+ if args.validation_images is None or args.validation_images == ['None']:
101
+ raise ValueError("validation_images must be provided and non-empty")
102
+ if args.validation_prompt is None:
103
+ raise ValueError("validation_prompt must be provided and non-empty")
104
+
105
+ val_imgs = args.validation_images
106
+ prompts = args.validation_prompt
107
+ # Prepend instruction to each prompt (same as dataset/test requirement)
108
+ instruction = "Fill in the white region naturally and adapt the foreground into the background. Fix the perspective of the foreground object if necessary."
109
+ try:
110
+ prompts = [f"{instruction} {p}".strip() if isinstance(p, str) and len(p.strip()) > 0 else instruction for p in prompts]
111
+ except Exception:
112
+ # Fallback: keep original prompts if unexpected
113
+ pass
114
+
115
+ if not (len(val_imgs) == len(prompts)):
116
+ raise ValueError(
117
+ f"Length mismatch: validation_images={len(val_imgs)}, validation_prompt={len(prompts)}"
118
+ )
119
+
120
+ results = []
121
+
122
+ def _resize_to_preferred(img: Image.Image) -> Image.Image:
123
+ w, h = img.size
124
+ aspect_ratio = w / h if h != 0 else 1.0
125
+ _, target_w, target_h = min(
126
+ (abs(aspect_ratio - (pref_w / pref_h)), pref_w, pref_h)
127
+ for (pref_h, pref_w) in PREFERRED_KONTEXT_RESOLUTIONS
128
+ )
129
+ return img.resize((target_w, target_h), Image.BICUBIC)
130
+
131
+ # Distributed per-rank assignment: each process handles its own slice of cases
132
+ num_cases = len(prompts)
133
+ logger.info(f"Paired validation (distributed): {num_cases} cases across {accelerator.num_processes} ranks")
134
+
135
+ # Indices assigned to this rank
136
+ rank = accelerator.process_index
137
+ world_size = accelerator.num_processes
138
+ local_indices = list(range(rank, num_cases, world_size))
139
+
140
+ local_images = []
141
+ with autocast_ctx:
142
+ for idx in local_indices:
143
+ try:
144
+ base_img = Image.open(val_imgs[idx]).convert("RGB")
145
+ resized_img = _resize_to_preferred(base_img)
146
+ except Exception as e:
147
+ raise ValueError(f"Failed to load/resize validation image idx={idx}: {e}")
148
+
149
+ case_args = dict(pipeline_args) if pipeline_args is not None else {}
150
+ case_args.pop("height", None)
151
+ case_args.pop("width", None)
152
+ if resized_img is not None:
153
+ tw, th = resized_img.size
154
+ case_args["height"] = th
155
+ case_args["width"] = tw
156
+
157
+ case_args["prompt"] = prompts[idx]
158
+ img = pipeline(image=resized_img, **case_args, generator=generator).images[0]
159
+ local_images.append(img)
160
+
161
+ # Gather all images per rank (pad to equal count) to main process
162
+ fixed_size = (1024, 1024)
163
+ max_local = int(math.ceil(num_cases / world_size)) if world_size > 0 else len(local_images)
164
+ # Build per-rank batch tensors
165
+ imgs_rank = []
166
+ idx_rank = []
167
+ has_rank = []
168
+ for j in range(max_local):
169
+ if j < len(local_images):
170
+ resized = local_images[j].resize(fixed_size, Image.BICUBIC)
171
+ img_np = np.asarray(resized).astype(np.uint8)
172
+ imgs_rank.append(torch.from_numpy(img_np))
173
+ idx_rank.append(local_indices[j])
174
+ has_rank.append(1)
175
+ else:
176
+ imgs_rank.append(torch.from_numpy(np.zeros((fixed_size[1], fixed_size[0], 3), dtype=np.uint8)))
177
+ idx_rank.append(-1)
178
+ has_rank.append(0)
179
+ imgs_rank_tensor = torch.stack([t.to(device=accelerator.device) for t in imgs_rank], dim=0) # [max_local, H, W, C]
180
+ idx_rank_tensor = torch.tensor(idx_rank, device=accelerator.device, dtype=torch.long) # [max_local]
181
+ has_rank_tensor = torch.tensor(has_rank, device=accelerator.device, dtype=torch.int) # [max_local]
182
+
183
+ gathered_has = accelerator.gather(has_rank_tensor) # [world * max_local]
184
+ gathered_idx = accelerator.gather(idx_rank_tensor) # [world * max_local]
185
+ gathered_imgs = accelerator.gather(imgs_rank_tensor) # [world * max_local, H, W, C]
186
+
187
+ if accelerator.is_main_process:
188
+ world = int(world_size)
189
+ slots = int(max_local)
190
+ try:
191
+ gathered_has = gathered_has.view(world, slots)
192
+ gathered_idx = gathered_idx.view(world, slots)
193
+ gathered_imgs = gathered_imgs.view(world, slots, fixed_size[1], fixed_size[0], 3)
194
+ except Exception:
195
+ # Fallback: treat as flat if reshape fails
196
+ gathered_has = gathered_has.view(-1, 1)
197
+ gathered_idx = gathered_idx.view(-1, 1)
198
+ gathered_imgs = gathered_imgs.view(-1, 1, fixed_size[1], fixed_size[0], 3)
199
+ world = int(gathered_has.shape[0])
200
+ slots = 1
201
+ for i in range(world):
202
+ for j in range(slots):
203
+ if int(gathered_has[i, j].item()) == 1:
204
+ idx = int(gathered_idx[i, j].item())
205
+ arr = gathered_imgs[i, j].cpu().numpy()
206
+ pil_img = Image.fromarray(arr.astype(np.uint8))
207
+ # Resize back to original validation image size
208
+ try:
209
+ orig = Image.open(val_imgs[idx]).convert("RGB")
210
+ pil_img = pil_img.resize(orig.size, Image.BICUBIC)
211
+ except Exception:
212
+ pass
213
+ results.append(pil_img)
214
+
215
+ # Log results (resize to 1024x1024 for saving or external trackers). Skip TensorBoard per request.
216
+ resized_for_log = [img.resize((1024, 1024), Image.BICUBIC) for img in results]
217
+ for tracker in accelerator.trackers:
218
+ phase_name = "test" if is_final_validation else "validation"
219
+ if tracker.name == "tensorboard":
220
+ continue
221
+ if tracker.name == "wandb":
222
+ tracker.log({
223
+ phase_name: [wandb.Image(image, caption=f"{i}: {prompts[i] if i < len(prompts) else ''}") for i, image in enumerate(resized_for_log)]
224
+ })
225
+
226
+ del pipeline
227
+ if torch.cuda.is_available():
228
+ torch.cuda.empty_cache()
229
+
230
+ return results
231
+
232
+
233
+ def save_with_retry(img: Image.Image, path: str, max_retries: int = 3) -> bool:
234
+ """Save PIL image with simple retry and exponential backoff to mitigate transient I/O errors."""
235
+ last_err = None
236
+ for attempt in range(max_retries):
237
+ try:
238
+ os.makedirs(os.path.dirname(path), exist_ok=True)
239
+ img.save(path)
240
+ return True
241
+ except OSError as e:
242
+ last_err = e
243
+ # Exponential backoff: 1.0, 1.5, 2.25 seconds ...
244
+ time.sleep(1.5 ** attempt)
245
+ logger.warning(f"Failed to save {path} after {max_retries} retries: {last_err}")
246
+ return False
247
+
248
+
249
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
250
+ text_encoder_config = transformers.PretrainedConfig.from_pretrained(
251
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
252
+ )
253
+ model_class = text_encoder_config.architectures[0]
254
+ if model_class == "CLIPTextModel":
255
+ from transformers import CLIPTextModel
256
+
257
+ return CLIPTextModel
258
+ elif model_class == "T5EncoderModel":
259
+ from transformers import T5EncoderModel
260
+
261
+ return T5EncoderModel
262
+ else:
263
+ raise ValueError(f"{model_class} is not supported.")
264
+
265
+
266
+ def parse_args(input_args=None):
267
+ parser = argparse.ArgumentParser(description="Training script for Flux Kontext with EasyControl.")
268
+ parser.add_argument("--mode", type=str, default=None, help="Controller mode; kept for compatibility.")
269
+
270
+ # Dataset arguments
271
+ parser.add_argument("--dataset_mode", type=str, default="mixed", choices=["interactive", "placement", "pexels", "mixed"],
272
+ help="Dataset mode: interactive, placement, pexels, or mixed")
273
+ parser.add_argument("--train_data_jsonl", type=str, default="/robby/share/Editing/lzc/HOI_v1/final_metadata.jsonl",
274
+ help="Path to interactive dataset JSONL")
275
+ parser.add_argument("--placement_data_jsonl", type=str, default="/robby/share/Editing/lzc/subject_placement/metadata_relight.jsonl",
276
+ help="Path to placement dataset JSONL")
277
+ parser.add_argument("--pexels_data_jsonl", type=str, default=None,
278
+ help="Path to pexels dataset JSONL")
279
+ parser.add_argument("--interactive_base_dir", type=str, default="/robby/share/Editing/lzc/HOI_v1",
280
+ help="Base directory for interactive dataset")
281
+ parser.add_argument("--placement_base_dir", type=str, default="/robby/share/Editing/lzc/subject_placement",
282
+ help="Base directory for placement dataset")
283
+ parser.add_argument("--pexels_base_dir", type=str, default=None,
284
+ help="Base directory for pexels dataset")
285
+ parser.add_argument("--pexels_relight_base_dir", type=str, default=None,
286
+ help="Base directory for pexels relighted images")
287
+ parser.add_argument("--seg_base_dir", type=str, default=None,
288
+ help="Directory containing segmentation maps for pexels dataset")
289
+ parser.add_argument("--interactive_weight", type=float, default=1.0,
290
+ help="Sampling weight for interactive dataset (default: 1.0)")
291
+ parser.add_argument("--placement_weight", type=float, default=1.0,
292
+ help="Sampling weight for placement dataset (default: 1.0)")
293
+ parser.add_argument("--pexels_weight", type=float, default=0.1,
294
+ help="Sampling weight for pexels dataset (default: 1.0)")
295
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="", required=False, help="Base model path")
296
+ parser.add_argument("--pretrained_lora_path", type=str, default=None, required=False, help="LoRA checkpoint to initialize from")
297
+ parser.add_argument("--revision", type=str, default=None, required=False, help="Revision of pretrained model")
298
+ parser.add_argument("--variant", type=str, default=None, help="Variant of the model files")
299
+
300
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
301
+ parser.add_argument("--max_sequence_length", type=int, default=128, help="Max sequence length for T5")
302
+ parser.add_argument("--kontext", type=str, default="enable")
303
+ parser.add_argument("--validation_prompt", type=str, nargs="+", default=None)
304
+ parser.add_argument("--validation_images", type=str, nargs="+", default=None, help="List of valiadation images")
305
+ parser.add_argument("--num_validation_images", type=int, default=4)
306
+ parser.add_argument("--validation_steps", type=int, default=20)
307
+
308
+ parser.add_argument("--ranks", type=int, nargs="+", default=[32], help="LoRA ranks")
309
+ parser.add_argument("--output_dir", type=str, default="", help="Output directory")
310
+ parser.add_argument("--seed", type=int, default=None)
311
+ parser.add_argument("--train_batch_size", type=int, default=1)
312
+ parser.add_argument("--num_train_epochs", type=int, default=50)
313
+ parser.add_argument("--max_train_steps", type=int, default=None)
314
+ parser.add_argument("--checkpointing_steps", type=int, default=1000)
315
+ parser.add_argument("--checkpoints_total_limit", type=int, default=None)
316
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
317
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
318
+ parser.add_argument("--gradient_checkpointing", action="store_true")
319
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
320
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Flux Kontext is guidance distilled")
321
+ parser.add_argument("--scale_lr", action="store_true", default=False)
322
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
323
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
324
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
325
+ parser.add_argument("--lr_power", type=float, default=1.0)
326
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
327
+ parser.add_argument("--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"])
328
+ parser.add_argument("--logit_mean", type=float, default=0.0)
329
+ parser.add_argument("--logit_std", type=float, default=1.0)
330
+ parser.add_argument("--mode_scale", type=float, default=1.29)
331
+ parser.add_argument("--optimizer", type=str, default="AdamW")
332
+ parser.add_argument("--use_8bit_adam", action="store_true")
333
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
334
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
335
+ parser.add_argument("--prodigy_beta3", type=float, default=None)
336
+ parser.add_argument("--prodigy_decouple", type=bool, default=True)
337
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04)
338
+ parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03)
339
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
340
+ parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True)
341
+ parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True)
342
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
343
+ parser.add_argument("--logging_dir", type=str, default="logs")
344
+ parser.add_argument("--cache_latents", action="store_true", default=False)
345
+ parser.add_argument("--report_to", type=str, default="tensorboard")
346
+ parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
347
+ parser.add_argument("--upcast_before_saving", action="store_true", default=False)
348
+
349
+ # Blending options for dataset pixel_values
350
+ parser.add_argument("--blend_pixel_values", action="store_true", help="Blend target/source into pixel_values using mask")
351
+ parser.add_argument("--blend_kernel", type=int, default=21, help="Gaussian blur kernel size (must be odd)")
352
+ parser.add_argument("--blend_sigma", type=float, default=10.0, help="Gaussian blur sigma")
353
+
354
+ if input_args is not None:
355
+ args = parser.parse_args(input_args)
356
+ else:
357
+ args = parser.parse_args()
358
+ return args
359
+
360
+
361
+ def main(args):
362
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
363
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
364
+
365
+ if args.output_dir is not None:
366
+ os.makedirs(args.output_dir, exist_ok=True)
367
+ os.makedirs(args.logging_dir, exist_ok=True)
368
+ logging_dir = Path(args.output_dir, args.logging_dir)
369
+
370
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
371
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
372
+ accelerator = Accelerator(
373
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
374
+ mixed_precision=args.mixed_precision,
375
+ log_with=args.report_to,
376
+ project_config=accelerator_project_config,
377
+ kwargs_handlers=[kwargs],
378
+ )
379
+
380
+ if torch.backends.mps.is_available():
381
+ accelerator.native_amp = False
382
+
383
+ if args.report_to == "wandb":
384
+ if not is_wandb_available():
385
+ raise ImportError("Install wandb for logging during training.")
386
+
387
+ logging.basicConfig(
388
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
389
+ datefmt="%m/%d/%Y %H:%M:%S",
390
+ level=logging.INFO,
391
+ )
392
+ logger.info(accelerator.state, main_process_only=False)
393
+ if accelerator.is_local_main_process:
394
+ transformers.utils.logging.set_verbosity_warning()
395
+ diffusers.utils.logging.set_verbosity_info()
396
+ else:
397
+ transformers.utils.logging.set_verbosity_error()
398
+ diffusers.utils.logging.set_verbosity_error()
399
+
400
+ if args.seed is not None:
401
+ set_seed(args.seed)
402
+
403
+ if accelerator.is_main_process and args.output_dir is not None:
404
+ os.makedirs(args.output_dir, exist_ok=True)
405
+
406
+ # Tokenizers
407
+ tokenizer_one = transformers.CLIPTokenizer.from_pretrained(
408
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
409
+ )
410
+ tokenizer_two = transformers.T5TokenizerFast.from_pretrained(
411
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision
412
+ )
413
+
414
+ # Text encoders
415
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder")
416
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2")
417
+
418
+ # Scheduler and models
419
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
420
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
421
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
422
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
423
+ transformer = FluxTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant)
424
+
425
+ # Train only LoRA adapters: freeze base transformer/text encoders/vae
426
+ transformer.requires_grad_(False)
427
+ vae.requires_grad_(False)
428
+ text_encoder_one.requires_grad_(False)
429
+ text_encoder_two.requires_grad_(False)
430
+
431
+ weight_dtype = torch.float32
432
+ if accelerator.mixed_precision == "fp16":
433
+ weight_dtype = torch.float16
434
+ elif accelerator.mixed_precision == "bf16":
435
+ weight_dtype = torch.bfloat16
436
+
437
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
438
+ raise ValueError("Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 or fp32 instead.")
439
+
440
+ vae.to(accelerator.device, dtype=weight_dtype)
441
+ transformer.to(accelerator.device, dtype=weight_dtype)
442
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
443
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
444
+
445
+ if args.gradient_checkpointing:
446
+ transformer.enable_gradient_checkpointing()
447
+
448
+ # Setup standard PEFT LoRA on FluxTransformer2DModel
449
+ # target_modules = [
450
+ # "attn.to_k",
451
+ # "attn.to_q",
452
+ # "attn.to_v",
453
+ # "attn.to_out.0",
454
+ # "attn.add_k_proj",
455
+ # "attn.add_q_proj",
456
+ # "attn.add_v_proj",
457
+ # "attn.to_add_out",
458
+ # "ff.net.0.proj",
459
+ # "ff.net.2",
460
+ # "ff_context.net.0.proj",
461
+ # "ff_context.net.2",
462
+ # ]
463
+ target_modules = [
464
+ "attn.to_k",
465
+ "attn.to_q",
466
+ "attn.to_v",
467
+ "attn.to_out.0",
468
+ "attn.add_k_proj",
469
+ "attn.add_q_proj",
470
+ "attn.add_v_proj",
471
+ "attn.to_add_out",
472
+ "ff.net.0.proj",
473
+ "ff.net.2",
474
+ "ff_context.net.0.proj",
475
+ "ff_context.net.2",
476
+ # ===========================================================
477
+ # 【补全部分 1】: 单流模块 (single_transformer_blocks) 的专属层
478
+ # ===========================================================
479
+ # 说明:单流块中的注意力层 (to_q, to_k, to_v) 已被上面的通用名称覆盖。
480
+ # 这里补充的是它们特有的 MLP 和输出层。
481
+ "proj_mlp",
482
+ "proj_out", # 这个名称也会匹配单流块各自的输出层和模型总输出层
483
+
484
+ # ===========================================================
485
+ # 【补全部分 2】: 所有的归一化 (Norm) 层
486
+ # ===========================================================
487
+ # 说明:这些层负责调整特征分布,对风格学习很重要。
488
+ # 使用 "linear" 可以一次性匹配所有以 ".linear" 结尾的Norm层。
489
+ "linear", # 匹配 norm1.linear, norm1_context.linear, norm.linear, norm_out.linear
490
+ ]
491
+ lora_rank = int(args.ranks[0]) if isinstance(args.ranks, list) and len(args.ranks) > 0 else 256
492
+ lora_config = LoraConfig(
493
+ r=lora_rank,
494
+ lora_alpha=lora_rank,
495
+ init_lora_weights="gaussian",
496
+ target_modules=target_modules,
497
+ )
498
+ transformer.add_adapter(lora_config)
499
+ transformer.train()
500
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
501
+
502
+ def unwrap_model(model):
503
+ model = accelerator.unwrap_model(model)
504
+ model = model._orig_mod if is_compiled_module(model) else model
505
+ return model
506
+
507
+ if args.resume_from_checkpoint:
508
+ path = args.resume_from_checkpoint
509
+ global_step = int(path.split("-")[-1])
510
+ initial_global_step = global_step
511
+ else:
512
+ initial_global_step = 0
513
+ global_step = 0
514
+ first_epoch = 0
515
+
516
+ if args.scale_lr:
517
+ args.learning_rate = (
518
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
519
+ )
520
+
521
+ if args.mixed_precision == "fp16":
522
+ models = [transformer]
523
+ cast_training_params(models, dtype=torch.float32)
524
+
525
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
526
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
527
+ # print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
528
+
529
+ optimizer_class = torch.optim.AdamW
530
+ optimizer = optimizer_class(
531
+ [transformer_parameters_with_lr],
532
+ betas=(args.adam_beta1, args.adam_beta2),
533
+ weight_decay=args.adam_weight_decay,
534
+ eps=args.adam_epsilon,
535
+ )
536
+
537
+ tokenizers = [tokenizer_one, tokenizer_two]
538
+ text_encoders = [text_encoder_one, text_encoder_two]
539
+
540
+ # Create dataset based on mode
541
+ if args.dataset_mode == "mixed":
542
+ # Mixed mode: combine all available datasets
543
+ train_dataset = make_mixed_dataset(
544
+ args,
545
+ tokenizers,
546
+ interactive_jsonl_path=args.train_data_jsonl,
547
+ placement_jsonl_path=args.placement_data_jsonl,
548
+ pexels_jsonl_path=args.pexels_data_jsonl,
549
+ interactive_base_dir=args.interactive_base_dir,
550
+ placement_base_dir=args.placement_base_dir,
551
+ pexels_base_dir=args.pexels_base_dir,
552
+ interactive_weight=args.interactive_weight,
553
+ placement_weight=args.placement_weight,
554
+ pexels_weight=args.pexels_weight,
555
+ accelerator=accelerator
556
+ )
557
+ weights_str = []
558
+ if args.train_data_jsonl:
559
+ weights_str.append(f"Interactive: {args.interactive_weight:.2f}")
560
+ if args.placement_data_jsonl:
561
+ weights_str.append(f"Placement: {args.placement_weight:.2f}")
562
+ if args.pexels_data_jsonl:
563
+ weights_str.append(f"Pexels: {args.pexels_weight:.2f}")
564
+ logger.info(f"Mixed dataset created with weights - {', '.join(weights_str)}")
565
+ elif args.dataset_mode == "pexels":
566
+ if not args.pexels_data_jsonl:
567
+ raise ValueError("pexels_data_jsonl must be provided for pexels mode")
568
+ train_dataset = make_pexels_dataset_subjects(args, tokenizers, accelerator)
569
+ elif args.dataset_mode == "placement":
570
+ if not args.placement_data_jsonl:
571
+ raise ValueError("placement_data_jsonl must be provided for placement mode")
572
+ train_dataset = make_placement_dataset_subjects(args, tokenizers, accelerator)
573
+ else: # interactive mode
574
+ train_dataset = make_interactive_dataset_subjects(args, tokenizers, accelerator)
575
+
576
+ train_dataloader = torch.utils.data.DataLoader(
577
+ train_dataset,
578
+ batch_size=args.train_batch_size,
579
+ shuffle=True,
580
+ collate_fn=collate_fn,
581
+ num_workers=args.dataloader_num_workers,
582
+ )
583
+
584
+ vae_config_shift_factor = vae.config.shift_factor
585
+ vae_config_scaling_factor = vae.config.scaling_factor
586
+
587
+ overrode_max_train_steps = False
588
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
589
+ if args.resume_from_checkpoint:
590
+ first_epoch = global_step // num_update_steps_per_epoch
591
+ if args.max_train_steps is None:
592
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
593
+ overrode_max_train_steps = True
594
+
595
+ lr_scheduler = get_scheduler(
596
+ args.lr_scheduler,
597
+ optimizer=optimizer,
598
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
599
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
600
+ num_cycles=args.lr_num_cycles,
601
+ power=args.lr_power,
602
+ )
603
+
604
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
605
+ transformer, optimizer, train_dataloader, lr_scheduler
606
+ )
607
+
608
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
609
+ if overrode_max_train_steps:
610
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
611
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
612
+
613
+ # Sanitize config for TensorBoard hparams (only allow int/float/bool/str/tensor). Others are stringified if possible; otherwise dropped
614
+ def _sanitize_hparams(config_dict):
615
+ sanitized = {}
616
+ for key, value in dict(config_dict).items():
617
+ try:
618
+ if value is None:
619
+ continue
620
+ # numpy scalar types
621
+ if isinstance(value, (np.integer,)):
622
+ sanitized[key] = int(value)
623
+ elif isinstance(value, (np.floating,)):
624
+ sanitized[key] = float(value)
625
+ elif isinstance(value, (int, float, bool, str)):
626
+ sanitized[key] = value
627
+ elif isinstance(value, Path):
628
+ sanitized[key] = str(value)
629
+ elif isinstance(value, (list, tuple)):
630
+ # stringify simple sequences; skip if fails
631
+ sanitized[key] = str(value)
632
+ else:
633
+ # best-effort stringify
634
+ sanitized[key] = str(value)
635
+ except Exception:
636
+ # skip unconvertible entries
637
+ continue
638
+ return sanitized
639
+
640
+ if accelerator.is_main_process:
641
+ tracker_name = "Easy_Control_Kontext"
642
+ accelerator.init_trackers(tracker_name, config=_sanitize_hparams(vars(args)))
643
+
644
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
645
+ logger.info("***** Running training *****")
646
+ logger.info(f" Num examples = {len(train_dataset)}")
647
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
648
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
649
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
650
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
651
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
652
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
653
+
654
+ progress_bar = tqdm(
655
+ range(0, args.max_train_steps),
656
+ initial=initial_global_step,
657
+ desc="Steps",
658
+ disable=not accelerator.is_local_main_process,
659
+ )
660
+
661
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
662
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
663
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
664
+ timesteps = timesteps.to(accelerator.device)
665
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
666
+ sigma = sigmas[step_indices].flatten()
667
+ while len(sigma.shape) < n_dim:
668
+ sigma = sigma.unsqueeze(-1)
669
+ return sigma
670
+
671
+ # Kontext specifics
672
+ vae_scale_factor = 8 # Kontext uses 8x VAE factor; pack/unpack uses additional 2x in methods
673
+
674
+ for epoch in range(first_epoch, args.num_train_epochs):
675
+ transformer.train()
676
+ for step, batch in enumerate(train_dataloader):
677
+ models_to_accumulate = [transformer]
678
+ with accelerator.accumulate(models_to_accumulate):
679
+ tokens = [batch["text_ids_1"], batch["text_ids_2"]]
680
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
681
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
682
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
683
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
684
+
685
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
686
+ height_ = 2 * (int(pixel_values.shape[-2]) // (vae_scale_factor * 2))
687
+ width_ = 2 * (int(pixel_values.shape[-1]) // (vae_scale_factor * 2))
688
+
689
+ model_input = vae.encode(pixel_values).latent_dist.sample()
690
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
691
+ model_input = model_input.to(dtype=weight_dtype)
692
+
693
+ # Prepare latent ids for transformer (positional encodings)
694
+ latent_image_ids = FluxKontextPipeline._prepare_latent_image_ids(
695
+ batch_size=model_input.shape[0],
696
+ height=model_input.shape[2] // 2,
697
+ width=model_input.shape[3] // 2,
698
+ device=accelerator.device,
699
+ dtype=weight_dtype,
700
+ )
701
+
702
+ noise = torch.randn_like(model_input)
703
+ bsz = model_input.shape[0]
704
+
705
+ u = compute_density_for_timestep_sampling(
706
+ weighting_scheme=args.weighting_scheme,
707
+ batch_size=bsz,
708
+ logit_mean=args.logit_mean,
709
+ logit_std=args.logit_std,
710
+ mode_scale=args.mode_scale,
711
+ )
712
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
713
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
714
+
715
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
716
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
717
+
718
+ packed_noisy_model_input = FluxKontextPipeline._pack_latents(
719
+ noisy_model_input,
720
+ batch_size=model_input.shape[0],
721
+ num_channels_latents=model_input.shape[1],
722
+ height=model_input.shape[2],
723
+ width=model_input.shape[3],
724
+ )
725
+
726
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
727
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
728
+ guidance = guidance.expand(model_input.shape[0])
729
+ else:
730
+ guidance = None
731
+
732
+ # If kontext editing is enabled, append source image latents to the sequence
733
+ latent_model_input = packed_noisy_model_input
734
+ if args.kontext == "enable":
735
+ source_pixel_values = batch["source_pixel_values"].to(dtype=vae.dtype)
736
+ source_image_latents = vae.encode(source_pixel_values).latent_dist.sample()
737
+ source_image_latents = (source_image_latents - vae_config_shift_factor) * vae_config_scaling_factor
738
+ image_latent_h, image_latent_w = source_image_latents.shape[2:]
739
+ packed_image_latents = FluxKontextPipeline._pack_latents(
740
+ source_image_latents,
741
+ batch_size=source_image_latents.shape[0],
742
+ num_channels_latents=source_image_latents.shape[1],
743
+ height=image_latent_h,
744
+ width=image_latent_w,
745
+ )
746
+ source_image_ids = FluxKontextPipeline._prepare_latent_image_ids(
747
+ batch_size=source_image_latents.shape[0],
748
+ height=image_latent_h // 2,
749
+ width=image_latent_w // 2,
750
+ device=accelerator.device,
751
+ dtype=weight_dtype,
752
+ )
753
+ source_image_ids[..., 0] = 1
754
+ latent_model_input = torch.cat([latent_model_input, packed_image_latents], dim=1)
755
+ latent_image_ids = torch.cat([latent_image_ids, source_image_ids], dim=0)
756
+
757
+ # Forward transformer with packed latents and ids
758
+ model_pred = transformer(
759
+ hidden_states=latent_model_input,
760
+ timestep=timesteps / 1000,
761
+ guidance=guidance,
762
+ pooled_projections=pooled_prompt_embeds,
763
+ encoder_hidden_states=prompt_embeds,
764
+ txt_ids=text_ids,
765
+ img_ids=latent_image_ids,
766
+ return_dict=False,
767
+ )[0]
768
+
769
+ model_pred = model_pred[:, : packed_noisy_model_input.size(1)]
770
+
771
+ model_pred = FluxKontextPipeline._unpack_latents(
772
+ model_pred,
773
+ height=int(pixel_values.shape[-2]),
774
+ width=int(pixel_values.shape[-1]),
775
+ vae_scale_factor=vae_scale_factor,
776
+ )
777
+
778
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
779
+ target = noise - model_input
780
+
781
+ loss = torch.mean((weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1)
782
+ loss = loss.mean()
783
+ accelerator.backward(loss)
784
+ if accelerator.sync_gradients:
785
+ params_to_clip = (transformer.parameters())
786
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
787
+
788
+ optimizer.step()
789
+ lr_scheduler.step()
790
+ optimizer.zero_grad()
791
+
792
+ if accelerator.sync_gradients:
793
+ progress_bar.update(1)
794
+ global_step += 1
795
+
796
+ if accelerator.is_main_process:
797
+ if global_step % args.checkpointing_steps == 0:
798
+ if args.checkpoints_total_limit is not None:
799
+ checkpoints = os.listdir(args.output_dir)
800
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
801
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
802
+ if len(checkpoints) >= args.checkpoints_total_limit:
803
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
804
+ removing_checkpoints = checkpoints[0:num_to_remove]
805
+ logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
806
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
807
+ for removing_checkpoint in removing_checkpoints:
808
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
809
+ shutil.rmtree(removing_checkpoint)
810
+
811
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
812
+ os.makedirs(save_path, exist_ok=True)
813
+ unwrapped = accelerator.unwrap_model(transformer)
814
+ peft_state = get_peft_model_state_dict(unwrapped)
815
+ # Convert PEFT state dict to diffusers LoRA format for transformer
816
+ diffusers_lora = convert_state_dict_to_diffusers(peft_state)
817
+ save_file(diffusers_lora, os.path.join(save_path, "pytorch_lora_weights.safetensors"))
818
+ logger.info(f"Saved state to {save_path}")
819
+
820
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
821
+ progress_bar.set_postfix(**logs)
822
+ accelerator.log(logs, step=global_step)
823
+
824
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
825
+ # Create pipeline on every rank to run validation in parallel
826
+ pipeline = FluxKontextPipeline.from_pretrained(
827
+ args.pretrained_model_name_or_path,
828
+ vae=vae,
829
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
830
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
831
+ transformer=accelerator.unwrap_model(transformer),
832
+ revision=args.revision,
833
+ variant=args.variant,
834
+ torch_dtype=weight_dtype,
835
+ )
836
+
837
+ pipeline_args = {
838
+ "prompt": args.validation_prompt,
839
+ "guidance_scale": 3.5,
840
+ "num_inference_steps": 20,
841
+ "max_sequence_length": 128,
842
+ }
843
+
844
+ images = log_validation(
845
+ pipeline=pipeline,
846
+ args=args,
847
+ accelerator=accelerator,
848
+ pipeline_args=pipeline_args,
849
+ step=global_step,
850
+ torch_dtype=weight_dtype,
851
+ )
852
+
853
+ # Only main process saves/logs
854
+ if accelerator.is_main_process:
855
+ save_path = os.path.join(args.output_dir, "validation")
856
+ os.makedirs(save_path, exist_ok=True)
857
+ save_folder = os.path.join(save_path, f"checkpoint-{global_step}")
858
+ os.makedirs(save_folder, exist_ok=True)
859
+ for idx, img in enumerate(images):
860
+ out_path = os.path.join(save_folder, f"{idx}.jpg")
861
+ save_with_retry(img, out_path)
862
+ del pipeline
863
+
864
+ accelerator.wait_for_everyone()
865
+ accelerator.end_training()
866
+
867
+
868
+ if __name__ == "__main__":
869
+ args = parse_args()
870
+ main(args)
871
+
util.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import Counter
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ import cv2 # OpenCV
6
+ import torch
7
+ import re
8
+ import io
9
+ import base64
10
+ from PIL import Image, ImageOps
11
+ from src.pipeline_flux_kontext_control import PREFERRED_KONTEXT_RESOLUTIONS
12
+
13
+ def get_bounding_box_from_mask(mask, padded=False):
14
+ mask = mask.squeeze()
15
+ rows, cols = torch.where(mask > 0.5)
16
+ if len(rows) == 0 or len(cols) == 0:
17
+ return (0, 0, 0, 0)
18
+ height, width = mask.shape
19
+ if padded:
20
+ padded_size = max(width, height)
21
+ if width < height:
22
+ offset_x = (padded_size - width) / 2
23
+ offset_y = 0
24
+ else:
25
+ offset_y = (padded_size - height) / 2
26
+ offset_x = 0
27
+ top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
28
+ bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
29
+ top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
30
+ bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
31
+ else:
32
+ offset_x = 0
33
+ offset_y = 0
34
+
35
+ top_left_x = round(float(torch.min(cols).item() / width), 3)
36
+ bottom_right_x = round(float(torch.max(cols).item() / width), 3)
37
+ top_left_y = round(float(torch.min(rows).item() / height), 3)
38
+ bottom_right_y = round(float(torch.max(rows).item() / height), 3)
39
+
40
+
41
+ return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
42
+
43
+ def extract_bbox(text):
44
+ pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
45
+ match = re.search(pattern, text)
46
+ return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))
47
+
48
+ def resize_bbox(bbox, width_ratio, height_ratio):
49
+ x1, y1, x2, y2 = bbox
50
+ new_x1 = int(x1 * width_ratio)
51
+ new_y1 = int(y1 * height_ratio)
52
+ new_x2 = int(x2 * width_ratio)
53
+ new_y2 = int(y2 * height_ratio)
54
+
55
+ return (new_x1, new_y1, new_x2, new_y2)
56
+
57
+
58
+ def tensor_to_base64(tensor, quality=80, method=6):
59
+ tensor = tensor.squeeze(0).clone().detach().cpu()
60
+
61
+ if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
62
+ tensor *= 255
63
+ tensor = tensor.to(torch.uint8)
64
+
65
+ if tensor.ndim == 2: # 灰度图像
66
+ pil_image = Image.fromarray(tensor.numpy(), 'L')
67
+ pil_image = pil_image.convert('RGB')
68
+ elif tensor.ndim == 3:
69
+ if tensor.shape[2] == 1: # 单通道
70
+ pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
71
+ pil_image = pil_image.convert('RGB')
72
+ elif tensor.shape[2] == 3: # RGB
73
+ pil_image = Image.fromarray(tensor.numpy(), 'RGB')
74
+ elif tensor.shape[2] == 4: # RGBA
75
+ pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
76
+ else:
77
+ raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
78
+ else:
79
+ raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
80
+
81
+ buffered = io.BytesIO()
82
+ pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
83
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
84
+ return img_str
85
+
86
+ def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
87
+ image = Image.open(image_path)
88
+ image = ImageOps.exif_transpose(image)
89
+
90
+ if image.mode == 'RGBA':
91
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
92
+ image = Image.alpha_composite(background, image)
93
+ image = image.convert(convert_to)
94
+ image_array = np.array(image).astype(np.float32) / 255.0
95
+
96
+ if has_alpha and convert_to == 'RGBA':
97
+ image_tensor = torch.from_numpy(image_array)[None,]
98
+ else:
99
+ if len(image_array.shape) == 3 and image_array.shape[2] > 3:
100
+ image_array = image_array[:, :, :3]
101
+ image_tensor = torch.from_numpy(image_array)[None,]
102
+
103
+ return image_tensor
104
+
105
+ def process_background(base64_image, convert_to='RGB', size=None):
106
+ image_data = read_base64_image(base64_image)
107
+ image = Image.open(image_data)
108
+ image = ImageOps.exif_transpose(image)
109
+ image = image.convert(convert_to)
110
+
111
+ # Select preferred size by closest aspect ratio, then snap to multiple_of
112
+ w0, h0 = image.size
113
+ aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
114
+ # Choose the (w, h) whose aspect ratio is closest to the input
115
+ _, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
116
+ multiple_of = 16 # default: vae_scale_factor (8) * 2
117
+ tw = (tw // multiple_of) * multiple_of
118
+ th = (th // multiple_of) * multiple_of
119
+
120
+ if (w0, h0) != (tw, th):
121
+ image = image.resize((tw, th), resample=Image.BICUBIC)
122
+
123
+ image_array = np.array(image).astype(np.uint8)
124
+ image_tensor = torch.from_numpy(image_array)[None,]
125
+ return image_tensor
126
+
127
+ def read_base64_image(base64_image):
128
+ if base64_image.startswith("data:image/png;base64,"):
129
+ base64_image = base64_image.split(",")[1]
130
+ elif base64_image.startswith("data:image/jpeg;base64,"):
131
+ base64_image = base64_image.split(",")[1]
132
+ elif base64_image.startswith("data:image/webp;base64,"):
133
+ base64_image = base64_image.split(",")[1]
134
+ else:
135
+ raise ValueError("Unsupported image format.")
136
+ image_data = base64.b64decode(base64_image)
137
+ return io.BytesIO(image_data)
138
+
139
+ def create_alpha_mask(image_path):
140
+ """Create an alpha mask from the alpha channel of an image."""
141
+ image = Image.open(image_path)
142
+ image = ImageOps.exif_transpose(image)
143
+ mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
144
+ if 'A' in image.getbands():
145
+ alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
146
+ mask[0] = 1.0 - torch.from_numpy(alpha_channel)
147
+ return mask
148
+
149
+ def get_mask_bbox(mask_tensor, padding=10):
150
+ assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
151
+ _, H, W = mask_tensor.shape
152
+ mask_2d = mask_tensor.squeeze(0)
153
+
154
+ y_coords, x_coords = torch.where(mask_2d > 0)
155
+
156
+ if len(y_coords) == 0:
157
+ return None
158
+
159
+ x_min = int(torch.min(x_coords))
160
+ y_min = int(torch.min(y_coords))
161
+ x_max = int(torch.max(x_coords))
162
+ y_max = int(torch.max(y_coords))
163
+
164
+ x_min = max(0, x_min - padding)
165
+ y_min = max(0, y_min - padding)
166
+ x_max = min(W - 1, x_max + padding)
167
+ y_max = min(H - 1, y_max + padding)
168
+
169
+ return x_min, y_min, x_max, y_max
170
+
171
+ def tensor_to_pil(tensor):
172
+ tensor = tensor.squeeze(0).clone().detach().cpu()
173
+ if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
174
+ if tensor.max() <= 1.0:
175
+ tensor *= 255
176
+ tensor = tensor.to(torch.uint8)
177
+
178
+ if tensor.ndim == 2: # 灰度图像 [H, W]
179
+ return Image.fromarray(tensor.numpy(), 'L')
180
+ elif tensor.ndim == 3:
181
+ if tensor.shape[2] == 1: # 单通道 [H, W, 1]
182
+ return Image.fromarray(tensor.numpy().squeeze(2), 'L')
183
+ elif tensor.shape[2] >= 3: # RGB [H, W, 3]
184
+ return Image.fromarray(tensor.numpy(), 'RGB')
185
+ else:
186
+ raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
187
+ else:
188
+ raise ValueError(f"不支持的tensor维度: {tensor.ndim}")
utils_node.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ import numpy as np
5
+ from tqdm import trange
6
+ import torchvision.transforms as T
7
+ import torch.nn.functional as F
8
+ from typing import Tuple
9
+ import scipy.ndimage
10
+ import cv2
11
+ from train.src.condition.util import HWC3, common_input_validate
12
+
13
+ def check_image_mask(image, mask, name):
14
+ if len(image.shape) < 4:
15
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
16
+ image = image[None,:,:,:]
17
+
18
+ if len(mask.shape) > 3:
19
+ # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
20
+ # take first mask, red channel
21
+ mask = (mask[:,:,:,0])[:,:,:]
22
+ elif len(mask.shape) < 3:
23
+ # mask tensor shape should be [B, H, W] but batch somehow is missing
24
+ mask = mask[None,:,:]
25
+
26
+ if image.shape[0] > mask.shape[0]:
27
+ print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
28
+ if mask.shape[0] == 1:
29
+ print(name, "will copy the mask to fill batch")
30
+ mask = torch.cat([mask] * image.shape[0], dim=0)
31
+ else:
32
+ print(name, "will add empty masks to fill batch")
33
+ empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
34
+ mask = torch.cat([mask, empty_mask], dim=0)
35
+ elif image.shape[0] < mask.shape[0]:
36
+ print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
37
+ mask = mask[:image.shape[0],:,:]
38
+
39
+ return (image, mask)
40
+
41
+
42
+ def cv2_resize_shortest_edge(image, size):
43
+ h, w = image.shape[:2]
44
+ if h < w:
45
+ new_h = size
46
+ new_w = int(round(w / h * size))
47
+ else:
48
+ new_w = size
49
+ new_h = int(round(h / w * size))
50
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
51
+ return resized_image
52
+
53
+ def apply_color(img, res=512):
54
+ img = cv2_resize_shortest_edge(img, res)
55
+ h, w = img.shape[:2]
56
+
57
+ input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
58
+ input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
59
+ return input_img_color
60
+
61
+ #Color T2I like multiples-of-64, upscale methods are fixed.
62
+ class ColorDetector:
63
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
64
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
65
+ input_image = HWC3(input_image)
66
+ detected_map = HWC3(apply_color(input_image, detect_resolution))
67
+
68
+ if output_type == "pil":
69
+ detected_map = Image.fromarray(detected_map)
70
+
71
+ return detected_map
72
+
73
+
74
+ class InpaintPreprocessor:
75
+ def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
76
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
77
+ mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
78
+ image = image.clone()
79
+ if black_pixel_for_xinsir_cn:
80
+ masked_pixel = 0.0
81
+ else:
82
+ masked_pixel = -1.0
83
+ image[mask > 0.5] = masked_pixel
84
+ return (image,)
85
+
86
+
87
+ class BlendInpaint:
88
+ def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
89
+
90
+ original, mask = check_image_mask(original, mask, 'Blend Inpaint')
91
+
92
+ if len(inpaint.shape) < 4:
93
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
94
+ inpaint = inpaint[None,:,:,:]
95
+
96
+ if inpaint.shape[0] < original.shape[0]:
97
+ print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
98
+ original= original[:inpaint.shape[0],:,:]
99
+ mask = mask[:inpaint.shape[0],:,:]
100
+
101
+ if inpaint.shape[0] > original.shape[0]:
102
+ # batch over inpaint
103
+ count = 0
104
+ original_list = []
105
+ mask_list = []
106
+ origin_list = []
107
+ while (count < inpaint.shape[0]):
108
+ for i in range(original.shape[0]):
109
+ original_list.append(original[i][None,:,:,:])
110
+ mask_list.append(mask[i][None,:,:])
111
+ if origin is not None:
112
+ origin_list.append(origin[i][None,:])
113
+ count += 1
114
+ if count >= inpaint.shape[0]:
115
+ break
116
+ original = torch.concat(original_list, dim=0)
117
+ mask = torch.concat(mask_list, dim=0)
118
+ if origin is not None:
119
+ origin = torch.concat(origin_list, dim=0)
120
+
121
+ if kernel % 2 == 0:
122
+ kernel += 1
123
+ transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
124
+
125
+ ret = []
126
+ blurred = []
127
+ for i in range(inpaint.shape[0]):
128
+ if origin is None:
129
+ blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
130
+ blurred.append(blurred_mask[0])
131
+
132
+ result = torch.nn.functional.interpolate(
133
+ inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
134
+ size=(
135
+ original[i].shape[0],
136
+ original[i].shape[1],
137
+ )
138
+ ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
139
+ else:
140
+ # got mask from CutForInpaint
141
+ height, width, _ = original[i].shape
142
+ x0 = origin[i][0].item()
143
+ y0 = origin[i][1].item()
144
+
145
+ if mask[i].shape[0] < height or mask[i].shape[1] < width:
146
+ padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
147
+ y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
148
+ else:
149
+ padded_mask = mask[i]
150
+ blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
151
+ blurred.append(blurred_mask[0][0])
152
+
153
+ result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
154
+ y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
155
+ result = result[None,:,:,:].to(original.device).to(original.dtype)
156
+
157
+ ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
158
+
159
+ return (torch.stack(ret), torch.stack(blurred), )
160
+
161
+
162
+ def resize_mask(mask, shape):
163
+ return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
164
+
165
+ class JoinImageWithAlpha:
166
+ def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
167
+ batch_size = min(len(image), len(alpha))
168
+ out_images = []
169
+
170
+ alpha = 1.0 - resize_mask(alpha, image.shape[1:])
171
+ for i in range(batch_size):
172
+ out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
173
+
174
+ result = (torch.stack(out_images),)
175
+ return result
176
+
177
+ class GrowMask:
178
+ def expand_mask(self, mask, expand, tapered_corners):
179
+ c = 0 if tapered_corners else 1
180
+ kernel = np.array([[c, 1, c],
181
+ [1, 1, 1],
182
+ [c, 1, c]])
183
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
184
+ out = []
185
+ for m in mask:
186
+ output = m.numpy()
187
+ for _ in range(abs(expand)):
188
+ if expand < 0:
189
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
190
+ else:
191
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
192
+ output = torch.from_numpy(output)
193
+ out.append(output)
194
+ return (torch.stack(out, dim=0),)
195
+
196
+ class InvertMask:
197
+ def invert(self, mask):
198
+ out = 1.0 - mask
199
+ return (out,)