{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "6387c9e1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ca9233f0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/kaggle/working'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": 17, "id": "3d2f98af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0m\u001b[01;34mtest\u001b[0m/ \u001b[01;32mtest_pairs.txt\u001b[0m* \u001b[01;34mtrain\u001b[0m/ \u001b[01;32mtrain_pairs.txt\u001b[0m*\n" ] } ], "source": [ "ls /kaggle/input/viton-hd-dataset" ] }, { "cell_type": "code", "execution_count": 18, "id": "dc0f36f4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'stable-diffusion'...\n", "remote: Enumerating objects: 150, done.\u001b[K\n", "remote: Counting objects: 100% (150/150), done.\u001b[K\n", "remote: Compressing objects: 100% (124/124), done.\u001b[K\n", "remote: Total 150 (delta 36), reused 139 (delta 26), pack-reused 0 (from 0)\u001b[K\n", "Receiving objects: 100% (150/150), 9.11 MiB | 20.74 MiB/s, done.\n", "Resolving deltas: 100% (36/36), done.\n" ] } ], "source": [ "!git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git" ] }, { "cell_type": "code", "execution_count": 19, "id": "a0bf01ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/stable-diffusion\n" ] } ], "source": [ "cd stable-diffusion/" ] }, { "cell_type": "code", "execution_count": 20, "id": "1401cd56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2025-06-13 07:07:34-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", "Resolving huggingface.co (huggingface.co)... 18.67.93.22, 18.67.93.63, 18.67.93.58, ...\n", "Connecting to huggingface.co (huggingface.co)|18.67.93.22|:443... connected.\n", "HTTP request sent, awaiting response... 307 Temporary Redirect\n", "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n", "--2025-06-13 07:07:34-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n", "Reusing existing connection to huggingface.co:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", "--2025-06-13 07:07:35-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1749802055&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTgwMjA1NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=m4Xzc4SaPX28SXT9wK8qPXBWIr7uKmVt6iq2D3qMALrJWCfm1a4FHoshqkXLDrchchDIkAImr7l-yDlAv9x15JsX09FidLsSVU8UXS4a%7Em69hgWMTgloTObR3HlTwY9EQ7t%7ErneASRUS5r%7E2szyfyrlN-n4-U9QWCmyOikaumCc0PbAHE6lRNcy7FSCTxQGM48h%7EQBZ37iQArWW2JC%7E-apwm1knzGt422ywPlQws2qREoUeCPoXFWKl-iX1%7EqDimjSepdm2ZGt-COfekmJddQWXuCQAj7uY5YKcE3qEt7IBcaj96MNbF8b2qxTNbLrzgXioIzl0SIw8Ws-YUOu5I3A__&Key-Pair-Id=K3RPWS32NSSJCE\n", "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 108.158.20.116, 108.158.20.30, 108.158.20.84, ...\n", "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|108.158.20.116|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 4265437280 (4.0G) [binary/octet-stream]\n", "Saving to: ‘sd-v1-5-inpainting.ckpt’\n", "\n", "sd-v1-5-inpainting. 100%[===================>] 3.97G 366MB/s in 12s \n", "\n", "2025-06-13 07:07:46 (353 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n", "\n" ] } ], "source": [ "!wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt" ] }, { "cell_type": "code", "execution_count": null, "id": "f7450c55", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2025-06-11 10:33:19-- https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true\n", "Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.90, 3.163.189.114, ...\n", "Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9 [following]\n", "--2025-06-11 10:33:19-- https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9\n", "Resolving cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)... 18.238.238.75, 18.238.238.106, 18.238.238.119, ...\n", "Connecting to cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)|18.238.238.75|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 198303368 (189M) [binary/octet-stream]\n", "Saving to: ‘model.safetensors?download=true’\n", "\n", "model.safetensors?d 100%[===================>] 189.12M 298MB/s in 0.6s \n", "\n", "2025-06-11 10:33:20 (298 MB/s) - ‘model.safetensors?download=true’ saved [198303368/198303368]\n", "\n" ] } ], "source": [ "# !wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true " ] }, { "cell_type": "code", "execution_count": 11, "id": "ca20c487", "metadata": {}, "outputs": [], "source": [ "mv 'model.safetensors?download=true' model.safetensors" ] }, { "cell_type": "code", "execution_count": 12, "id": "6d0a1287", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "attention.py encoder.py\t model.safetensors sd-v1-5-inpainting.ckpt\n", "clip.py interface.py\t pipeline.py\t test.ipynb\n", "ddpm.py merges.txt\t README.md\t vocab.json\n", "decoder.py model_converter.py requirements.txt\n", "diffusion.py model.py\t\t sample_dataset\n" ] } ], "source": [ "!ls" ] }, { "cell_type": "code", "execution_count": 14, "id": "8f11470e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/stable-diffusion/CatVTON\n" ] } ], "source": [ "cd .." ] }, { "cell_type": "code", "execution_count": 15, "id": "cb794cb3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "app_flux.py eval.py preprocess_agnostic_mask.py \u001b[0m\u001b[01;34mstable-diffusion\u001b[0m/\n", "app_p2p.py index.html \u001b[01;34m__pycache__\u001b[0m/ utils.py\n", "app.py inference.py README.md\n", "\u001b[01;34mdensepose\u001b[0m/ LICENSE requirements.txt\n", "\u001b[01;34mdetectron2\u001b[0m/ \u001b[01;34mmodel\u001b[0m/ \u001b[01;34mresource\u001b[0m/\n" ] } ], "source": [ "ls" ] }, { "cell_type": "code", "execution_count": 16, "id": "b6af145b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "\n", "src_dir = \"./stable-diffusion\"\n", "dst_dir = \".\"\n", "\n", "for filename in os.listdir(src_dir):\n", " src_path = os.path.join(src_dir, filename)\n", " dst_path = os.path.join(dst_dir, filename)\n", " if os.path.isfile(src_path):\n", " shutil.move(src_path, dst_path)\n", " elif os.path.isdir(src_path):\n", " shutil.move(src_path, dst_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "63ee438c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "60598bd3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 229, "id": "192a649c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import gc\n", "\n", "# Clear CUDA cache and collect garbage\n", "torch.cuda.empty_cache()\n", "gc.collect()\n", "\n", "# Delete all user-defined variables except for built-ins and modules\n", "for var in list(globals()):\n", " if not var.startswith(\"__\") and var not in [\"torch\", \"gc\"]:\n", " del globals()[var]\n", "\n", "gc.collect()\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 245, "id": "a3a4a5dc", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import gc\n", "\n", "# Clear CUDA cache and collect garbage\n", "torch.cuda.empty_cache()\n", "gc.collect()\n", "\n", "# Delete all user-defined variables except for built-ins and modules\n", "for var_name in list(globals()):\n", " if not var_name.startswith(\"__\") and var_name not in [\"torch\", \"gc\"]:\n", " del globals()[var_name]\n", "\n", "gc.collect()\n", "torch.cuda.empty_cache()\n", "\n", "import tensorflow as tf\n", "tf.keras.backend.clear_session()" ] }, { "cell_type": "code", "execution_count": 4, "id": "91ef7a4e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import gc\n", "\n", "torch.cuda.empty_cache() # Release unused GPU memory\n", "gc.collect() # Run Python garbage collector" ] }, { "cell_type": "code", "execution_count": 9, "id": "08f29055", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPU memory used: 0.00 MB / 16269.25 MB\n" ] } ], "source": [ "import torch\n", "\n", "if torch.cuda.is_available():\n", " used = torch.cuda.memory_allocated() / 1024 ** 2 # in MB\n", " total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2 # in MB\n", " print(f\"GPU memory used: {used:.2f} MB / {total:.2f} MB\")\n", "else:\n", " print(\"CUDA is not available.\")" ] }, { "cell_type": "code", "execution_count": 197, "id": "6fbde810", "metadata": {}, "outputs": [], "source": [ "# rm -rf output" ] }, { "cell_type": "code", "execution_count": null, "id": "37335c1e", "metadata": {}, "outputs": [], "source": [ "def compute_vae_encodings(image_tensor, encoder, device=\"cuda\"):\n", " \"\"\"Encode image using VAE encoder\"\"\"\n", " # Generate random noise for encoding\n", " encoder_noise = torch.randn(\n", " (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n", " device=device,\n", " )\n", " \n", " # Encode using your custom encoder\n", " latent = encoder(image_tensor, encoder_noise)\n", " return latent" ] }, { "cell_type": "code", "execution_count": null, "id": "35d98b83", "metadata": {}, "outputs": [], "source": [ "def get_trainable_module(unet, trainable_module_name):\n", " if trainable_module_name == \"unet\":\n", " return unet\n", " elif trainable_module_name == \"transformer\":\n", " trainable_modules = torch.nn.ModuleList()\n", " for blocks in [unet.encoders, unet.bottleneck, unet.decoders]:\n", " if hasattr(blocks, \"attentions\"):\n", " trainable_modules.append(blocks.attentions)\n", " else:\n", " for block in blocks:\n", " if hasattr(block, \"attentions\"):\n", " trainable_modules.append(block.attentions)\n", " return trainable_modules\n", " elif trainable_module_name == \"attention\":\n", " attn_blocks = torch.nn.ModuleList()\n", " for name, param in unet.named_modules():\n", " if \"attention_1\" in name:\n", " attn_blocks.append(param)\n", " return attn_blocks\n", " else:\n", " raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d7ff094a", "metadata": {}, "outputs": [], "source": [ "from torch.nn import functional as F\n", "import torch\n", "# from flash_attn import flash_attn_func\n", "\n", "class SkipAttnProcessor(torch.nn.Module):\n", " def __init__(self, *args, **kwargs) -> None:\n", " super().__init__()\n", "\n", " def __call__(\n", " self,\n", " attn,\n", " hidden_states,\n", " encoder_hidden_states=None,\n", " attention_mask=None,\n", " temb=None,\n", " ):\n", " return hidden_states\n", "\n", "class AttnProcessor2_0(torch.nn.Module):\n", " r\"\"\"\n", " Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " hidden_size=None,\n", " cross_attention_dim=None,\n", " **kwargs\n", " ):\n", " super().__init__()\n", " if not hasattr(F, \"scaled_dot_product_attention\"):\n", " raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n", "\n", " def __call__(\n", " self,\n", " attn,\n", " hidden_states,\n", " encoder_hidden_states=None,\n", " attention_mask=None,\n", " temb=None,\n", " *args,\n", " **kwargs,\n", " ):\n", " residual = hidden_states\n", "\n", " if attn.spatial_norm is not None:\n", " hidden_states = attn.spatial_norm(hidden_states, temb)\n", "\n", " input_ndim = hidden_states.ndim\n", "\n", " if input_ndim == 4:\n", " batch_size, channel, height, width = hidden_states.shape\n", " hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n", "\n", " batch_size, sequence_length, _ = (\n", " hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n", " )\n", "\n", " if attention_mask is not None:\n", " attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n", " # scaled_dot_product_attention expects attention_mask shape to be\n", " # (batch, heads, source_length, target_length)\n", " attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n", "\n", " if attn.group_norm is not None:\n", " hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n", "\n", " query = attn.to_q(hidden_states)\n", "\n", " if encoder_hidden_states is None:\n", " encoder_hidden_states = hidden_states\n", " elif attn.norm_cross:\n", " encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n", "\n", " key = attn.to_k(encoder_hidden_states)\n", " value = attn.to_v(encoder_hidden_states)\n", "\n", " inner_dim = key.shape[-1]\n", " head_dim = inner_dim // attn.heads\n", "\n", " query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", "\n", " key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", " value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", "\n", " # the output of sdp = (batch, num_heads, seq_len, head_dim)\n", " # TODO: add support for attn.scale when we move to Torch 2.1\n", " \n", " hidden_states = F.scaled_dot_product_attention(\n", " query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n", " )\n", " # hidden_states = flash_attn_func(\n", " # query, key, value, dropout_p=0.0, causal=False\n", " # )\n", "\n", " hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n", " hidden_states = hidden_states.to(query.dtype)\n", "\n", " # linear proj\n", " hidden_states = attn.to_out[0](hidden_states)\n", " # dropout\n", " hidden_states = attn.to_out[1](hidden_states)\n", "\n", " if input_ndim == 4:\n", " hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n", "\n", " if attn.residual_connection:\n", " hidden_states = hidden_states + residual\n", "\n", " hidden_states = hidden_states / attn.rescale_output_factor\n", "\n", " return hidden_states\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "84a7fa87", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import torch\n", "\n", "def init_adapter(unet, \n", " cross_attn_cls=SkipAttnProcessor,\n", " self_attn_cls=None,\n", " cross_attn_dim=None, \n", " **kwargs):\n", " if cross_attn_dim is None:\n", " cross_attn_dim = unet.config.cross_attention_dim\n", " attn_procs = {}\n", " for name in unet.attn_processors.keys():\n", " cross_attention_dim = None if name.endswith(\"attn1.processor\") else cross_attn_dim\n", " if name.startswith(\"mid_block\"):\n", " hidden_size = unet.config.block_out_channels[-1]\n", " elif name.startswith(\"up_blocks\"):\n", " block_id = int(name[len(\"up_blocks.\")])\n", " hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n", " elif name.startswith(\"down_blocks\"):\n", " block_id = int(name[len(\"down_blocks.\")])\n", " hidden_size = unet.config.block_out_channels[block_id]\n", " if cross_attention_dim is None:\n", " if self_attn_cls is not None:\n", " attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", " else:\n", " # retain the original attn processor\n", " attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", " else:\n", " attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n", " \n", " unet.set_attn_processor(attn_procs)\n", " adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n", " return adapter_modules\n", "\n", "def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):\n", " from diffusers import AutoencoderKL\n", " from transformers import CLIPTextModel, CLIPTokenizer\n", "\n", " text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder=\"text_encoder\")\n", " vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder=\"vae\")\n", " tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder=\"tokenizer\")\n", " try:\n", " unet_folder = os.path.join(diffusion_model_name_or_path, \"unet\")\n", " unet_configs = json.load(open(os.path.join(unet_folder, \"config.json\"), \"r\"))\n", " unet = unet_class(**unet_configs)\n", " unet.load_state_dict(torch.load(os.path.join(unet_folder, \"diffusion_pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n", " except:\n", " unet = None\n", " return text_encoder, vae, tokenizer, unet\n", "\n", "def attn_of_unet(unet):\n", " attn_blocks = torch.nn.ModuleList()\n", " for name, param in unet.named_modules():\n", " if \"attn1\" in name:\n", " attn_blocks.append(param)\n", " return attn_blocks\n", "\n", "def get_trainable_module(unet, trainable_module_name):\n", " if trainable_module_name == \"unet\":\n", " return unet\n", " elif trainable_module_name == \"transformer\":\n", " trainable_modules = torch.nn.ModuleList()\n", " for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:\n", " if hasattr(blocks, \"attentions\"):\n", " trainable_modules.append(blocks.attentions)\n", " else:\n", " for block in blocks:\n", " if hasattr(block, \"attentions\"):\n", " trainable_modules.append(block.attentions)\n", " return trainable_modules\n", " elif trainable_module_name == \"attention\":\n", " attn_blocks = torch.nn.ModuleList()\n", " for name, param in unet.named_modules():\n", " if \"attn1\" in name:\n", " attn_blocks.append(param)\n", " return attn_blocks\n", " else:\n", " raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")\n", "\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "6028381d", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'model'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_662/1349749640.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCLIPImageProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn_processor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSkipAttnProcessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_trainable_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_adapter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'model'" ] } ], "source": [ "import inspect\n", "import os\n", "from typing import Union\n", "\n", "import PIL\n", "import numpy as np\n", "import torch\n", "import tqdm\n", "from accelerate import load_checkpoint_in_model\n", "from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel\n", "from diffusers.pipelines.stable_diffusion.safety_checker import \\\n", " StableDiffusionSafetyChecker\n", "from diffusers.utils.torch_utils import randn_tensor\n", "from huggingface_hub import snapshot_download\n", "from transformers import CLIPImageProcessor\n", "\n", "from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n", " prepare_mask_image, resize_and_crop, resize_and_padding)\n", "from ddpm import DDPMSampler\n", "\n", "class CatVTONPipeline:\n", " def __init__(\n", " self, \n", " base_ckpt, \n", " attn_ckpt, \n", " attn_ckpt_version=\"mix\",\n", " weight_dtype=torch.float32,\n", " device='cuda',\n", " compile=False,\n", " skip_safety_check=True,\n", " use_tf32=True,\n", " models={},\n", " ):\n", " self.device = device\n", " self.weight_dtype = weight_dtype\n", " self.skip_safety_check = skip_safety_check\n", " self.models = models\n", "\n", " self.generator = torch.Generator(device=device)\n", " self.noise_scheduler = DDPMSampler(generator=self.generator)\n", " # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n", " self.encoder= models.get('encoder', None)\n", " self.decoder= models.get('decoder', None)\n", " if not skip_safety_check:\n", " self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder=\"feature_extractor\")\n", " self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder=\"safety_checker\").to(device, dtype=weight_dtype)\n", " self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder=\"unet\").to(device, dtype=weight_dtype)\n", " # self.unet=models.get('diffusion', None)\n", " init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention\n", " self.attn_modules = get_trainable_module(self.unet, \"attention\")\n", " self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)\n", " # Pytorch 2.0 Compile\n", " # if compile:\n", " # self.unet = torch.compile(self.unet)\n", " # self.vae = torch.compile(self.vae, mode=\"reduce-overhead\")\n", " \n", " # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n", " if use_tf32:\n", " torch.set_float32_matmul_precision(\"high\")\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", " def auto_attn_ckpt_load(self, attn_ckpt, version):\n", " sub_folder = {\n", " \"mix\": \"mix-48k-1024\",\n", " \"vitonhd\": \"vitonhd-16k-512\",\n", " \"dresscode\": \"dresscode-16k-512\",\n", " }[version]\n", " if os.path.exists(attn_ckpt):\n", " load_checkpoint_in_model(self.attn_modules, os.path.join(attn_ckpt, sub_folder, 'attention'))\n", " else:\n", " repo_path = snapshot_download(repo_id=attn_ckpt)\n", " print(f\"Downloaded {attn_ckpt} to {repo_path}\")\n", " load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention'))\n", " \n", " def run_safety_checker(self, image):\n", " if self.safety_checker is None:\n", " has_nsfw_concept = None\n", " else:\n", " safety_checker_input = self.feature_extractor(image, return_tensors=\"pt\").to(self.device)\n", " image, has_nsfw_concept = self.safety_checker(\n", " images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)\n", " )\n", " return image, has_nsfw_concept\n", " \n", " def prepare_extra_step_kwargs(self, generator, eta):\n", " # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n", " # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n", " # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n", " # and should be between [0, 1]\n", "\n", " accepts_eta = \"eta\" in set(\n", " inspect.signature(self.noise_scheduler.step).parameters.keys()\n", " )\n", " extra_step_kwargs = {}\n", " if accepts_eta:\n", " extra_step_kwargs[\"eta\"] = eta\n", "\n", " # check if the scheduler accepts generator\n", " accepts_generator = \"generator\" in set(\n", " inspect.signature(self.noise_scheduler.step).parameters.keys()\n", " )\n", " if accepts_generator:\n", " extra_step_kwargs[\"generator\"] = generator\n", " return extra_step_kwargs\n", "\n", " @torch.no_grad()\n", " def __call__(\n", " self, \n", " image: Union[PIL.Image.Image, torch.Tensor],\n", " condition_image: Union[PIL.Image.Image, torch.Tensor],\n", " mask: Union[PIL.Image.Image, torch.Tensor],\n", " num_inference_steps: int = 50,\n", " guidance_scale: float = 2.5,\n", " height: int = 1024,\n", " width: int = 768,\n", " generator=None,\n", " eta=1.0,\n", " **kwargs\n", " ):\n", " concat_dim = -2 # FIXME: y axis concat\n", " # Prepare inputs to Tensor\n", " image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n", " image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n", " condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n", " mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n", " # Mask image\n", " masked_image = image * (mask < 0.5)\n", " # VAE encoding\n", " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n", " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n", " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n", " del image, mask, condition_image\n", " # Concatenate latents\n", " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n", " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n", " # Prepare noise\n", " latents = randn_tensor(\n", " masked_latent_concat.shape,\n", " generator=generator,\n", " device=masked_latent_concat.device,\n", " dtype=self.weight_dtype,\n", " )\n", " # Prepare timesteps\n", " self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n", " timesteps = self.noise_scheduler.timesteps\n", " # latents = latents * self.noise_scheduler.init_noise_sigma\n", " latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n", " # Classifier-Free Guidance\n", " if do_classifier_free_guidance := (guidance_scale > 1.0):\n", " masked_latent_concat = torch.cat(\n", " [\n", " torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n", " masked_latent_concat,\n", " ]\n", " )\n", " mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n", "\n", " # Denoising loop\n", " # extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n", " # num_warmup_steps = (len(timesteps) - num_inference_steps * self.noise_scheduler.order)\n", " num_warmup_steps = 0 # For simple DDPM, no warmup needed\n", " with tqdm(total=num_inference_steps) as progress_bar:\n", " for i, t in enumerate(timesteps):\n", " # expand the latents if we are doing classifier free guidance\n", " non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n", " # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n", " # prepare the input for the inpainting model\n", " inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n", " # predict the noise residual\n", " # time_embedding = get_time_embedding(t.item())\n", " # time_embedding = time_embedding.repeat(inpainting_latent_model_input.shape[0], 1).to(self.device, dtype=self.weight_dtype)\n", " noise_pred= self.unet(\n", " inpainting_latent_model_input,\n", " # time_embedding\n", " t.to(self.device),\n", " encoder_hidden_states=None, # FIXME\n", " return_dict=False,\n", " )[0]\n", " # perform guidance\n", " if do_classifier_free_guidance:\n", " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (\n", " noise_pred_text - noise_pred_uncond\n", " )\n", " # compute the previous noisy sample x_t -> x_t-1\n", " latents = self.noise_scheduler.step(\n", " t, latents, noise_pred\n", " )\n", " # call the callback, if provided\n", " if i == len(timesteps) - 1 or (\n", " (i + 1) > num_warmup_steps\n", " ):\n", " progress_bar.update()\n", "\n", " # Decode the final latents\n", " latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n", " # latents = 1 / self.vae.config.scaling_factor * latents\n", " # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n", " image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n", " image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n", " image = numpy_to_pil(image)\n", " \n", " # Safety Check\n", " if not self.skip_safety_check:\n", " current_script_directory = os.path.dirname(os.path.realpath(__file__))\n", " nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')\n", " nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)\n", " image_np = np.array(image)\n", " _, has_nsfw_concept = self.run_safety_checker(image=image_np)\n", " for i, not_safe in enumerate(has_nsfw_concept):\n", " if not_safe:\n", " image[i] = nsfw_image\n", " return image\n" ] }, { "cell_type": "code", "execution_count": null, "id": "94e19198", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "An error occurred while trying to fetch booksforcharlie/stable-diffusion-inpainting: booksforcharlie/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.\n", "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "912125b29fef4b31aff0e4433b03b876", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 12 files: 0%| | 0/12 [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_72/4184774867.py\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 303\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mask'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 305\u001b[0;31m results = pipeline(\n\u001b[0m\u001b[1;32m 306\u001b[0m \u001b[0mperson_images\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0mcloth_images\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_72/4282996458.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0minpainting_latent_model_input\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# time_embedding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0mencoder_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# FIXME\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import os\n", "import numpy as np\n", "import torch\n", "import argparse\n", "from torch.utils.data import Dataset, DataLoader\n", "from VITON_Dataset import VITONHDTestDataset\n", "from diffusers.image_processor import VaeImageProcessor\n", "from tqdm import tqdm\n", "from PIL import Image, ImageFilter\n", "import load_model\n", "\n", "from utils import repaint, to_pil_image\n", " \n", "def parse_args():\n", " parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n", " parser.add_argument(\n", " \"--base_model_path\",\n", " type=str,\n", " default=\"booksforcharlie/stable-diffusion-inpainting\", # Change to a copy repo as runawayml delete original repo\n", " help=(\n", " \"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--resume_path\",\n", " type=str,\n", " default=\"zhengchong/CatVTON\",\n", " help=(\n", " \"The Path to the checkpoint of trained tryon model.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--dataset_name\",\n", " type=str,\n", " required=True,\n", " help=\"The datasets to use for evaluation.\",\n", " )\n", " parser.add_argument(\n", " \"--data_root_path\", \n", " type=str, \n", " required=True,\n", " help=\"Path to the dataset to evaluate.\"\n", " )\n", " parser.add_argument(\n", " \"--output_dir\",\n", " type=str,\n", " default=\"output\",\n", " help=\"The output directory where the model predictions will be written.\",\n", " )\n", "\n", " parser.add_argument(\n", " \"--seed\", type=int, default=555, help=\"A seed for reproducible evaluation.\"\n", " )\n", " parser.add_argument(\n", " \"--batch_size\", type=int, default=8, help=\"The batch size for evaluation.\"\n", " )\n", " \n", " parser.add_argument(\n", " \"--num_inference_steps\",\n", " type=int,\n", " default=50,\n", " help=\"Number of inference steps to perform.\",\n", " )\n", " parser.add_argument(\n", " \"--guidance_scale\",\n", " type=float,\n", " default=2.5,\n", " help=\"The scale of classifier-free guidance for inference.\",\n", " )\n", "\n", " parser.add_argument(\n", " \"--width\",\n", " type=int,\n", " default=384,\n", " help=(\n", " \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n", " \" resolution\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--height\",\n", " type=int,\n", " default=512,\n", " help=(\n", " \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n", " \" resolution\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--repaint\", \n", " action=\"store_true\", \n", " help=\"Whether to repaint the result image with the original background.\"\n", " )\n", " parser.add_argument(\n", " \"--eval_pair\",\n", " action=\"store_true\",\n", " help=\"Whether or not to evaluate the pair.\",\n", " )\n", " parser.add_argument(\n", " \"--concat_eval_results\",\n", " action=\"store_true\",\n", " help=\"Whether or not to concatenate the all conditions into one image.\",\n", " )\n", " parser.add_argument(\n", " \"--allow_tf32\",\n", " action=\"store_true\",\n", " default=True,\n", " help=(\n", " \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n", " \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--dataloader_num_workers\",\n", " type=int,\n", " default=8,\n", " help=(\n", " \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n", " ),\n", " )\n", " parser.add_argument(\n", " \"--mixed_precision\",\n", " type=str,\n", " default=\"bf16\",\n", " choices=[\"no\", \"fp16\", \"bf16\"],\n", " help=(\n", " \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n", " \" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the\"\n", " \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n", " ),\n", " )\n", "\n", " parser.add_argument(\n", " \"--concat_axis\",\n", " type=str,\n", " choices=[\"x\", \"y\", 'random'],\n", " default=\"y\",\n", " help=\"The axis to concat the cloth feature, select from ['x', 'y', 'random'].\",\n", " )\n", " parser.add_argument(\n", " \"--enable_condition_noise\",\n", " action=\"store_true\",\n", " default=True,\n", " help=\"Whether or not to enable condition noise.\",\n", " )\n", " \n", " args = parser.parse_args()\n", " env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n", " if env_local_rank != -1 and env_local_rank != args.local_rank:\n", " args.local_rank = env_local_rank\n", "\n", " return args\n", "\n", "@torch.no_grad()\n", "def main():\n", " # args = parse_args()\n", "\n", " # Replace with your actual data root and output directory paths\n", " # !CUDA_VISIBLE_DEVICES=0 python inference.py \\\n", " # --dataset vitonhd \\\n", " # --data_root_path /kaggle/input/viton-hd-dataset \\\n", " # --output_dir ./output \\\n", " # --dataloader_num_workers 8 \\\n", " # --batch_size 8 \\\n", " # --seed 555 \\\n", " # --mixed_precision no \\\n", " # --allow_tf32 \\\n", " # --repaint \\\n", " # --eval_pair\n", " \n", " args=argparse.Namespace()\n", " args.__dict__= {\n", " \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n", " \"resume_path\": \"zhengchong/CatVTON\",\n", " \"dataset_name\": \"vitonhd\",\n", " # \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n", " \"data_root_path\": \"/kaggle/working/stable-diffusion/sample_dataset\",\n", " \"output_dir\": \"./output\",\n", " \"seed\": 555,\n", " \"batch_size\": 2,\n", " \"num_inference_steps\": 50,\n", " \"guidance_scale\": 2.5,\n", " \"width\": 384,\n", " \"height\": 512,\n", " \"repaint\": True,\n", " \"eval_pair\": False,\n", " \"concat_eval_results\": True,\n", " \"allow_tf32\": True,\n", " \"dataloader_num_workers\": 4,\n", " \"mixed_precision\": 'no',\n", " \"concat_axis\": 'y',\n", " \"enable_condition_noise\": True,\n", " \"is_train\": False\n", " }\n", "\n", " models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"/kaggle/working/stable-diffusion/checkpoints/checkpoint_epoch_10.pth\")\n", "\n", " # Pipeline\n", " pipeline = CatVTONPipeline(\n", " attn_ckpt_version=args.dataset_name,\n", " attn_ckpt=args.resume_path,\n", " base_ckpt=args.base_model_path,\n", " weight_dtype={\n", " \"no\": torch.float32,\n", " \"fp16\": torch.float16,\n", " \"bf16\": torch.bfloat16,\n", " }[args.mixed_precision],\n", " device=\"cuda\",\n", " skip_safety_check=True,\n", " models=models,\n", " )\n", " # Dataset\n", " if args.dataset_name == \"vitonhd\":\n", " dataset = VITONHDTestDataset(args)\n", " else:\n", " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n", " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n", " dataloader = DataLoader(\n", " dataset,\n", " batch_size=args.batch_size,\n", " shuffle=False,\n", " num_workers=args.dataloader_num_workers\n", " )\n", " # Inference\n", " generator = torch.Generator(device='cuda').manual_seed(args.seed)\n", " args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n", " if not os.path.exists(args.output_dir):\n", " os.makedirs(args.output_dir)\n", " \n", " for batch in tqdm(dataloader):\n", " person_images = batch['person']\n", " cloth_images = batch['cloth']\n", " masks = batch['mask']\n", "\n", " results = pipeline(\n", " person_images,\n", " cloth_images,\n", " masks,\n", " num_inference_steps=args.num_inference_steps,\n", " guidance_scale=args.guidance_scale,\n", " height=args.height,\n", " width=args.width,\n", " generator=generator,\n", " )\n", " \n", " if args.concat_eval_results or args.repaint:\n", " person_images = to_pil_image(person_images)\n", " cloth_images = to_pil_image(cloth_images)\n", " masks = to_pil_image(masks)\n", " for i, result in enumerate(results):\n", " person_name = batch['person_name'][i]\n", " output_path = os.path.join(args.output_dir, person_name)\n", " if not os.path.exists(os.path.dirname(output_path)):\n", " os.makedirs(os.path.dirname(output_path))\n", " if args.repaint:\n", " person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n", " person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n", " mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n", " result = repaint(person_image, mask, result)\n", " if args.concat_eval_results:\n", " w, h = result.size\n", " concated_result = Image.new('RGB', (w*3, h))\n", " concated_result.paste(person_images[i], (0, 0))\n", " concated_result.paste(cloth_images[i], (w, 0)) \n", " concated_result.paste(result, (w*2, 0))\n", " result = concated_result\n", " result.save(output_path)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "id": "5c2d9f98", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "143d0ef9", "metadata": {}, "outputs": [], "source": [ "# rm -rf output" ] }, { "cell_type": "code", "execution_count": 37, "id": "77c56140", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "05006_00.jpg 11001_00.jpg\n" ] } ], "source": [ "import sys\n", "f='/kaggle/input/viton-hd-dataset/test_pairs.txt'\n", "with open(f, 'r') as file:\n", " lines = file.readlines()\n", "person_img, cloth_img = lines[0].strip().split(\" \")\n", "mask_img = person_img\n", "\n", "print(person_img, cloth_img)" ] }, { "cell_type": "code", "execution_count": 38, "id": "0fdf30ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "app_flux.py eval.py preprocess_agnostic_mask.py\n", "app_p2p.py index.html \u001b[0m\u001b[01;34m__pycache__\u001b[0m/\n", "app.py inference.py README.md\n", "attention.py interface.py requirements.txt\n", "clip.py LICENSE \u001b[01;34mresource\u001b[0m/\n", "ddpm.py merges.txt \u001b[01;34msample_dataset\u001b[0m/\n", "decoder.py \u001b[01;34mmodel\u001b[0m/ sd-v1-5-inpainting.ckpt\n", "\u001b[01;34mdensepose\u001b[0m/ model_converter.py \u001b[01;34mstable-diffusion\u001b[0m/\n", "\u001b[01;34mdetectron2\u001b[0m/ model.safetensors test.ipynb\n", "diffusion.py \u001b[01;34moutput\u001b[0m/ utils.py\n", "encoder.py pipeline.py vocab.json\n" ] } ], "source": [ "ls" ] }, { "cell_type": "code", "execution_count": null, "id": "d4063d0b", "metadata": {}, "outputs": [], "source": [ "# rm -rf output" ] }, { "cell_type": "code", "execution_count": 97, "id": "52e3cd56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "02532_00.jpg 03921_00.jpg 08088_00.jpg 12419_00.jpg\n", "03191_00.jpg 05006_00.jpg 08650_00.jpg 12562_00.jpg\n" ] } ], "source": [ "ls ./output/vitonhd-512/unpaired" ] }, { "cell_type": "code", "execution_count": 98, "id": "ac7340f8", "metadata": {}, "outputs": [ { "ename": "KeyError", "evalue": "'_oh'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_71/1176057974.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyError\u001b[0m: '_oh'" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from PIL import Image\n", "img=Image.open('./output/vitonhd-512/unpaired/12562_00.jpg')\n", "\n", "import matplotlib.pyplot as plt\n", "plt.imshow(img)" ] }, { "cell_type": "code", "execution_count": 25, "id": "86b70586", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-0.5, 767.5, 1023.5, -0.5)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from PIL import Image\n", "\n", "person_image = Image.open(f\"/kaggle/input/viton-hd-dataset/test/image/{person_img}\").convert(\"RGB\")\n", "cloth_image = Image.open(f\"/kaggle/input/viton-hd-dataset/test/cloth/{cloth_img}\").convert(\"RGB\")\n", "mask_image = Image.open(f\"/kaggle/input/viton-hd-dataset/test/agnostic-mask/{mask_img.replace('.jpg', '_mask.png')}\").convert(\"L\")\n", "\n", "import matplotlib.pyplot as plt\n", "plt.figure(figsize=(12, 4))\n", "plt.subplot(1, 3, 1)\n", "plt.imshow(person_image)\n", "plt.title(\"Person Image\")\n", "plt.axis('off')\n", "\n", "plt.subplot(1, 3, 2)\n", "plt.imshow(cloth_image)\n", "plt.title(\"Cloth Image\")\n", "plt.axis('off')\n", "plt.subplot(1, 3, 3)\n", "plt.imshow(mask_image, cmap='gray')\n", "plt.title(\"Mask Image\")\n", "plt.axis('off')" ] }, { "cell_type": "code", "execution_count": null, "id": "826427b6", "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-06-11 03:13:29.903172: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1749611610.125212 127 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1749611610.187151 127 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "scheduler_config.json: 100%|███████████████████| 313/313 [00:00<00:00, 1.89MB/s]\n", "config.json: 100%|█████████████████████████████| 547/547 [00:00<00:00, 4.50MB/s]\n", "diffusion_pytorch_model.safetensors: 100%|████| 335M/335M [00:01<00:00, 250MB/s]\n", "config.json: 100%|█████████████████████████████| 748/748 [00:00<00:00, 4.77MB/s]\n", "An error occurred while trying to fetch booksforcharlie/stable-diffusion-inpainting: booksforcharlie/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.\n", "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n", "diffusion_pytorch_model.bin: 100%|██████████| 3.44G/3.44G [00:13<00:00, 251MB/s]\n", "Fetching 12 files: 0%| | 0/12 [00:00\n", " main()\n", " File \"/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py\", line 116, in decorate_context\n", " return func(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/kaggle/working/CatVTON/inference.py\", line 269, in main\n", " dataset = VITONHDTestDataset(args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/kaggle/working/CatVTON/inference.py\", line 18, in __init__\n", " self.data = self.load_data()\n", " ^^^^^^^^^^^^^^^^\n", " File \"/kaggle/working/CatVTON/inference.py\", line 39, in load_data\n", " assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, 'test_pairs_unpaired.txt')), f\"File {pair_txt} does not exist.\"\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "AssertionError: File /kaggle/input/viton-hd-dataset/test_pairs_unpaired.txt does not exist.\n" ] } ], "source": [ "# # Replace with your actual data root and output directory paths\n", "# !CUDA_VISIBLE_DEVICES=0 python inference.py \\\n", "# --dataset vitonhd \\\n", "# --data_root_path /kaggle/input/viton-hd-dataset \\\n", "# --output_dir ./output \\\n", "# --dataloader_num_workers 8 \\\n", "# --batch_size 8 \\\n", "# --seed 555 \\\n", "# --mixed_precision no \\\n", "# --allow_tf32 \\\n", "# --repaint \\\n", "# --eval_pair" ] }, { "cell_type": "code", "execution_count": null, "id": "e417edb7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1c86c58d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }