harsh99 commited on
Commit
569254a
·
1 Parent(s): 870796d
Files changed (5) hide show
  1. VITON_Dataset.py +1 -1
  2. interface.py +2 -2
  3. model.py → load_model.py +0 -0
  4. pipeline.py +2 -2
  5. test.ipynb +279 -92
VITON_Dataset.py CHANGED
@@ -32,7 +32,7 @@ class InferenceDataset(Dataset):
32
 
33
  class VITONHDTestDataset(InferenceDataset):
34
  def load_data(self):
35
- name= "train" if self.args.is_train else "test"
36
  assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist."
37
  with open(pair_txt, 'r') as f:
38
  lines = f.readlines()
 
32
 
33
  class VITONHDTestDataset(InferenceDataset):
34
  def load_data(self):
35
+ name= "train" if self.args.is_train else "samples"
36
  assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, f'{name}_pairs.txt')), f"File {pair_txt} does not exist."
37
  with open(pair_txt, 'r') as f:
38
  lines = f.readlines()
interface.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  from transformers import CLIPTokenizer
5
 
6
  # Import your existing model and pipeline modules
7
- import model
8
  import pipeline
9
 
10
  # Device Configuration
@@ -24,7 +24,7 @@ print(f"Using device: {DEVICE}")
24
  # Load tokenizer and models
25
  tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
26
  model_file = "inkpunk-diffusion-v1.ckpt"
27
- models = model.preload_models_from_standard_weights(model_file, DEVICE)
28
  # models=None
29
 
30
  def generate_image(
 
4
  from transformers import CLIPTokenizer
5
 
6
  # Import your existing model and pipeline modules
7
+ import load_model
8
  import pipeline
9
 
10
  # Device Configuration
 
24
  # Load tokenizer and models
25
  tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
26
  model_file = "inkpunk-diffusion-v1.ckpt"
27
+ models = load_model.preload_models_from_standard_weights(model_file, DEVICE)
28
  # models=None
29
 
30
  def generate_image(
model.py → load_model.py RENAMED
File without changes
pipeline.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from tqdm import tqdm
7
  from ddpm import DDPMSampler
8
  from PIL import Image
9
- import model
10
  from utils import check_inputs, prepare_image, prepare_mask_image
11
 
12
  WIDTH = 512
@@ -293,7 +293,7 @@ if __name__ == "__main__":
293
  mask = Image.open("agnostic_mask.png").convert("L")
294
 
295
  # Load models
296
- models=model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda")
297
 
298
  # Generate image
299
  generated_image = generate(
 
6
  from tqdm import tqdm
7
  from ddpm import DDPMSampler
8
  from PIL import Image
9
+ import load_model
10
  from utils import check_inputs, prepare_image, prepare_mask_image
11
 
12
  WIDTH = 512
 
293
  mask = Image.open("agnostic_mask.png").convert("L")
294
 
295
  # Load models
296
+ models=load_model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda")
297
 
298
  # Generate image
299
  generated_image = generate(
test.ipynb CHANGED
@@ -2,67 +2,52 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 4,
6
- "id": "867520bc",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Cloning into 'CatVTON'...\n",
14
- "remote: Enumerating objects: 1348, done.\u001b[K\n",
15
- "remote: Counting objects: 100% (62/62), done.\u001b[K\n",
16
- "remote: Compressing objects: 100% (29/29), done.\u001b[K\n",
17
- "remote: Total 1348 (delta 51), reused 33 (delta 33), pack-reused 1286 (from 3)\u001b[K\n",
18
- "Receiving objects: 100% (1348/1348), 16.74 MiB | 42.65 MiB/s, done.\n",
19
- "Resolving deltas: 100% (449/449), done.\n"
20
- ]
21
- }
22
- ],
23
- "source": [
24
- "!git clone https://github.com/Zheng-Chong/CatVTON.git"
25
- ]
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": 5,
30
- "id": "3d2f98af",
31
  "metadata": {},
32
  "outputs": [
33
  {
34
- "name": "stdout",
35
- "output_type": "stream",
36
- "text": [
37
- "\u001b[0m\u001b[01;34mtest\u001b[0m/ test_pairs.txt \u001b[01;34mtrain\u001b[0m/ train_pairs.txt\n"
38
- ]
 
 
 
39
  }
40
  ],
41
- "source": [
42
- "ls /kaggle/input/viton-hd-dataset"
43
- ]
44
  },
45
  {
46
  "cell_type": "code",
47
- "execution_count": 6,
48
- "id": "ba750da0",
49
  "metadata": {},
50
  "outputs": [
51
  {
52
  "name": "stdout",
53
  "output_type": "stream",
54
  "text": [
55
- "/kaggle/working/stable-diffusion/CatVTON\n"
56
  ]
57
  }
58
  ],
59
  "source": [
60
- "cd CatVTON/"
61
  ]
62
  },
63
  {
64
  "cell_type": "code",
65
- "execution_count": 7,
66
  "id": "dc0f36f4",
67
  "metadata": {},
68
  "outputs": [
@@ -71,12 +56,12 @@
71
  "output_type": "stream",
72
  "text": [
73
  "Cloning into 'stable-diffusion'...\n",
74
- "remote: Enumerating objects: 56, done.\u001b[K\n",
75
- "remote: Counting objects: 100% (56/56), done.\u001b[K\n",
76
- "remote: Compressing objects: 100% (44/44), done.\u001b[K\n",
77
- "remote: Total 56 (delta 17), reused 50 (delta 12), pack-reused 0 (from 0)\u001b[K\n",
78
- "Receiving objects: 100% (56/56), 4.68 MiB | 36.31 MiB/s, done.\n",
79
- "Resolving deltas: 100% (17/17), done.\n"
80
  ]
81
  }
82
  ],
@@ -86,7 +71,7 @@
86
  },
87
  {
88
  "cell_type": "code",
89
- "execution_count": 8,
90
  "id": "a0bf01ab",
91
  "metadata": {},
92
  "outputs": [
@@ -94,7 +79,7 @@
94
  "name": "stdout",
95
  "output_type": "stream",
96
  "text": [
97
- "/kaggle/working/stable-diffusion/CatVTON/stable-diffusion\n"
98
  ]
99
  }
100
  ],
@@ -104,7 +89,7 @@
104
  },
105
  {
106
  "cell_type": "code",
107
- "execution_count": 9,
108
  "id": "1401cd56",
109
  "metadata": {},
110
  "outputs": [
@@ -112,25 +97,25 @@
112
  "name": "stdout",
113
  "output_type": "stream",
114
  "text": [
115
- "--2025-06-11 10:33:00-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
116
- "Resolving huggingface.co (huggingface.co)... 3.163.189.114, 3.163.189.74, 3.163.189.90, ...\n",
117
- "Connecting to huggingface.co (huggingface.co)|3.163.189.114|:443... connected.\n",
118
  "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
119
  "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
120
- "--2025-06-11 10:33:01-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
121
  "Reusing existing connection to huggingface.co:443.\n",
122
  "HTTP request sent, awaiting response... 302 Found\n",
123
- "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749640621&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MDYyMX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=sdomKXQPt3COUrAxFqHQCR83b0Sgw0kHwStFv%7EqgSrCfwOddw9sNRX3qela0jgons998TT3Oqk0TA0c-PTLyPpAO-iqM9aGvLsRLixtxMNgdvDdWdk87Ywdgvg24T6GkVaL9I8ErFlF918m%7EYMtHICZ8hcoq1GST-DdDigp4vA-w9lHnRfOGteBzViPKyqgQaYiYRd10FVmSYYpFUJrZ%7ECFAGO5MwVA-OTlMVLOYKKPs0s3duoP4KIz9-SUoUIXbgUmiuExLqdVulk-tJRCSAk-u7WvbUhPUsraiP1YGa-QvUYoygX5xlluuFIt%7EG54t5TrCzIWP0tu0ZGaqr3%7E%7EEA__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
124
- "--2025-06-11 10:33:01-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749640621&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MDYyMX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=sdomKXQPt3COUrAxFqHQCR83b0Sgw0kHwStFv%7EqgSrCfwOddw9sNRX3qela0jgons998TT3Oqk0TA0c-PTLyPpAO-iqM9aGvLsRLixtxMNgdvDdWdk87Ywdgvg24T6GkVaL9I8ErFlF918m%7EYMtHICZ8hcoq1GST-DdDigp4vA-w9lHnRfOGteBzViPKyqgQaYiYRd10FVmSYYpFUJrZ%7ECFAGO5MwVA-OTlMVLOYKKPs0s3duoP4KIz9-SUoUIXbgUmiuExLqdVulk-tJRCSAk-u7WvbUhPUsraiP1YGa-QvUYoygX5xlluuFIt%7EG54t5TrCzIWP0tu0ZGaqr3%7E%7EEA__&Key-Pair-Id=K3RPWS32NSSJCE\n",
125
- "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.172.170.108, 18.172.170.21, 18.172.170.5, ...\n",
126
- "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.172.170.108|:443... connected.\n",
127
  "HTTP request sent, awaiting response... 200 OK\n",
128
  "Length: 4265437280 (4.0G) [binary/octet-stream]\n",
129
  "Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
130
  "\n",
131
- "sd-v1-5-inpainting. 100%[===================>] 3.97G 299MB/s in 12s \n",
132
  "\n",
133
- "2025-06-11 10:33:13 (341 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
134
  "\n"
135
  ]
136
  }
@@ -141,7 +126,7 @@
141
  },
142
  {
143
  "cell_type": "code",
144
- "execution_count": 10,
145
  "id": "f7450c55",
146
  "metadata": {},
147
  "outputs": [
@@ -169,7 +154,7 @@
169
  }
170
  ],
171
  "source": [
172
- "!wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true "
173
  ]
174
  },
175
  {
@@ -333,22 +318,19 @@
333
  },
334
  {
335
  "cell_type": "code",
336
- "execution_count": 247,
337
  "id": "91ef7a4e",
338
  "metadata": {},
339
  "outputs": [
340
  {
341
- "ename": "KeyError",
342
- "evalue": "'_oh'",
343
- "output_type": "error",
344
- "traceback": [
345
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
346
- "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
347
- "\u001b[0;32m/tmp/ipykernel_71/1017109895.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Release unused GPU memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Run Python garbage collector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
348
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
349
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
350
- "\u001b[0;31mKeyError\u001b[0m: '_oh'"
351
- ]
352
  }
353
  ],
354
  "source": [
@@ -361,7 +343,7 @@
361
  },
362
  {
363
  "cell_type": "code",
364
- "execution_count": 249,
365
  "id": "08f29055",
366
  "metadata": {},
367
  "outputs": [
@@ -369,7 +351,7 @@
369
  "name": "stdout",
370
  "output_type": "stream",
371
  "text": [
372
- "GPU memory used: 8.12 MB / 16269.25 MB\n"
373
  ]
374
  }
375
  ],
@@ -396,7 +378,7 @@
396
  },
397
  {
398
  "cell_type": "code",
399
- "execution_count": 18,
400
  "id": "37335c1e",
401
  "metadata": {},
402
  "outputs": [],
@@ -416,7 +398,7 @@
416
  },
