Spaces:
Running
Running
bug fixes
Browse files- VITON_Dataset.py +1 -1
- interface.py +2 -2
- model.py → load_model.py +0 -0
- pipeline.py +2 -2
- test.ipynb +279 -92
VITON_Dataset.py
CHANGED
|
@@ -32,7 +32,7 @@ class InferenceDataset(Dataset):
|
|
| 32 |
|
| 33 |
class VITONHDTestDataset(InferenceDataset):
|
| 34 |
def load_data(self):
|
| 35 |
-
name= "train" if self.args.is_train else "
|
| 36 |
assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist."
|
| 37 |
with open(pair_txt, 'r') as f:
|
| 38 |
lines = f.readlines()
|
|
|
|
| 32 |
|
| 33 |
class VITONHDTestDataset(InferenceDataset):
|
| 34 |
def load_data(self):
|
| 35 |
+
name= "train" if self.args.is_train else "samples"
|
| 36 |
assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist."
|
| 37 |
with open(pair_txt, 'r') as f:
|
| 38 |
lines = f.readlines()
|
interface.py
CHANGED
|
@@ -4,7 +4,7 @@ from PIL import Image
|
|
| 4 |
from transformers import CLIPTokenizer
|
| 5 |
|
| 6 |
# Import your existing model and pipeline modules
|
| 7 |
-
import
|
| 8 |
import pipeline
|
| 9 |
|
| 10 |
# Device Configuration
|
|
@@ -24,7 +24,7 @@ print(f"Using device: {DEVICE}")
|
|
| 24 |
# Load tokenizer and models
|
| 25 |
tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
|
| 26 |
model_file = "inkpunk-diffusion-v1.ckpt"
|
| 27 |
-
models =
|
| 28 |
# models=None
|
| 29 |
|
| 30 |
def generate_image(
|
|
|
|
| 4 |
from transformers import CLIPTokenizer
|
| 5 |
|
| 6 |
# Import your existing model and pipeline modules
|
| 7 |
+
import load_model
|
| 8 |
import pipeline
|
| 9 |
|
| 10 |
# Device Configuration
|
|
|
|
| 24 |
# Load tokenizer and models
|
| 25 |
tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
|
| 26 |
model_file = "inkpunk-diffusion-v1.ckpt"
|
| 27 |
+
models = load_model.preload_models_from_standard_weights(model_file, DEVICE)
|
| 28 |
# models=None
|
| 29 |
|
| 30 |
def generate_image(
|
model.py → load_model.py
RENAMED
|
File without changes
|
pipeline.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from ddpm import DDPMSampler
|
| 8 |
from PIL import Image
|
| 9 |
-
import
|
| 10 |
from utils import check_inputs, prepare_image, prepare_mask_image
|
| 11 |
|
| 12 |
WIDTH = 512
|
|
@@ -293,7 +293,7 @@ if __name__ == "__main__":
|
|
| 293 |
mask = Image.open("agnostic_mask.png").convert("L")
|
| 294 |
|
| 295 |
# Load models
|
| 296 |
-
models=
|
| 297 |
|
| 298 |
# Generate image
|
| 299 |
generated_image = generate(
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from ddpm import DDPMSampler
|
| 8 |
from PIL import Image
|
| 9 |
+
import load_model
|
| 10 |
from utils import check_inputs, prepare_image, prepare_mask_image
|
| 11 |
|
| 12 |
WIDTH = 512
|
|
|
|
| 293 |
mask = Image.open("agnostic_mask.png").convert("L")
|
| 294 |
|
| 295 |
# Load models
|
| 296 |
+
models=load_model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda")
|
| 297 |
|
| 298 |
# Generate image
|
| 299 |
generated_image = generate(
|
test.ipynb
CHANGED
|
@@ -2,67 +2,52 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
-
"id": "
|
| 7 |
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"Cloning into 'CatVTON'...\n",
|
| 14 |
-
"remote: Enumerating objects: 1348, done.\u001b[K\n",
|
| 15 |
-
"remote: Counting objects: 100% (62/62), done.\u001b[K\n",
|
| 16 |
-
"remote: Compressing objects: 100% (29/29), done.\u001b[K\n",
|
| 17 |
-
"remote: Total 1348 (delta 51), reused 33 (delta 33), pack-reused 1286 (from 3)\u001b[K\n",
|
| 18 |
-
"Receiving objects: 100% (1348/1348), 16.74 MiB | 42.65 MiB/s, done.\n",
|
| 19 |
-
"Resolving deltas: 100% (449/449), done.\n"
|
| 20 |
-
]
|
| 21 |
-
}
|
| 22 |
-
],
|
| 23 |
-
"source": [
|
| 24 |
-
"!git clone https://github.com/Zheng-Chong/CatVTON.git"
|
| 25 |
-
]
|
| 26 |
},
|
| 27 |
{
|
| 28 |
"cell_type": "code",
|
| 29 |
-
"execution_count":
|
| 30 |
-
"id": "
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [
|
| 33 |
{
|
| 34 |
-
"
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
}
|
| 40 |
],
|
| 41 |
-
"source": [
|
| 42 |
-
"ls /kaggle/input/viton-hd-dataset"
|
| 43 |
-
]
|
| 44 |
},
|
| 45 |
{
|
| 46 |
"cell_type": "code",
|
| 47 |
-
"execution_count":
|
| 48 |
-
"id": "
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [
|
| 51 |
{
|
| 52 |
"name": "stdout",
|
| 53 |
"output_type": "stream",
|
| 54 |
"text": [
|
| 55 |
-
"/
|
| 56 |
]
|
| 57 |
}
|
| 58 |
],
|
| 59 |
"source": [
|
| 60 |
-
"
|
| 61 |
]
|
| 62 |
},
|
| 63 |
{
|
| 64 |
"cell_type": "code",
|
| 65 |
-
"execution_count":
|
| 66 |
"id": "dc0f36f4",
|
| 67 |
"metadata": {},
|
| 68 |
"outputs": [
|
|
@@ -71,12 +56,12 @@
|
|
| 71 |
"output_type": "stream",
|
| 72 |
"text": [
|
| 73 |
"Cloning into 'stable-diffusion'...\n",
|
| 74 |
-
"remote: Enumerating objects:
|
| 75 |
-
"remote: Counting objects: 100% (
|
| 76 |
-
"remote: Compressing objects: 100% (
|
| 77 |
-
"remote: Total
|
| 78 |
-
"Receiving objects: 100% (
|
| 79 |
-
"Resolving deltas: 100% (
|
| 80 |
]
|
| 81 |
}
|
| 82 |
],
|
|
@@ -86,7 +71,7 @@
|
|
| 86 |
},
|
| 87 |
{
|
| 88 |
"cell_type": "code",
|
| 89 |
-
"execution_count":
|
| 90 |
"id": "a0bf01ab",
|
| 91 |
"metadata": {},
|
| 92 |
"outputs": [
|
|
@@ -94,7 +79,7 @@
|
|
| 94 |
"name": "stdout",
|
| 95 |
"output_type": "stream",
|
| 96 |
"text": [
|
| 97 |
-
"/kaggle/working/stable-diffusion
|
| 98 |
]
|
| 99 |
}
|
| 100 |
],
|
|
@@ -104,7 +89,7 @@
|
|
| 104 |
},
|
| 105 |
{
|
| 106 |
"cell_type": "code",
|
| 107 |
-
"execution_count":
|
| 108 |
"id": "1401cd56",
|
| 109 |
"metadata": {},
|
| 110 |
"outputs": [
|
|
@@ -112,25 +97,25 @@
|
|
| 112 |
"name": "stdout",
|
| 113 |
"output_type": "stream",
|
| 114 |
"text": [
|
| 115 |
-
"--2025-06-
|
| 116 |
-
"Resolving huggingface.co (huggingface.co)...
|
| 117 |
-
"Connecting to huggingface.co (huggingface.co)|
|
| 118 |
"HTTP request sent, awaiting response... 307 Temporary Redirect\n",
|
| 119 |
"Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
|
| 120 |
-
"--2025-06-
|
| 121 |
"Reusing existing connection to huggingface.co:443.\n",
|
| 122 |
"HTTP request sent, awaiting response... 302 Found\n",
|
| 123 |
-
"Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=
|
| 124 |
-
"--2025-06-
|
| 125 |
-
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)...
|
| 126 |
-
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|
|
| 127 |
"HTTP request sent, awaiting response... 200 OK\n",
|
| 128 |
"Length: 4265437280 (4.0G) [binary/octet-stream]\n",
|
| 129 |
"Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
|
| 130 |
"\n",
|
| 131 |
-
"sd-v1-5-inpainting. 100%[===================>] 3.97G
|
| 132 |
"\n",
|
| 133 |
-
"2025-06-
|
| 134 |
"\n"
|
| 135 |
]
|
| 136 |
}
|
|
@@ -141,7 +126,7 @@
|
|
| 141 |
},
|
| 142 |
{
|
| 143 |
"cell_type": "code",
|
| 144 |
-
"execution_count":
|
| 145 |
"id": "f7450c55",
|
| 146 |
"metadata": {},
|
| 147 |
"outputs": [
|
|
@@ -169,7 +154,7 @@
|
|
| 169 |
}
|
| 170 |
],
|
| 171 |
"source": [
|
| 172 |
-
"!wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true "
|
| 173 |
]
|
| 174 |
},
|
| 175 |
{
|
|
@@ -333,22 +318,19 @@
|
|
| 333 |
},
|
| 334 |
{
|
| 335 |
"cell_type": "code",
|
| 336 |
-
"execution_count":
|
| 337 |
"id": "91ef7a4e",
|
| 338 |
"metadata": {},
|
| 339 |
"outputs": [
|
| 340 |
{
|
| 341 |
-
"
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 350 |
-
"\u001b[0;31mKeyError\u001b[0m: '_oh'"
|
| 351 |
-
]
|
| 352 |
}
|
| 353 |
],
|
| 354 |
"source": [
|
|
@@ -361,7 +343,7 @@
|
|
| 361 |
},
|
| 362 |
{
|
| 363 |
"cell_type": "code",
|
| 364 |
-
"execution_count":
|
| 365 |
"id": "08f29055",
|
| 366 |
"metadata": {},
|
| 367 |
"outputs": [
|
|
@@ -369,7 +351,7 @@
|
|
| 369 |
"name": "stdout",
|
| 370 |
"output_type": "stream",
|
| 371 |
"text": [
|
| 372 |
-
"GPU memory used:
|
| 373 |
]
|
| 374 |
}
|
| 375 |
],
|
|
@@ -396,7 +378,7 @@
|
|
| 396 |
},
|
| 397 |
{
|
| 398 |
"cell_type": "code",
|
| 399 |
-
"execution_count":
|
| 400 |
"id": "37335c1e",
|
| 401 |
"metadata": {},
|
| 402 |
"outputs": [],
|
|
@@ -416,7 +398,7 @@
|
|
| 416 |
},
|
| 417 |
{
|
| 418 |
"cell_type": "code",
|
| 419 |
-
"execution_count":
|
| 420 |
"id": "35d98b83",
|
| 421 |
"metadata": {},
|
| 422 |
"outputs": [],
|
|
@@ -450,14 +432,231 @@
|
|
| 450 |
"id": "d7ff094a",
|
| 451 |
"metadata": {},
|
| 452 |
"outputs": [],
|
| 453 |
-
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
},
|
| 455 |
{
|
| 456 |
"cell_type": "code",
|
| 457 |
"execution_count": null,
|
| 458 |
-
"id": "
|
| 459 |
"metadata": {},
|
| 460 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
"source": [
|
| 462 |
"import inspect\n",
|
| 463 |
"import os\n",
|
|
@@ -475,8 +674,6 @@
|
|
| 475 |
"from huggingface_hub import snapshot_download\n",
|
| 476 |
"from transformers import CLIPImageProcessor\n",
|
| 477 |
"\n",
|
| 478 |
-
"from model.attn_processor import SkipAttnProcessor\n",
|
| 479 |
-
"from model.utils import get_trainable_module, init_adapter\n",
|
| 480 |
"from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
|
| 481 |
" prepare_mask_image, resize_and_crop, resize_and_padding)\n",
|
| 482 |
"from ddpm import DDPMSampler\n",
|
|
@@ -755,7 +952,7 @@
|
|
| 755 |
"from diffusers.image_processor import VaeImageProcessor\n",
|
| 756 |
"from tqdm import tqdm\n",
|
| 757 |
"from PIL import Image, ImageFilter\n",
|
| 758 |
-
"import
|
| 759 |
"\n",
|
| 760 |
"from utils import repaint, to_pil_image\n",
|
| 761 |
" \n",
|
|
@@ -921,7 +1118,8 @@
|
|
| 921 |
" \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n",
|
| 922 |
" \"resume_path\": \"zhengchong/CatVTON\",\n",
|
| 923 |
" \"dataset_name\": \"vitonhd\",\n",
|
| 924 |
-
" \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
|
|
|
|
| 925 |
" \"output_dir\": \"./output\",\n",
|
| 926 |
" \"seed\": 555,\n",
|
| 927 |
" \"batch_size\": 2,\n",
|
|
@@ -936,10 +1134,11 @@
|
|
| 936 |
" \"dataloader_num_workers\": 4,\n",
|
| 937 |
" \"mixed_precision\": 'no',\n",
|
| 938 |
" \"concat_axis\": 'y',\n",
|
| 939 |
-
" \"enable_condition_noise\": True
|
|
|
|
| 940 |
" }\n",
|
| 941 |
"\n",
|
| 942 |
-
" models=
|
| 943 |
"\n",
|
| 944 |
" # Pipeline\n",
|
| 945 |
" pipeline = CatVTONPipeline(\n",
|
|
@@ -1795,18 +1994,6 @@
|
|
| 1795 |
"display_name": "Python 3 (ipykernel)",
|
| 1796 |
"language": "python",
|
| 1797 |
"name": "python3"
|
| 1798 |
-
},
|
| 1799 |
-
"language_info": {
|
| 1800 |
-
"codemirror_mode": {
|
| 1801 |
-
"name": "ipython",
|
| 1802 |
-
"version": 3
|
| 1803 |
-
},
|
| 1804 |
-
"file_extension": ".py",
|
| 1805 |
-
"mimetype": "text/x-python",
|
| 1806 |
-
"name": "python",
|
| 1807 |
-
"nbconvert_exporter": "python",
|
| 1808 |
-
"pygments_lexer": "ipython3",
|
| 1809 |
-
"version": "3.11.11"
|
| 1810 |
}
|
| 1811 |
},
|
| 1812 |
"nbformat": 4,
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "6387c9e1",
|
| 7 |
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "ca9233f0",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [
|
| 17 |
{
|
| 18 |
+
"data": {
|
| 19 |
+
"text/plain": [
|
| 20 |
+
"'/kaggle/working'"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
"execution_count": 16,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"output_type": "execute_result"
|
| 26 |
}
|
| 27 |
],
|
| 28 |
+
"source": []
|
|
|
|
|
|
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"cell_type": "code",
|
| 32 |
+
"execution_count": 17,
|
| 33 |
+
"id": "3d2f98af",
|
| 34 |
"metadata": {},
|
| 35 |
"outputs": [
|
| 36 |
{
|
| 37 |
"name": "stdout",
|
| 38 |
"output_type": "stream",
|
| 39 |
"text": [
|
| 40 |
+
"\u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;32mtest_pairs.txt\u001b[0m* \u001b[01;34mtrain\u001b[0m/ \u001b[01;32mtrain_pairs.txt\u001b[0m*\n"
|
| 41 |
]
|
| 42 |
}
|
| 43 |
],
|
| 44 |
"source": [
|
| 45 |
+
"ls /kaggle/input/viton-hd-dataset"
|
| 46 |
]
|
| 47 |
},
|
| 48 |
{
|
| 49 |
"cell_type": "code",
|
| 50 |
+
"execution_count": 18,
|
| 51 |
"id": "dc0f36f4",
|
| 52 |
"metadata": {},
|
| 53 |
"outputs": [
|
|
|
|
| 56 |
"output_type": "stream",
|
| 57 |
"text": [
|
| 58 |
"Cloning into 'stable-diffusion'...\n",
|
| 59 |
+
"remote: Enumerating objects: 150, done.\u001b[K\n",
|
| 60 |
+
"remote: Counting objects: 100% (150/150), done.\u001b[K\n",
|
| 61 |
+
"remote: Compressing objects: 100% (124/124), done.\u001b[K\n",
|
| 62 |
+
"remote: Total 150 (delta 36), reused 139 (delta 26), pack-reused 0 (from 0)\u001b[K\n",
|
| 63 |
+
"Receiving objects: 100% (150/150), 9.11 MiB | 20.74 MiB/s, done.\n",
|
| 64 |
+
"Resolving deltas: 100% (36/36), done.\n"
|
| 65 |
]
|
| 66 |
}
|
| 67 |
],
|
|
|
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"cell_type": "code",
|
| 74 |
+
"execution_count": 19,
|
| 75 |
"id": "a0bf01ab",
|
| 76 |
"metadata": {},
|
| 77 |
"outputs": [
|
|
|
|
| 79 |
"name": "stdout",
|
| 80 |
"output_type": "stream",
|
| 81 |
"text": [
|
| 82 |
+
"/kaggle/working/stable-diffusion\n"
|
| 83 |
]
|
| 84 |
}
|
| 85 |
],
|
|
|
|
| 89 |
},
|
| 90 |
{
|
| 91 |
"cell_type": "code",
|
| 92 |
+
"execution_count": 20,
|
| 93 |
"id": "1401cd56",
|
| 94 |
"metadata": {},
|
| 95 |
"outputs": [
|
|
|
|
| 97 |
"name": "stdout",
|
| 98 |
"output_type": "stream",
|
| 99 |
"text": [
|
| 100 |
+
"--2025-06-13 07:07:34-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 101 |
+
"Resolving huggingface.co (huggingface.co)... 18.67.93.22, 18.67.93.63, 18.67.93.58, ...\n",
|
| 102 |
+
"Connecting to huggingface.co (huggingface.co)|18.67.93.22|:443... connected.\n",
|
| 103 |
"HTTP request sent, awaiting response... 307 Temporary Redirect\n",
|
| 104 |
"Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
|
| 105 |
+
"--2025-06-13 07:07:34-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 106 |
"Reusing existing connection to huggingface.co:443.\n",
|
| 107 |
"HTTP request sent, awaiting response... 302 Found\n",
|
| 108 |
+
"Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
|
| 109 |
+
"--2025-06-13 07:07:35-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE\n",
|
| 110 |
+
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 108.158.20.116, 108.158.20.30, 108.158.20.84, ...\n",
|
| 111 |
+
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|108.158.20.116|:443... connected.\n",
|
| 112 |
"HTTP request sent, awaiting response... 200 OK\n",
|
| 113 |
"Length: 4265437280 (4.0G) [binary/octet-stream]\n",
|
| 114 |
"Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
|
| 115 |
"\n",
|
| 116 |
+
"sd-v1-5-inpainting. 100%[===================>] 3.97G 366MB/s in 12s \n",
|
| 117 |
"\n",
|
| 118 |
+
"2025-06-13 07:07:46 (353 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
|
| 119 |
"\n"
|
| 120 |
]
|
| 121 |
}
|
|
|
|
| 126 |
},
|
| 127 |
{
|
| 128 |
"cell_type": "code",
|
| 129 |
+
"execution_count": null,
|
| 130 |
"id": "f7450c55",
|
| 131 |
"metadata": {},
|
| 132 |
"outputs": [
|
|
|
|
| 154 |
}
|
| 155 |
],
|
| 156 |
"source": [
|
| 157 |
+
"# !wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true "
|
| 158 |
]
|
| 159 |
},
|
| 160 |
{
|
|
|
|
| 318 |
},
|
| 319 |
{
|
| 320 |
"cell_type": "code",
|
| 321 |
+
"execution_count": 4,
|
| 322 |
"id": "91ef7a4e",
|
| 323 |
"metadata": {},
|
| 324 |
"outputs": [
|
| 325 |
{
|
| 326 |
+
"data": {
|
| 327 |
+
"text/plain": [
|
| 328 |
+
"0"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
"execution_count": 4,
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"output_type": "execute_result"
|
|
|
|
|
|
|
|
|
|
| 334 |
}
|
| 335 |
],
|
| 336 |
"source": [
|
|
|
|
| 343 |
},
|
| 344 |
{
|
| 345 |
"cell_type": "code",
|
| 346 |
+
"execution_count": 9,
|
| 347 |
"id": "08f29055",
|
| 348 |
"metadata": {},
|
| 349 |
"outputs": [
|
|
|
|
| 351 |
"name": "stdout",
|
| 352 |
"output_type": "stream",
|
| 353 |
"text": [
|
| 354 |
+
"GPU memory used: 0.00 MB / 16269.25 MB\n"
|
| 355 |
]
|
| 356 |
}
|
| 357 |
],
|
|
|
|
| 378 |
},
|
| 379 |
{
|
| 380 |
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
"id": "37335c1e",
|
| 383 |
"metadata": {},
|
| 384 |
"outputs": [],
|
|
|
|
| 398 |
},
|
| 399 |
{
|
| 400 |
"cell_type": "code",
|
| 401 |
+
"execution_count": null,
|
| 402 |
"id": "35d98b83",
|
| 403 |
"metadata": {},
|
| 404 |
"outputs": [],
|
|
|
|
| 432 |
"id": "d7ff094a",
|
| 433 |
"metadata": {},
|
| 434 |
"outputs": [],
|
| 435 |
+
"source": [
|
| 436 |
+
"from torch.nn import functional as F\n",
|
| 437 |
+
"import torch\n",
|
| 438 |
+
"# from flash_attn import flash_attn_func\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"class SkipAttnProcessor(torch.nn.Module):\n",
|
| 441 |
+
" def __init__(self, *args, **kwargs) -> None:\n",
|
| 442 |
+
" super().__init__()\n",
|
| 443 |
+
"\n",
|
| 444 |
+
" def __call__(\n",
|
| 445 |
+
" self,\n",
|
| 446 |
+
" attn,\n",
|
| 447 |
+
" hidden_states,\n",
|
| 448 |
+
" encoder_hidden_states=None,\n",
|
| 449 |
+
" attention_mask=None,\n",
|
| 450 |
+
" temb=None,\n",
|
| 451 |
+
" ):\n",
|
| 452 |
+
" return hidden_states\n",
|
| 453 |
+
"\n",
|
| 454 |
+
"class AttnProcessor2_0(torch.nn.Module):\n",
|
| 455 |
+
" r\"\"\"\n",
|
| 456 |
+
" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n",
|
| 457 |
+
" \"\"\"\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" def __init__(\n",
|
| 460 |
+
" self,\n",
|
| 461 |
+
" hidden_size=None,\n",
|
| 462 |
+
" cross_attention_dim=None,\n",
|
| 463 |
+
" **kwargs\n",
|
| 464 |
+
" ):\n",
|
| 465 |
+
" super().__init__()\n",
|
| 466 |
+
" if not hasattr(F, \"scaled_dot_product_attention\"):\n",
|
| 467 |
+
" raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n",
|
| 468 |
+
"\n",
|
| 469 |
+
" def __call__(\n",
|
| 470 |
+
" self,\n",
|
| 471 |
+
" attn,\n",
|
| 472 |
+
" hidden_states,\n",
|
| 473 |
+
" encoder_hidden_states=None,\n",
|
| 474 |
+
" attention_mask=None,\n",
|
| 475 |
+
" temb=None,\n",
|
| 476 |
+
" *args,\n",
|
| 477 |
+
" **kwargs,\n",
|
| 478 |
+
" ):\n",
|
| 479 |
+
" residual = hidden_states\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" if attn.spatial_norm is not None:\n",
|
| 482 |
+
" hidden_states = attn.spatial_norm(hidden_states, temb)\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" input_ndim = hidden_states.ndim\n",
|
| 485 |
+
"\n",
|
| 486 |
+
" if input_ndim == 4:\n",
|
| 487 |
+
" batch_size, channel, height, width = hidden_states.shape\n",
|
| 488 |
+
" hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" batch_size, sequence_length, _ = (\n",
|
| 491 |
+
" hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
|
| 492 |
+
" )\n",
|
| 493 |
+
"\n",
|
| 494 |
+
" if attention_mask is not None:\n",
|
| 495 |
+
" attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n",
|
| 496 |
+
" # scaled_dot_product_attention expects attention_mask shape to be\n",
|
| 497 |
+
" # (batch, heads, source_length, target_length)\n",
|
| 498 |
+
" attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" if attn.group_norm is not None:\n",
|
| 501 |
+
" hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" query = attn.to_q(hidden_states)\n",
|
| 504 |
+
"\n",
|
| 505 |
+
" if encoder_hidden_states is None:\n",
|
| 506 |
+
" encoder_hidden_states = hidden_states\n",
|
| 507 |
+
" elif attn.norm_cross:\n",
|
| 508 |
+
" encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n",
|
| 509 |
+
"\n",
|
| 510 |
+
" key = attn.to_k(encoder_hidden_states)\n",
|
| 511 |
+
" value = attn.to_v(encoder_hidden_states)\n",
|
| 512 |
+
"\n",
|
| 513 |
+
" inner_dim = key.shape[-1]\n",
|
| 514 |
+
" head_dim = inner_dim // attn.heads\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 517 |
+
"\n",
|
| 518 |
+
" key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 519 |
+
" value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 520 |
+
"\n",
|
| 521 |
+
" # the output of sdp = (batch, num_heads, seq_len, head_dim)\n",
|
| 522 |
+
" # TODO: add support for attn.scale when we move to Torch 2.1\n",
|
| 523 |
+
" \n",
|
| 524 |
+
" hidden_states = F.scaled_dot_product_attention(\n",
|
| 525 |
+
" query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n",
|
| 526 |
+
" )\n",
|
| 527 |
+
" # hidden_states = flash_attn_func(\n",
|
| 528 |
+
" # query, key, value, dropout_p=0.0, causal=False\n",
|
| 529 |
+
" # )\n",
|
| 530 |
+
"\n",
|
| 531 |
+
" hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n",
|
| 532 |
+
" hidden_states = hidden_states.to(query.dtype)\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" # linear proj\n",
|
| 535 |
+
" hidden_states = attn.to_out[0](hidden_states)\n",
|
| 536 |
+
" # dropout\n",
|
| 537 |
+
" hidden_states = attn.to_out[1](hidden_states)\n",
|
| 538 |
+
"\n",
|
| 539 |
+
" if input_ndim == 4:\n",
|
| 540 |
+
" hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
|
| 541 |
+
"\n",
|
| 542 |
+
" if attn.residual_connection:\n",
|
| 543 |
+
" hidden_states = hidden_states + residual\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" hidden_states = hidden_states / attn.rescale_output_factor\n",
|
| 546 |
+
"\n",
|
| 547 |
+
" return hidden_states\n",
|
| 548 |
+
" "
|
| 549 |
+
]
|
| 550 |
},
|
| 551 |
{
|
| 552 |
"cell_type": "code",
|
| 553 |
"execution_count": null,
|
| 554 |
+
"id": "84a7fa87",
|
| 555 |
"metadata": {},
|
| 556 |
"outputs": [],
|
| 557 |
+
"source": [
|
| 558 |
+
"import os\n",
|
| 559 |
+
"import json\n",
|
| 560 |
+
"import torch\n",
|
| 561 |
+
"\n",
|
| 562 |
+
"def init_adapter(unet, \n",
|
| 563 |
+
" cross_attn_cls=SkipAttnProcessor,\n",
|
| 564 |
+
" self_attn_cls=None,\n",
|
| 565 |
+
" cross_attn_dim=None, \n",
|
| 566 |
+
" **kwargs):\n",
|
| 567 |
+
" if cross_attn_dim is None:\n",
|
| 568 |
+
" cross_attn_dim = unet.config.cross_attention_dim\n",
|
| 569 |
+
" attn_procs = {}\n",
|
| 570 |
+
" for name in unet.attn_processors.keys():\n",
|
| 571 |
+
" cross_attention_dim = None if name.endswith(\"attn1.processor\") else cross_attn_dim\n",
|
| 572 |
+
" if name.startswith(\"mid_block\"):\n",
|
| 573 |
+
" hidden_size = unet.config.block_out_channels[-1]\n",
|
| 574 |
+
" elif name.startswith(\"up_blocks\"):\n",
|
| 575 |
+
" block_id = int(name[len(\"up_blocks.\")])\n",
|
| 576 |
+
" hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n",
|
| 577 |
+
" elif name.startswith(\"down_blocks\"):\n",
|
| 578 |
+
" block_id = int(name[len(\"down_blocks.\")])\n",
|
| 579 |
+
" hidden_size = unet.config.block_out_channels[block_id]\n",
|
| 580 |
+
" if cross_attention_dim is None:\n",
|
| 581 |
+
" if self_attn_cls is not None:\n",
|
| 582 |
+
" attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 583 |
+
" else:\n",
|
| 584 |
+
" # retain the original attn processor\n",
|
| 585 |
+
" attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 586 |
+
" else:\n",
|
| 587 |
+
" attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 588 |
+
" \n",
|
| 589 |
+
" unet.set_attn_processor(attn_procs)\n",
|
| 590 |
+
" adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n",
|
| 591 |
+
" return adapter_modules\n",
|
| 592 |
+
"\n",
|
| 593 |
+
"def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):\n",
|
| 594 |
+
" from diffusers import AutoencoderKL\n",
|
| 595 |
+
" from transformers import CLIPTextModel, CLIPTokenizer\n",
|
| 596 |
+
"\n",
|
| 597 |
+
" text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder=\"text_encoder\")\n",
|
| 598 |
+
" vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder=\"vae\")\n",
|
| 599 |
+
" tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder=\"tokenizer\")\n",
|
| 600 |
+
" try:\n",
|
| 601 |
+
" unet_folder = os.path.join(diffusion_model_name_or_path, \"unet\")\n",
|
| 602 |
+
" unet_configs = json.load(open(os.path.join(unet_folder, \"config.json\"), \"r\"))\n",
|
| 603 |
+
" unet = unet_class(**unet_configs)\n",
|
| 604 |
+
" unet.load_state_dict(torch.load(os.path.join(unet_folder, \"diffusion_pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n",
|
| 605 |
+
" except:\n",
|
| 606 |
+
" unet = None\n",
|
| 607 |
+
" return text_encoder, vae, tokenizer, unet\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"def attn_of_unet(unet):\n",
|
| 610 |
+
" attn_blocks = torch.nn.ModuleList()\n",
|
| 611 |
+
" for name, param in unet.named_modules():\n",
|
| 612 |
+
" if \"attn1\" in name:\n",
|
| 613 |
+
" attn_blocks.append(param)\n",
|
| 614 |
+
" return attn_blocks\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"def get_trainable_module(unet, trainable_module_name):\n",
|
| 617 |
+
" if trainable_module_name == \"unet\":\n",
|
| 618 |
+
" return unet\n",
|
| 619 |
+
" elif trainable_module_name == \"transformer\":\n",
|
| 620 |
+
" trainable_modules = torch.nn.ModuleList()\n",
|
| 621 |
+
" for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:\n",
|
| 622 |
+
" if hasattr(blocks, \"attentions\"):\n",
|
| 623 |
+
" trainable_modules.append(blocks.attentions)\n",
|
| 624 |
+
" else:\n",
|
| 625 |
+
" for block in blocks:\n",
|
| 626 |
+
" if hasattr(block, \"attentions\"):\n",
|
| 627 |
+
" trainable_modules.append(block.attentions)\n",
|
| 628 |
+
" return trainable_modules\n",
|
| 629 |
+
" elif trainable_module_name == \"attention\":\n",
|
| 630 |
+
" attn_blocks = torch.nn.ModuleList()\n",
|
| 631 |
+
" for name, param in unet.named_modules():\n",
|
| 632 |
+
" if \"attn1\" in name:\n",
|
| 633 |
+
" attn_blocks.append(param)\n",
|
| 634 |
+
" return attn_blocks\n",
|
| 635 |
+
" else:\n",
|
| 636 |
+
" raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")\n",
|
| 637 |
+
"\n",
|
| 638 |
+
" \n",
|
| 639 |
+
" "
|
| 640 |
+
]
|
| 641 |
+
},
|
| 642 |
+
{
|
| 643 |
+
"cell_type": "code",
|
| 644 |
+
"execution_count": null,
|
| 645 |
+
"id": "6028381d",
|
| 646 |
+
"metadata": {},
|
| 647 |
+
"outputs": [
|
| 648 |
+
{
|
| 649 |
+
"ename": "ModuleNotFoundError",
|
| 650 |
+
"evalue": "No module named 'model'",
|
| 651 |
+
"output_type": "error",
|
| 652 |
+
"traceback": [
|
| 653 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 654 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 655 |
+
"\u001b[0;32m/tmp/ipykernel_662/1349749640.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCLIPImageProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_processor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSkipAttnProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_trainable_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_adapter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
|
| 656 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'model'"
|
| 657 |
+
]
|
| 658 |
+
}
|
| 659 |
+
],
|
| 660 |
"source": [
|
| 661 |
"import inspect\n",
|
| 662 |
"import os\n",
|
|
|
|
| 674 |
"from huggingface_hub import snapshot_download\n",
|
| 675 |
"from transformers import CLIPImageProcessor\n",
|
| 676 |
"\n",
|
|
|
|
|
|
|
| 677 |
"from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
|
| 678 |
" prepare_mask_image, resize_and_crop, resize_and_padding)\n",
|
| 679 |
"from ddpm import DDPMSampler\n",
|
|
|
|
| 952 |
"from diffusers.image_processor import VaeImageProcessor\n",
|
| 953 |
"from tqdm import tqdm\n",
|
| 954 |
"from PIL import Image, ImageFilter\n",
|
| 955 |
+
"import load_model\n",
|
| 956 |
"\n",
|
| 957 |
"from utils import repaint, to_pil_image\n",
|
| 958 |
" \n",
|
|
|
|
| 1118 |
" \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n",
|
| 1119 |
" \"resume_path\": \"zhengchong/CatVTON\",\n",
|
| 1120 |
" \"dataset_name\": \"vitonhd\",\n",
|
| 1121 |
+
" # \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
|
| 1122 |
+
" \"data_root_path\": \"/kaggle/working/stable-diffusion/sample_dataset\",\n",
|
| 1123 |
" \"output_dir\": \"./output\",\n",
|
| 1124 |
" \"seed\": 555,\n",
|
| 1125 |
" \"batch_size\": 2,\n",
|
|
|
|
| 1134 |
" \"dataloader_num_workers\": 4,\n",
|
| 1135 |
" \"mixed_precision\": 'no',\n",
|
| 1136 |
" \"concat_axis\": 'y',\n",
|
| 1137 |
+
" \"enable_condition_noise\": True,\n",
|
| 1138 |
+
" \"is_train\": False\n",
|
| 1139 |
" }\n",
|
| 1140 |
"\n",
|
| 1141 |
+
" models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"/kaggle/working/stable-diffusion/checkpoints/checkpoint_epoch_10.pth\")\n",
|
| 1142 |
"\n",
|
| 1143 |
" # Pipeline\n",
|
| 1144 |
" pipeline = CatVTONPipeline(\n",
|
|
|
|
| 1994 |
"display_name": "Python 3 (ipykernel)",
|
| 1995 |
"language": "python",
|
| 1996 |
"name": "python3"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1997 |
}
|
| 1998 |
},
|
| 1999 |
"nbformat": 4,
|