417
  {
418
  "cell_type": "code",
419
- "execution_count": 19,
420
  "id": "35d98b83",
421
  "metadata": {},
422
  "outputs": [],
@@ -450,14 +432,231 @@
450
  "id": "d7ff094a",
451
  "metadata": {},
452
  "outputs": [],
453
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  },
455
  {
456
  "cell_type": "code",
457
  "execution_count": null,
458
- "id": "6028381d",
459
  "metadata": {},
460
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  "source": [
462
  "import inspect\n",
463
  "import os\n",
@@ -475,8 +674,6 @@
475
  "from huggingface_hub import snapshot_download\n",
476
  "from transformers import CLIPImageProcessor\n",
477
  "\n",
478
- "from model.attn_processor import SkipAttnProcessor\n",
479
- "from model.utils import get_trainable_module, init_adapter\n",
480
  "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
481
  " prepare_mask_image, resize_and_crop, resize_and_padding)\n",
482
  "from ddpm import DDPMSampler\n",
@@ -755,7 +952,7 @@
755
  "from diffusers.image_processor import VaeImageProcessor\n",
756
  "from tqdm import tqdm\n",
757
  "from PIL import Image, ImageFilter\n",
758
- "import model\n",
759
  "\n",
760
  "from utils import repaint, to_pil_image\n",
761
  " \n",
@@ -921,7 +1118,8 @@
921
  " \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n",
922
  " \"resume_path\": \"zhengchong/CatVTON\",\n",
923
  " \"dataset_name\": \"vitonhd\",\n",
924
- " \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
 
925
  " \"output_dir\": \"./output\",\n",
926
  " \"seed\": 555,\n",
927
  " \"batch_size\": 2,\n",
@@ -936,10 +1134,11 @@
936
  " \"dataloader_num_workers\": 4,\n",
937
  " \"mixed_precision\": 'no',\n",
938
  " \"concat_axis\": 'y',\n",
939
- " \"enable_condition_noise\": True\n",
 
940
  " }\n",
941
  "\n",
942
- " models=model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weight_path=\"model.safetensors\")\n",
943
  "\n",
944
  " # Pipeline\n",
945
  " pipeline = CatVTONPipeline(\n",
@@ -1795,18 +1994,6 @@
1795
  "display_name": "Python 3 (ipykernel)",
1796
  "language": "python",
1797
  "name": "python3"
1798
- },
1799
- "language_info": {
1800
- "codemirror_mode": {
1801
- "name": "ipython",
1802
- "version": 3
1803
- },
1804
- "file_extension": ".py",
1805
- "mimetype": "text/x-python",
1806
- "name": "python",
1807
- "nbconvert_exporter": "python",
1808
- "pygments_lexer": "ipython3",
1809
- "version": "3.11.11"
1810
  }
1811
  },
1812
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6387c9e1",
7
  "metadata": {},
8
+ "outputs": [],
9
+ "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "ca9233f0",
15
  "metadata": {},
16
  "outputs": [
17
  {
18
+ "data": {
19
+ "text/plain": [
20
+ "'/kaggle/working'"
21
+ ]
22
+ },
23
+ "execution_count": 16,
24
+ "metadata": {},
25
+ "output_type": "execute_result"
26
  }
27
  ],
28
+ "source": []
 
 
29
  },
30
  {
31
  "cell_type": "code",
32
+ "execution_count": 17,
33
+ "id": "3d2f98af",
34
  "metadata": {},
35
  "outputs": [
36
  {
37
  "name": "stdout",
38
  "output_type": "stream",
39
  "text": [
40
+ "\u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;32mtest_pairs.txt\u001b[0m* \u001b[01;34mtrain\u001b[0m/ \u001b[01;32mtrain_pairs.txt\u001b[0m*\n"
41
  ]
42
  }
43
  ],
44
  "source": [
45
+ "ls /kaggle/input/viton-hd-dataset"
46
  ]
47
  },
48
  {
49
  "cell_type": "code",
50
+ "execution_count": 18,
51
  "id": "dc0f36f4",
52
  "metadata": {},
53
  "outputs": [
 
56
  "output_type": "stream",
57
  "text": [
58
  "Cloning into 'stable-diffusion'...\n",
59
+ "remote: Enumerating objects: 150, done.\u001b[K\n",
60
+ "remote: Counting objects: 100% (150/150), done.\u001b[K\n",
61
+ "remote: Compressing objects: 100% (124/124), done.\u001b[K\n",
62
+ "remote: Total 150 (delta 36), reused 139 (delta 26), pack-reused 0 (from 0)\u001b[K\n",
63
+ "Receiving objects: 100% (150/150), 9.11 MiB | 20.74 MiB/s, done.\n",
64
+ "Resolving deltas: 100% (36/36), done.\n"
65
  ]
66
  }
67
  ],
 
71
  },
72
  {
73
  "cell_type": "code",
74
+ "execution_count": 19,
75
  "id": "a0bf01ab",
76
  "metadata": {},
77
  "outputs": [
 
79
  "name": "stdout",
80
  "output_type": "stream",
81
  "text": [
82
+ "/kaggle/working/stable-diffusion\n"
83
  ]
84
  }
85
  ],
 
89
  },
90
  {
91
  "cell_type": "code",
92
+ "execution_count": 20,
93
  "id": "1401cd56",
94
  "metadata": {},
95
  "outputs": [
 
97
  "name": "stdout",
98
  "output_type": "stream",
99
  "text": [
100
+ "--2025-06-13 07:07:34-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
101
+ "Resolving huggingface.co (huggingface.co)... 18.67.93.22, 18.67.93.63, 18.67.93.58, ...\n",
102
+ "Connecting to huggingface.co (huggingface.co)|18.67.93.22|:443... connected.\n",
103
  "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
104
  "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
105
+ "--2025-06-13 07:07:34-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
106
  "Reusing existing connection to huggingface.co:443.\n",
107
  "HTTP request sent, awaiting response... 302 Found\n",
108
+ "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
109
+ "--2025-06-13 07:07:35-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE\n",
110
+ "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 108.158.20.116, 108.158.20.30, 108.158.20.84, ...\n",
111
+ "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|108.158.20.116|:443... connected.\n",
112
  "HTTP request sent, awaiting response... 200 OK\n",
113
  "Length: 4265437280 (4.0G) [binary/octet-stream]\n",
114
  "Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
115
  "\n",
116
+ "sd-v1-5-inpainting. 100%[===================>] 3.97G 366MB/s in 12s \n",
117
  "\n",
118
+ "2025-06-13 07:07:46 (353 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
119
  "\n"
120
  ]
121
  }
 
126
  },
127
  {
128
  "cell_type": "code",
129
+ "execution_count": null,
130
  "id": "f7450c55",
131
  "metadata": {},
132
  "outputs": [
 
154
  }
155
  ],
156
  "source": [
157
+ "# !wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true "
158
  ]
159
  },
160
  {
 
318
  },
319
  {
320
  "cell_type": "code",
321
+ "execution_count": 4,
322
  "id": "91ef7a4e",
323
  "metadata": {},
324
  "outputs": [
325
  {
326
+ "data": {
327
+ "text/plain": [
328
+ "0"
329
+ ]
330
+ },
331
+ "execution_count": 4,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
 
 
 
334
  }
335
  ],
336
  "source": [
 
343
  },
344
  {
345
  "cell_type": "code",
346
+ "execution_count": 9,
347
  "id": "08f29055",
348
  "metadata": {},
349
  "outputs": [
 
351
  "name": "stdout",
352
  "output_type": "stream",
353
  "text": [
354
+ "GPU memory used: 0.00 MB / 16269.25 MB\n"
355
  ]
356
  }
357
  ],
 
378
  },
379
  {
380
  "cell_type": "code",
381
+ "execution_count": null,
382
  "id": "37335c1e",
383
  "metadata": {},
384
  "outputs": [],
 
398
  },
399
  {
400
  "cell_type": "code",
401
+ "execution_count": null,
402
  "id": "35d98b83",
403
  "metadata": {},
404
  "outputs": [],
 
432
  "id": "d7ff094a",
433
  "metadata": {},
434
  "outputs": [],
435
+ "source": [
436
+ "from torch.nn import functional as F\n",
437
+ "import torch\n",
438
+ "# from flash_attn import flash_attn_func\n",
439
+ "\n",
440
+ "class SkipAttnProcessor(torch.nn.Module):\n",
441
+ " def __init__(self, *args, **kwargs) -> None:\n",
442
+ " super().__init__()\n",
443
+ "\n",
444
+ " def __call__(\n",
445
+ " self,\n",
446
+ " attn,\n",
447
+ " hidden_states,\n",
448
+ " encoder_hidden_states=None,\n",
449
+ " attention_mask=None,\n",
450
+ " temb=None,\n",
451
+ " ):\n",
452
+ " return hidden_states\n",
453
+ "\n",
454
+ "class AttnProcessor2_0(torch.nn.Module):\n",
455
+ " r\"\"\"\n",
456
+ " Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n",
457
+ " \"\"\"\n",
458
+ "\n",
459
+ " def __init__(\n",
460
+ " self,\n",
461
+ " hidden_size=None,\n",
462
+ " cross_attention_dim=None,\n",
463
+ " **kwargs\n",
464
+ " ):\n",
465
+ " super().__init__()\n",
466
+ " if not hasattr(F, \"scaled_dot_product_attention\"):\n",
467
+ " raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n",
468
+ "\n",
469
+ " def __call__(\n",
470
+ " self,\n",
471
+ " attn,\n",
472
+ " hidden_states,\n",
473
+ " encoder_hidden_states=None,\n",
474
+ " attention_mask=None,\n",
475
+ " temb=None,\n",
476
+ " *args,\n",
477
+ " **kwargs,\n",
478
+ " ):\n",
479
+ " residual = hidden_states\n",
480
+ "\n",
481
+ " if attn.spatial_norm is not None:\n",
482
+ " hidden_states = attn.spatial_norm(hidden_states, temb)\n",
483
+ "\n",
484
+ " input_ndim = hidden_states.ndim\n",
485
+ "\n",
486
+ " if input_ndim == 4:\n",
487
+ " batch_size, channel, height, width = hidden_states.shape\n",
488
+ " hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
489
+ "\n",
490
+ " batch_size, sequence_length, _ = (\n",
491
+ " hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
492
+ " )\n",
493
+ "\n",
494
+ " if attention_mask is not None:\n",
495
+ " attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n",
496
+ " # scaled_dot_product_attention expects attention_mask shape to be\n",
497
+ " # (batch, heads, source_length, target_length)\n",
498
+ " attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n",
499
+ "\n",
500
+ " if attn.group_norm is not None:\n",
501
+ " hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
502
+ "\n",
503
+ " query = attn.to_q(hidden_states)\n",
504
+ "\n",
505
+ " if encoder_hidden_states is None:\n",
506
+ " encoder_hidden_states = hidden_states\n",
507
+ " elif attn.norm_cross:\n",
508
+ " encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n",
509
+ "\n",
510
+ " key = attn.to_k(encoder_hidden_states)\n",
511
+ " value = attn.to_v(encoder_hidden_states)\n",
512
+ "\n",
513
+ " inner_dim = key.shape[-1]\n",
514
+ " head_dim = inner_dim // attn.heads\n",
515
+ "\n",
516
+ " query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
517
+ "\n",
518
+ " key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
519
+ " value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
520
+ "\n",
521
+ " # the output of sdp = (batch, num_heads, seq_len, head_dim)\n",
522
+ " # TODO: add support for attn.scale when we move to Torch 2.1\n",
523
+ " \n",
524
+ " hidden_states = F.scaled_dot_product_attention(\n",
525
+ " query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n",
526
+ " )\n",
527
+ " # hidden_states = flash_attn_func(\n",
528
+ " # query, key, value, dropout_p=0.0, causal=False\n",
529
+ " # )\n",
530
+ "\n",
531
+ " hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n",
532
+ " hidden_states = hidden_states.to(query.dtype)\n",
533
+ "\n",
534
+ " # linear proj\n",
535
+ " hidden_states = attn.to_out[0](hidden_states)\n",
536
+ " # dropout\n",
537
+ " hidden_states = attn.to_out[1](hidden_states)\n",
538
+ "\n",
539
+ " if input_ndim == 4:\n",
540
+ " hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
541
+ "\n",
542
+ " if attn.residual_connection:\n",
543
+ " hidden_states = hidden_states + residual\n",
544
+ "\n",
545
+ " hidden_states = hidden_states / attn.rescale_output_factor\n",
546
+ "\n",
547
+ " return hidden_states\n",
548
+ " "
549
+ ]
550
  },
551
  {
552
  "cell_type": "code",
553
  "execution_count": null,
554
+ "id": "84a7fa87",
555
  "metadata": {},
556
  "outputs": [],
557
+ "source": [
558
+ "import os\n",
559
+ "import json\n",
560
+ "import torch\n",
561
+ "\n",
562
+ "def init_adapter(unet, \n",
563
+ " cross_attn_cls=SkipAttnProcessor,\n",
564
+ " self_attn_cls=None,\n",
565
+ " cross_attn_dim=None, \n",
566
+ " **kwargs):\n",
567
+ " if cross_attn_dim is None:\n",
568
+ " cross_attn_dim = unet.config.cross_attention_dim\n",
569
+ " attn_procs = {}\n",
570
+ " for name in unet.attn_processors.keys():\n",
571
+ " cross_attention_dim = None if name.endswith(\"attn1.processor\") else cross_attn_dim\n",
572
+ " if name.startswith(\"mid_block\"):\n",
573
+ " hidden_size = unet.config.block_out_channels[-1]\n",
574
+ " elif name.startswith(\"up_blocks\"):\n",
575
+ " block_id = int(name[len(\"up_blocks.\")])\n",
576
+ " hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n",
577
+ " elif name.startswith(\"down_blocks\"):\n",
578
+ " block_id = int(name[len(\"down_blocks.\")])\n",
579
+ " hidden_size = unet.config.block_out_channels[block_id]\n",
580
+ " if cross_attention_dim is None:\n",
581
+ " if self_attn_cls is not None:\n",
582
+ " attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
583
+ " else:\n",
584
+ " # retain the original attn processor\n",
585
+ " attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
586
+ " else:\n",
587
+ " attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
588
+ " \n",
589
+ " unet.set_attn_processor(attn_procs)\n",
590
+ " adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n",
591
+ " return adapter_modules\n",
592
+ "\n",
593
+ "def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):\n",
594
+ " from diffusers import AutoencoderKL\n",
595
+ " from transformers import CLIPTextModel, CLIPTokenizer\n",
596
+ "\n",
597
+ " text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder=\"text_encoder\")\n",
598
+ " vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder=\"vae\")\n",
599
+ " tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder=\"tokenizer\")\n",
600
+ " try:\n",
601
+ " unet_folder = os.path.join(diffusion_model_name_or_path, \"unet\")\n",
602
+ " unet_configs = json.load(open(os.path.join(unet_folder, \"config.json\"), \"r\"))\n",
603
+ " unet = unet_class(**unet_configs)\n",
604
+ " unet.load_state_dict(torch.load(os.path.join(unet_folder, \"diffusion_pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n",
605
+ " except:\n",
606
+ " unet = None\n",
607
+ " return text_encoder, vae, tokenizer, unet\n",
608
+ "\n",
609
+ "def attn_of_unet(unet):\n",
610
+ " attn_blocks = torch.nn.ModuleList()\n",
611
+ " for name, param in unet.named_modules():\n",
612
+ " if \"attn1\" in name:\n",
613
+ " attn_blocks.append(param)\n",
614
+ " return attn_blocks\n",
615
+ "\n",
616
+ "def get_trainable_module(unet, trainable_module_name):\n",
617
+ " if trainable_module_name == \"unet\":\n",
618
+ " return unet\n",
619
+ " elif trainable_module_name == \"transformer\":\n",
620
+ " trainable_modules = torch.nn.ModuleList()\n",
621
+ " for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:\n",
622
+ " if hasattr(blocks, \"attentions\"):\n",
623
+ " trainable_modules.append(blocks.attentions)\n",
624
+ " else:\n",
625
+ " for block in blocks:\n",
626
+ " if hasattr(block, \"attentions\"):\n",
627
+ " trainable_modules.append(block.attentions)\n",
628
+ " return trainable_modules\n",
629
+ " elif trainable_module_name == \"attention\":\n",
630
+ " attn_blocks = torch.nn.ModuleList()\n",
631
+ " for name, param in unet.named_modules():\n",
632
+ " if \"attn1\" in name:\n",
633
+ " attn_blocks.append(param)\n",
634
+ " return attn_blocks\n",
635
+ " else:\n",
636
+ " raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")\n",
637
+ "\n",
638
+ " \n",
639
+ " "
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "code",
644
+ "execution_count": null,
645
+ "id": "6028381d",
646
+ "metadata": {},
647
+ "outputs": [
648
+ {
649
+ "ename": "ModuleNotFoundError",
650
+ "evalue": "No module named 'model'",
651
+ "output_type": "error",
652
+ "traceback": [
653
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
654
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
655
+ "\u001b[0;32m/tmp/ipykernel_662/1349749640.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCLIPImageProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_processor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSkipAttnProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_trainable_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_adapter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
656
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'model'"
657
+ ]
658
+ }
659
+ ],
660
  "source": [
661
  "import inspect\n",
662
  "import os\n",
 
674
  "from huggingface_hub import snapshot_download\n",
675
  "from transformers import CLIPImageProcessor\n",
676
  "\n",
 
 
677
  "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
678
  " prepare_mask_image, resize_and_crop, resize_and_padding)\n",
679
  "from ddpm import DDPMSampler\n",
 
952
  "from diffusers.image_processor import VaeImageProcessor\n",
953
  "from tqdm import tqdm\n",
954
  "from PIL import Image, ImageFilter\n",
955
+ "import load_model\n",
956
  "\n",
957
  "from utils import repaint, to_pil_image\n",
958
  " \n",
 
1118
  " \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n",
1119
  " \"resume_path\": \"zhengchong/CatVTON\",\n",
1120
  " \"dataset_name\": \"vitonhd\",\n",
1121
+ " # \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
1122
+ " \"data_root_path\": \"/kaggle/working/stable-diffusion/sample_dataset\",\n",
1123
  " \"output_dir\": \"./output\",\n",
1124
  " \"seed\": 555,\n",
1125
  " \"batch_size\": 2,\n",
 
1134
  " \"dataloader_num_workers\": 4,\n",
1135
  " \"mixed_precision\": 'no',\n",
1136
  " \"concat_axis\": 'y',\n",
1137
+ " \"enable_condition_noise\": True,\n",
1138
+ " \"is_train\": False\n",
1139
  " }\n",
1140
  "\n",
1141
+ " models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"/kaggle/working/stable-diffusion/checkpoints/checkpoint_epoch_10.pth\")\n",
1142
  "\n",
1143
  " # Pipeline\n",
1144
  " pipeline = CatVTONPipeline(\n",
 
1994
  "display_name": "Python 3 (ipykernel)",
1995
  "language": "python",
1996
  "name": "python3"
 
 
 
 
 
 
 
 
 
 
 
 
1997
  }
1998
  },
1999
  "nbformat": 4,