Teatime666 commited on
Commit
823e49a
·
verified ·
1 Parent(s): 7cf1a79

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +202 -0
  2. README.md +119 -0
  3. dataset2json.py +63 -0
  4. extract_frame.py +38 -0
  5. model_structures.log +1497 -0
  6. myoutput.log +2 -0
  7. nohup.out +0 -0
  8. output.log +2 -0
  9. output/20241207/1929--seed_42-384x512/upper1_00057_00_512x384_3_1929.mp4 +0 -0
  10. output/20241207/2241--seed_42-384x512/3_s_1110342_in_xl_512x384_3_2241.mp4 +0 -0
  11. output/20241207/2241--seed_42-384x512/7_s_1110342_in_xl_512x384_3_2241.mp4 +0 -0
  12. output/20241207/2241--seed_42-384x512/8_s_1009794_in_xl_512x384_3_2241.mp4 +0 -0
  13. output/20241207/2241--seed_42-384x512/8_s_1110342_in_xl_512x384_3_2241.mp4 +0 -0
  14. read.py +39 -0
  15. requirements.txt +29 -0
  16. scripts.sh +7 -0
  17. stage1_nohup.out +0 -0
  18. train_stage_1.py +781 -0
  19. train_stage_2.py +842 -0
  20. vivid.py +229 -0
  21. vividfuxian_motion/20241211/1715/803128_detail_1060638_in_xl.mp4 +0 -0
  22. vividfuxian_motion/20241212/1437/000004-803128_detail_1060638_in_xl.mp4 +0 -0
  23. vividfuxian_motion/20241212/1506/000200-803128_detail_1060638_in_xl.mp4 +0 -0
  24. vividfuxian_motion/20241212/1629/000600-803128_detail_1060638_in_xl.mp4 +0 -0
  25. vividfuxian_valid/stage1/000010-803137_in_xl_812294_in_xl.jpg +0 -0
  26. vividfuxian_valid/stage1/000200-803137_in_xl_812294_in_xl.jpg +0 -0
  27. vividfuxian_valid/stage1/000400-803137_in_xl_812294_in_xl.jpg +0 -0
  28. vividfuxian_valid/stage1/000600-803137_in_xl_812294_in_xl.jpg +0 -0
  29. vividfuxian_valid/stage1/000800-803137_in_xl_812294_in_xl.jpg +0 -0
  30. vividfuxian_valid/stage1/001000-803137_in_xl_812294_in_xl.jpg +0 -0
  31. vividfuxian_valid/stage1/001200-803137_in_xl_812294_in_xl.jpg +0 -0
  32. vividfuxian_valid/stage1/001600-803137_in_xl_812294_in_xl.jpg +0 -0
  33. vividfuxian_valid/stage1/001800-803137_in_xl_812294_in_xl.jpg +0 -0
  34. vividfuxian_valid/stage1/002000-803137_in_xl_812294_in_xl.jpg +0 -0
  35. vividfuxian_valid/stage1/002200-803137_in_xl_812294_in_xl.jpg +0 -0
  36. vividfuxian_valid/stage1/002400-803137_in_xl_812294_in_xl.jpg +0 -0
  37. vividfuxian_valid/stage1/002600-803137_in_xl_812294_in_xl.jpg +0 -0
  38. vividfuxian_valid/stage1/002800-803137_in_xl_812294_in_xl.jpg +0 -0
  39. vividfuxian_valid/stage1/003000-803137_in_xl_812294_in_xl.jpg +0 -0
  40. vividfuxian_valid/stage1/003400-803137_in_xl_812294_in_xl.jpg +0 -0
  41. vividfuxian_valid/stage1/003600-803137_in_xl_812294_in_xl.jpg +0 -0
  42. vividfuxian_valid/stage1/003800-803137_in_xl_812294_in_xl.jpg +0 -0
  43. vividfuxian_valid/stage1/004200-803137_in_xl_812294_in_xl.jpg +0 -0
  44. vividfuxian_valid/stage1/004400-803137_in_xl_812294_in_xl.jpg +0 -0
  45. vividfuxian_valid/stage1/004600-803137_in_xl_812294_in_xl.jpg +0 -0
  46. vividfuxian_valid/stage1/004800-803137_in_xl_812294_in_xl.jpg +0 -0
  47. vividfuxian_valid/stage1/005200-803137_in_xl_812294_in_xl.jpg +0 -0
  48. vividfuxian_valid/stage1/005400-803137_in_xl_812294_in_xl.jpg +0 -0
  49. vividfuxian_valid/stage1/005600-803137_in_xl_812294_in_xl.jpg +0 -0
  50. vividfuxian_valid/stage1/005800-803137_in_xl_812294_in_xl.jpg +0 -0
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ViViD
2
+ ViViD: Video Virtual Try-on using Diffusion Models
3
+
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2405.11794-b31b1b.svg)](https://arxiv.org/abs/2405.11794)
5
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://alibaba-yuanjing-aigclab.github.io/ViViD)
6
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/alibaba-yuanjing-aigclab/ViViD)
7
+
8
+
9
+ ## Dataset
10
+ Dataset released: [ViViD](https://huggingface.co/datasets/alibaba-yuanjing-aigclab/ViViD)
11
+
12
+ ## Installation
13
+
14
+ ```
15
+ git clone https://github.com/alibaba-yuanjing-aigclab/ViViD
16
+ cd ViViD
17
+ ```
18
+
19
+ ### Environment
20
+ ```
21
+ conda create -n vivid python=3.10
22
+ conda activate vivid
23
+ conda activate /mnt/pfs-mc0p4k/ssai/cvg/team/envs/vivid
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ ### Weights
28
+ You can place the weights anywhere you like, for example, ```./ckpts```. If you put them somewhere else, you just need to update the path in ```./configs/prompts/*.yaml```.
29
+
30
+
31
+ #### Stable Diffusion Image Variations
32
+ ```
33
+ cd ckpts
34
+
35
+ git lfs install
36
+ git clone https://huggingface.co/lambdalabs/sd-image-variations-diffusers
37
+ ```
38
+ #### SD-VAE-ft-mse
39
+ ```
40
+ git lfs install
41
+ git clone https://huggingface.co/stabilityai/sd-vae-ft-mse
42
+ ```
43
+ #### Motion Module
44
+ Download [mm_sd_v15_v2](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt)
45
+
46
+ #### ViViD
47
+ ```
48
+ git lfs install
49
+ git clone https://huggingface.co/alibaba-yuanjing-aigclab/ViViD
50
+ ```
51
+ ## Inference
52
+ We provide two demos in ```./configs/prompts/```, run the following commands to have a try😼.
53
+
54
+ ```
55
+ python vivid.py --config ./configs/prompts/upper1.yaml
56
+
57
+ python vivid.py --config ./configs/prompts/lower1.yaml
58
+ ```
59
+
60
+ ## Data
61
+ As illustrated in ```./data```, the following data should be provided.
62
+ ```text
63
+ ./data/
64
+ |-- agnostic
65
+ | |-- video1.mp4
66
+ | |-- video2.mp4
67
+ | ...
68
+ |-- agnostic_mask
69
+ | |-- video1.mp4
70
+ | |-- video2.mp4
71
+ | ...
72
+ |-- cloth
73
+ | |-- cloth1.jpg
74
+ | |-- cloth2.jpg
75
+ | ...
76
+ |-- cloth_mask
77
+ | |-- cloth1.jpg
78
+ | |-- cloth2.jpg
79
+ | ...
80
+ |-- densepose
81
+ | |-- video1.mp4
82
+ | |-- video2.mp4
83
+ | ...
84
+ |-- videos
85
+ | |-- video1.mp4
86
+ | |-- video2.mp4
87
+ | ...
88
+ ```
89
+
90
+ ### Agnostic and agnostic_mask video
91
+ This part is a bit complex, you can obtain them through any of the following three ways:
92
+ 1. Follow [OOTDiffusion](https://github.com/levihsu/OOTDiffusion) to extract them frame-by-frame.(recommended)
93
+ 2. Use [SAM](https://github.com/facebookresearch/segment-anything) + Gaussian Blur.(see ```./tools/sam_agnostic.py``` for an example)
94
+ 3. Mask editor tools.
95
+
96
+ Note that the shape and size of the agnostic area may affect the try-on results.
97
+
98
+ ### Densepose video
99
+ See [vid2densepose](https://github.com/Flode-Labs/vid2densepose).(Thanks)
100
+
101
+ ### Cloth mask
102
+ Any detection tool is ok for obtaining the mask, like [SAM](https://github.com/facebookresearch/segment-anything).
103
+
104
+ ## BibTeX
105
+ ```text
106
+ @misc{fang2024vivid,
107
+ title={ViViD: Video Virtual Try-on using Diffusion Models},
108
+ author={Zixun Fang and Wei Zhai and Aimin Su and Hongliang Song and Kai Zhu and Mao Wang and Yu Chen and Zhiheng Liu and Yang Cao and Zheng-Jun Zha},
109
+ year={2024},
110
+ eprint={2405.11794},
111
+ archivePrefix={arXiv},
112
+ primaryClass={cs.CV}
113
+ }
114
+ ```
115
+
116
+ ## Contact Us
117
+ **Zixun Fang**: [[email protected]](mailto:[email protected])
118
+ **Yu Chen**: [[email protected]](mailto:[email protected])
119
+
dataset2json.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ def collect_files(data_dir):
5
+ """
6
+ 遍历 data 文件夹下的各子文件夹,以文件名的前 7 个字符为键,将对应的文件路径整理为字典
7
+ """
8
+ file_dict = {}
9
+
10
+ # 子文件夹列表
11
+ subfolders = ['densepose', 'videos', 'cloth', 'cloth_mask', 'agnostic_mask', 'agnostic']
12
+
13
+ for subfolder in subfolders:
14
+ subfolder_path = os.path.join(data_dir, subfolder)
15
+
16
+ if not os.path.exists(subfolder_path):
17
+ print(f"Warning: {subfolder_path} 路径不存在")
18
+ continue
19
+
20
+ # 遍历子文件夹中的文件
21
+ for file_name in os.listdir(subfolder_path):
22
+ # 只取文件名前 7 个字符用于匹配
23
+ key = file_name[:7]
24
+ if key not in file_dict:
25
+ # 初始化字典键为前 7 个字符的键名
26
+ file_dict[key] = {}
27
+
28
+ # 将当前文件路径保存在子文件夹名称对应的 key 下
29
+ file_dict[key][subfolder] = os.path.join(subfolder_path, file_name)
30
+
31
+ return file_dict
32
+
33
+ def generate_json(data_dir, output_file):
34
+ """
35
+ 生成 JSON 文件,将文件匹配结果输出
36
+ """
37
+ files = collect_files(data_dir)
38
+ result = []
39
+
40
+ # 构造符合格式的 JSON 列表
41
+ for key, paths in files.items():
42
+ result.append({
43
+ "densepose": paths.get("densepose", ""), # 如果某个字段不存在,则填补为空值
44
+ "videos": paths.get("videos", ""),
45
+ "cloth": paths.get("cloth", ""),
46
+ "cloth_mask": paths.get("cloth_mask", ""),
47
+ "agnostic_mask": paths.get("agnostic_mask", ""),
48
+ "agnostic": paths.get("agnostic", "")
49
+ })
50
+
51
+ # 写入到指定路径的 JSON 文件
52
+ with open(output_file, "w", encoding="utf-8") as f:
53
+ json.dump(result, f, indent=4, ensure_ascii=False)
54
+
55
+ print(f"JSON 文件已生成: {output_file}")
56
+
57
+ if __name__ == "__main__":
58
+ # 要匹配的 data 文件夹路径
59
+ data_dir = "/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/data"
60
+ # 输出的 JSON 文件路径
61
+ output_file = "/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/data/vividfuxian_stage1.json"
62
+
63
+ generate_json(data_dir, output_file)
extract_frame.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+
4
+ def extract_frame(video_path, frame_number, output_path):
5
+ # 打开视频文件
6
+ cap = cv2.VideoCapture(video_path)
7
+
8
+ if not cap.isOpened():
9
+ print(f"无法打开视频文件: {video_path}")
10
+ return
11
+
12
+ # 设置视频捕捉的位置到指定帧
13
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
14
+
15
+ # 读取指定的帧
16
+ success, frame = cap.read()
17
+
18
+ if success:
19
+ # 保存帧为指定路径的文件
20
+ cv2.imwrite(output_path, frame)
21
+ print(f"已成功提取帧 {frame_number} 并保存为 {output_path}")
22
+ else:
23
+ print(f"未能读取帧 {frame_number}。请检查帧编号是否超出范围。")
24
+
25
+ # 释放资源
26
+ cap.release()
27
+
28
+ if __name__ == "__main__":
29
+ video_file = "/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/densepose/803137_detail.mp4" # 替换为你的 MP4 文件路径
30
+ frame_to_extract = 24 # 需要提取的帧编号
31
+ output_file = "/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/valid/densepose_images/803137_in_xl.jpg" # 替换为你想保存的路径
32
+
33
+ # 创建包含输出文件的目录(如果不存在)
34
+ output_dir = os.path.dirname(output_file)
35
+ if not os.path.exists(output_dir) and output_dir:
36
+ os.makedirs(output_dir)
37
+
38
+ extract_frame(video_file, frame_to_extract, output_file)
model_structures.log ADDED
@@ -0,0 +1,1497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Denoising UNet structure:
2
+ UNet3DConditionModel(
3
+ (conv_in): InflatedConv3d(9, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4
+ (time_proj): Timesteps()
5
+ (time_embedding): TimestepEmbedding(
6
+ (linear_1): LoRACompatibleLinear(in_features=320, out_features=1280, bias=True)
7
+ (act): SiLU()
8
+ (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
9
+ )
10
+ (down_blocks): ModuleList(
11
+ (0): CrossAttnDownBlock3D(
12
+ (attentions): ModuleList(
13
+ (0-1): 2 x Transformer3DModel(
14
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
15
+ (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
16
+ (transformer_blocks): ModuleList(
17
+ (0): TemporalBasicTransformerBlock(
18
+ (attn1): Attention(
19
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
20
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
21
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
22
+ (to_out): ModuleList(
23
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
24
+ (1): Dropout(p=0.0, inplace=False)
25
+ )
26
+ )
27
+ (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
28
+ (attn2): Attention(
29
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
30
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
31
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
32
+ (to_out): ModuleList(
33
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
34
+ (1): Dropout(p=0.0, inplace=False)
35
+ )
36
+ )
37
+ (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
38
+ (ff): FeedForward(
39
+ (net): ModuleList(
40
+ (0): GEGLU(
41
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
42
+ )
43
+ (1): Dropout(p=0.0, inplace=False)
44
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
45
+ )
46
+ )
47
+ (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
48
+ )
49
+ )
50
+ (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
51
+ )
52
+ )
53
+ (resnets): ModuleList(
54
+ (0-1): 2 x ResnetBlock3D(
55
+ (norm1): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
56
+ (conv1): InflatedConv3d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
57
+ (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
58
+ (norm2): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
59
+ (dropout): Dropout(p=0.0, inplace=False)
60
+ (conv2): InflatedConv3d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
61
+ (nonlinearity): SiLU()
62
+ )
63
+ )
64
+ (motion_modules): ModuleList(
65
+ (0-1): 2 x VanillaTemporalModule(
66
+ (temporal_transformer): TemporalTransformer3DModel(
67
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
68
+ (proj_in): Linear(in_features=320, out_features=320, bias=True)
69
+ (transformer_blocks): ModuleList(
70
+ (0): TemporalTransformerBlock(
71
+ (attention_blocks): ModuleList(
72
+ (0-1): 2 x VersatileAttention(
73
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
74
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
75
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
76
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
77
+ (to_out): ModuleList(
78
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
79
+ (1): Dropout(p=0.0, inplace=False)
80
+ )
81
+ (pos_encoder): PositionalEncoding(
82
+ (dropout): Dropout(p=0.0, inplace=False)
83
+ )
84
+ )
85
+ )
86
+ (norms): ModuleList(
87
+ (0-1): 2 x LayerNorm((320,), eps=1e-05, elementwise_affine=True)
88
+ )
89
+ (ff): FeedForward(
90
+ (net): ModuleList(
91
+ (0): GEGLU(
92
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
93
+ )
94
+ (1): Dropout(p=0.0, inplace=False)
95
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
96
+ )
97
+ )
98
+ (ff_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
99
+ )
100
+ )
101
+ (proj_out): Linear(in_features=320, out_features=320, bias=True)
102
+ )
103
+ )
104
+ )
105
+ (downsamplers): ModuleList(
106
+ (0): Downsample3D(
107
+ (conv): InflatedConv3d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
108
+ )
109
+ )
110
+ )
111
+ (1): CrossAttnDownBlock3D(
112
+ (attentions): ModuleList(
113
+ (0-1): 2 x Transformer3DModel(
114
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
115
+ (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
116
+ (transformer_blocks): ModuleList(
117
+ (0): TemporalBasicTransformerBlock(
118
+ (attn1): Attention(
119
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
120
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
121
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
122
+ (to_out): ModuleList(
123
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
124
+ (1): Dropout(p=0.0, inplace=False)
125
+ )
126
+ )
127
+ (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
128
+ (attn2): Attention(
129
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
130
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
131
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
132
+ (to_out): ModuleList(
133
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
134
+ (1): Dropout(p=0.0, inplace=False)
135
+ )
136
+ )
137
+ (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
138
+ (ff): FeedForward(
139
+ (net): ModuleList(
140
+ (0): GEGLU(
141
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
142
+ )
143
+ (1): Dropout(p=0.0, inplace=False)
144
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
145
+ )
146
+ )
147
+ (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
148
+ )
149
+ )
150
+ (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
151
+ )
152
+ )
153
+ (resnets): ModuleList(
154
+ (0): ResnetBlock3D(
155
+ (norm1): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
156
+ (conv1): InflatedConv3d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
157
+ (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
158
+ (norm2): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
159
+ (dropout): Dropout(p=0.0, inplace=False)
160
+ (conv2): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
161
+ (nonlinearity): SiLU()
162
+ (conv_shortcut): InflatedConv3d(320, 640, kernel_size=(1, 1), stride=(1, 1))
163
+ )
164
+ (1): ResnetBlock3D(
165
+ (norm1): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
166
+ (conv1): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
167
+ (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
168
+ (norm2): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
169
+ (dropout): Dropout(p=0.0, inplace=False)
170
+ (conv2): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
171
+ (nonlinearity): SiLU()
172
+ )
173
+ )
174
+ (motion_modules): ModuleList(
175
+ (0-1): 2 x VanillaTemporalModule(
176
+ (temporal_transformer): TemporalTransformer3DModel(
177
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
178
+ (proj_in): Linear(in_features=640, out_features=640, bias=True)
179
+ (transformer_blocks): ModuleList(
180
+ (0): TemporalTransformerBlock(
181
+ (attention_blocks): ModuleList(
182
+ (0-1): 2 x VersatileAttention(
183
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
184
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
185
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
186
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
187
+ (to_out): ModuleList(
188
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
189
+ (1): Dropout(p=0.0, inplace=False)
190
+ )
191
+ (pos_encoder): PositionalEncoding(
192
+ (dropout): Dropout(p=0.0, inplace=False)
193
+ )
194
+ )
195
+ )
196
+ (norms): ModuleList(
197
+ (0-1): 2 x LayerNorm((640,), eps=1e-05, elementwise_affine=True)
198
+ )
199
+ (ff): FeedForward(
200
+ (net): ModuleList(
201
+ (0): GEGLU(
202
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
203
+ )
204
+ (1): Dropout(p=0.0, inplace=False)
205
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
206
+ )
207
+ )
208
+ (ff_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
209
+ )
210
+ )
211
+ (proj_out): Linear(in_features=640, out_features=640, bias=True)
212
+ )
213
+ )
214
+ )
215
+ (downsamplers): ModuleList(
216
+ (0): Downsample3D(
217
+ (conv): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
218
+ )
219
+ )
220
+ )
221
+ (2): CrossAttnDownBlock3D(
222
+ (attentions): ModuleList(
223
+ (0-1): 2 x Transformer3DModel(
224
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
225
+ (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
226
+ (transformer_blocks): ModuleList(
227
+ (0): TemporalBasicTransformerBlock(
228
+ (attn1): Attention(
229
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
230
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
231
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
232
+ (to_out): ModuleList(
233
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
234
+ (1): Dropout(p=0.0, inplace=False)
235
+ )
236
+ )
237
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
238
+ (attn2): Attention(
239
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
240
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
241
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
242
+ (to_out): ModuleList(
243
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
244
+ (1): Dropout(p=0.0, inplace=False)
245
+ )
246
+ )
247
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
248
+ (ff): FeedForward(
249
+ (net): ModuleList(
250
+ (0): GEGLU(
251
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
252
+ )
253
+ (1): Dropout(p=0.0, inplace=False)
254
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
255
+ )
256
+ )
257
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
258
+ )
259
+ )
260
+ (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
261
+ )
262
+ )
263
+ (resnets): ModuleList(
264
+ (0): ResnetBlock3D(
265
+ (norm1): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
266
+ (conv1): InflatedConv3d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
267
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
268
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
269
+ (dropout): Dropout(p=0.0, inplace=False)
270
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
271
+ (nonlinearity): SiLU()
272
+ (conv_shortcut): InflatedConv3d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
273
+ )
274
+ (1): ResnetBlock3D(
275
+ (norm1): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
276
+ (conv1): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
277
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
278
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
279
+ (dropout): Dropout(p=0.0, inplace=False)
280
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
281
+ (nonlinearity): SiLU()
282
+ )
283
+ )
284
+ (motion_modules): ModuleList(
285
+ (0-1): 2 x VanillaTemporalModule(
286
+ (temporal_transformer): TemporalTransformer3DModel(
287
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
288
+ (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
289
+ (transformer_blocks): ModuleList(
290
+ (0): TemporalTransformerBlock(
291
+ (attention_blocks): ModuleList(
292
+ (0-1): 2 x VersatileAttention(
293
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
294
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
295
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
296
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
297
+ (to_out): ModuleList(
298
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
299
+ (1): Dropout(p=0.0, inplace=False)
300
+ )
301
+ (pos_encoder): PositionalEncoding(
302
+ (dropout): Dropout(p=0.0, inplace=False)
303
+ )
304
+ )
305
+ )
306
+ (norms): ModuleList(
307
+ (0-1): 2 x LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
308
+ )
309
+ (ff): FeedForward(
310
+ (net): ModuleList(
311
+ (0): GEGLU(
312
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
313
+ )
314
+ (1): Dropout(p=0.0, inplace=False)
315
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
316
+ )
317
+ )
318
+ (ff_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
319
+ )
320
+ )
321
+ (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
322
+ )
323
+ )
324
+ )
325
+ (downsamplers): ModuleList(
326
+ (0): Downsample3D(
327
+ (conv): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
328
+ )
329
+ )
330
+ )
331
+ (3): DownBlock3D(
332
+ (resnets): ModuleList(
333
+ (0-1): 2 x ResnetBlock3D(
334
+ (norm1): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
335
+ (conv1): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
336
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
337
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
338
+ (dropout): Dropout(p=0.0, inplace=False)
339
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
340
+ (nonlinearity): SiLU()
341
+ )
342
+ )
343
+ (motion_modules): ModuleList(
344
+ (0-1): 2 x VanillaTemporalModule(
345
+ (temporal_transformer): TemporalTransformer3DModel(
346
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
347
+ (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
348
+ (transformer_blocks): ModuleList(
349
+ (0): TemporalTransformerBlock(
350
+ (attention_blocks): ModuleList(
351
+ (0-1): 2 x VersatileAttention(
352
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
353
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
354
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
355
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
356
+ (to_out): ModuleList(
357
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
358
+ (1): Dropout(p=0.0, inplace=False)
359
+ )
360
+ (pos_encoder): PositionalEncoding(
361
+ (dropout): Dropout(p=0.0, inplace=False)
362
+ )
363
+ )
364
+ )
365
+ (norms): ModuleList(
366
+ (0-1): 2 x LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
367
+ )
368
+ (ff): FeedForward(
369
+ (net): ModuleList(
370
+ (0): GEGLU(
371
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
372
+ )
373
+ (1): Dropout(p=0.0, inplace=False)
374
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
375
+ )
376
+ )
377
+ (ff_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
378
+ )
379
+ )
380
+ (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
381
+ )
382
+ )
383
+ )
384
+ )
385
+ )
386
+ (up_blocks): ModuleList(
387
+ (0): UpBlock3D(
388
+ (resnets): ModuleList(
389
+ (0-2): 3 x ResnetBlock3D(
390
+ (norm1): InflatedGroupNorm(32, 2560, eps=1e-05, affine=True)
391
+ (conv1): InflatedConv3d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
392
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
393
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
394
+ (dropout): Dropout(p=0.0, inplace=False)
395
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
396
+ (nonlinearity): SiLU()
397
+ (conv_shortcut): InflatedConv3d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
398
+ )
399
+ )
400
+ (motion_modules): ModuleList(
401
+ (0-2): 3 x VanillaTemporalModule(
402
+ (temporal_transformer): TemporalTransformer3DModel(
403
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
404
+ (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
405
+ (transformer_blocks): ModuleList(
406
+ (0): TemporalTransformerBlock(
407
+ (attention_blocks): ModuleList(
408
+ (0-1): 2 x VersatileAttention(
409
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
410
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
411
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
412
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
413
+ (to_out): ModuleList(
414
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
415
+ (1): Dropout(p=0.0, inplace=False)
416
+ )
417
+ (pos_encoder): PositionalEncoding(
418
+ (dropout): Dropout(p=0.0, inplace=False)
419
+ )
420
+ )
421
+ )
422
+ (norms): ModuleList(
423
+ (0-1): 2 x LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
424
+ )
425
+ (ff): FeedForward(
426
+ (net): ModuleList(
427
+ (0): GEGLU(
428
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
429
+ )
430
+ (1): Dropout(p=0.0, inplace=False)
431
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
432
+ )
433
+ )
434
+ (ff_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
435
+ )
436
+ )
437
+ (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
438
+ )
439
+ )
440
+ )
441
+ (upsamplers): ModuleList(
442
+ (0): Upsample3D(
443
+ (conv): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
444
+ )
445
+ )
446
+ )
447
+ (1): CrossAttnUpBlock3D(
448
+ (attentions): ModuleList(
449
+ (0-2): 3 x Transformer3DModel(
450
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
451
+ (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
452
+ (transformer_blocks): ModuleList(
453
+ (0): TemporalBasicTransformerBlock(
454
+ (attn1): Attention(
455
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
456
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
457
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
458
+ (to_out): ModuleList(
459
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
460
+ (1): Dropout(p=0.0, inplace=False)
461
+ )
462
+ )
463
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
464
+ (attn2): Attention(
465
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
466
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
467
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
468
+ (to_out): ModuleList(
469
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
470
+ (1): Dropout(p=0.0, inplace=False)
471
+ )
472
+ )
473
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
474
+ (ff): FeedForward(
475
+ (net): ModuleList(
476
+ (0): GEGLU(
477
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
478
+ )
479
+ (1): Dropout(p=0.0, inplace=False)
480
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
481
+ )
482
+ )
483
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
484
+ )
485
+ )
486
+ (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
487
+ )
488
+ )
489
+ (resnets): ModuleList(
490
+ (0-1): 2 x ResnetBlock3D(
491
+ (norm1): InflatedGroupNorm(32, 2560, eps=1e-05, affine=True)
492
+ (conv1): InflatedConv3d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
493
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
494
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
495
+ (dropout): Dropout(p=0.0, inplace=False)
496
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
497
+ (nonlinearity): SiLU()
498
+ (conv_shortcut): InflatedConv3d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
499
+ )
500
+ (2): ResnetBlock3D(
501
+ (norm1): InflatedGroupNorm(32, 1920, eps=1e-05, affine=True)
502
+ (conv1): InflatedConv3d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
503
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
504
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
505
+ (dropout): Dropout(p=0.0, inplace=False)
506
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
507
+ (nonlinearity): SiLU()
508
+ (conv_shortcut): InflatedConv3d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
509
+ )
510
+ )
511
+ (motion_modules): ModuleList(
512
+ (0-2): 3 x VanillaTemporalModule(
513
+ (temporal_transformer): TemporalTransformer3DModel(
514
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
515
+ (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
516
+ (transformer_blocks): ModuleList(
517
+ (0): TemporalTransformerBlock(
518
+ (attention_blocks): ModuleList(
519
+ (0-1): 2 x VersatileAttention(
520
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
521
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
522
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
523
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
524
+ (to_out): ModuleList(
525
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
526
+ (1): Dropout(p=0.0, inplace=False)
527
+ )
528
+ (pos_encoder): PositionalEncoding(
529
+ (dropout): Dropout(p=0.0, inplace=False)
530
+ )
531
+ )
532
+ )
533
+ (norms): ModuleList(
534
+ (0-1): 2 x LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
535
+ )
536
+ (ff): FeedForward(
537
+ (net): ModuleList(
538
+ (0): GEGLU(
539
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
540
+ )
541
+ (1): Dropout(p=0.0, inplace=False)
542
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
543
+ )
544
+ )
545
+ (ff_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
546
+ )
547
+ )
548
+ (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
549
+ )
550
+ )
551
+ )
552
+ (upsamplers): ModuleList(
553
+ (0): Upsample3D(
554
+ (conv): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
555
+ )
556
+ )
557
+ )
558
+ (2): CrossAttnUpBlock3D(
559
+ (attentions): ModuleList(
560
+ (0-2): 3 x Transformer3DModel(
561
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
562
+ (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
563
+ (transformer_blocks): ModuleList(
564
+ (0): TemporalBasicTransformerBlock(
565
+ (attn1): Attention(
566
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
567
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
568
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
569
+ (to_out): ModuleList(
570
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
571
+ (1): Dropout(p=0.0, inplace=False)
572
+ )
573
+ )
574
+ (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
575
+ (attn2): Attention(
576
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
577
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
578
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
579
+ (to_out): ModuleList(
580
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
581
+ (1): Dropout(p=0.0, inplace=False)
582
+ )
583
+ )
584
+ (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
585
+ (ff): FeedForward(
586
+ (net): ModuleList(
587
+ (0): GEGLU(
588
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
589
+ )
590
+ (1): Dropout(p=0.0, inplace=False)
591
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
592
+ )
593
+ )
594
+ (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
595
+ )
596
+ )
597
+ (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
598
+ )
599
+ )
600
+ (resnets): ModuleList(
601
+ (0): ResnetBlock3D(
602
+ (norm1): InflatedGroupNorm(32, 1920, eps=1e-05, affine=True)
603
+ (conv1): InflatedConv3d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
604
+ (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
605
+ (norm2): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
606
+ (dropout): Dropout(p=0.0, inplace=False)
607
+ (conv2): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
608
+ (nonlinearity): SiLU()
609
+ (conv_shortcut): InflatedConv3d(1920, 640, kernel_size=(1, 1), stride=(1, 1))
610
+ )
611
+ (1): ResnetBlock3D(
612
+ (norm1): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
613
+ (conv1): InflatedConv3d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
614
+ (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
615
+ (norm2): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
616
+ (dropout): Dropout(p=0.0, inplace=False)
617
+ (conv2): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
618
+ (nonlinearity): SiLU()
619
+ (conv_shortcut): InflatedConv3d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
620
+ )
621
+ (2): ResnetBlock3D(
622
+ (norm1): InflatedGroupNorm(32, 960, eps=1e-05, affine=True)
623
+ (conv1): InflatedConv3d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
624
+ (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
625
+ (norm2): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
626
+ (dropout): Dropout(p=0.0, inplace=False)
627
+ (conv2): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
628
+ (nonlinearity): SiLU()
629
+ (conv_shortcut): InflatedConv3d(960, 640, kernel_size=(1, 1), stride=(1, 1))
630
+ )
631
+ )
632
+ (motion_modules): ModuleList(
633
+ (0-2): 3 x VanillaTemporalModule(
634
+ (temporal_transformer): TemporalTransformer3DModel(
635
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
636
+ (proj_in): Linear(in_features=640, out_features=640, bias=True)
637
+ (transformer_blocks): ModuleList(
638
+ (0): TemporalTransformerBlock(
639
+ (attention_blocks): ModuleList(
640
+ (0-1): 2 x VersatileAttention(
641
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
642
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
643
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
644
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
645
+ (to_out): ModuleList(
646
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
647
+ (1): Dropout(p=0.0, inplace=False)
648
+ )
649
+ (pos_encoder): PositionalEncoding(
650
+ (dropout): Dropout(p=0.0, inplace=False)
651
+ )
652
+ )
653
+ )
654
+ (norms): ModuleList(
655
+ (0-1): 2 x LayerNorm((640,), eps=1e-05, elementwise_affine=True)
656
+ )
657
+ (ff): FeedForward(
658
+ (net): ModuleList(
659
+ (0): GEGLU(
660
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
661
+ )
662
+ (1): Dropout(p=0.0, inplace=False)
663
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
664
+ )
665
+ )
666
+ (ff_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
667
+ )
668
+ )
669
+ (proj_out): Linear(in_features=640, out_features=640, bias=True)
670
+ )
671
+ )
672
+ )
673
+ (upsamplers): ModuleList(
674
+ (0): Upsample3D(
675
+ (conv): InflatedConv3d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
676
+ )
677
+ )
678
+ )
679
+ (3): CrossAttnUpBlock3D(
680
+ (attentions): ModuleList(
681
+ (0-2): 3 x Transformer3DModel(
682
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
683
+ (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
684
+ (transformer_blocks): ModuleList(
685
+ (0): TemporalBasicTransformerBlock(
686
+ (attn1): Attention(
687
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
688
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
689
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
690
+ (to_out): ModuleList(
691
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
692
+ (1): Dropout(p=0.0, inplace=False)
693
+ )
694
+ )
695
+ (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
696
+ (attn2): Attention(
697
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
698
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
699
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
700
+ (to_out): ModuleList(
701
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
702
+ (1): Dropout(p=0.0, inplace=False)
703
+ )
704
+ )
705
+ (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
706
+ (ff): FeedForward(
707
+ (net): ModuleList(
708
+ (0): GEGLU(
709
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
710
+ )
711
+ (1): Dropout(p=0.0, inplace=False)
712
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
713
+ )
714
+ )
715
+ (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
716
+ )
717
+ )
718
+ (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
719
+ )
720
+ )
721
+ (resnets): ModuleList(
722
+ (0): ResnetBlock3D(
723
+ (norm1): InflatedGroupNorm(32, 960, eps=1e-05, affine=True)
724
+ (conv1): InflatedConv3d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
725
+ (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
726
+ (norm2): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
727
+ (dropout): Dropout(p=0.0, inplace=False)
728
+ (conv2): InflatedConv3d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
729
+ (nonlinearity): SiLU()
730
+ (conv_shortcut): InflatedConv3d(960, 320, kernel_size=(1, 1), stride=(1, 1))
731
+ )
732
+ (1-2): 2 x ResnetBlock3D(
733
+ (norm1): InflatedGroupNorm(32, 640, eps=1e-05, affine=True)
734
+ (conv1): InflatedConv3d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
735
+ (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
736
+ (norm2): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
737
+ (dropout): Dropout(p=0.0, inplace=False)
738
+ (conv2): InflatedConv3d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
739
+ (nonlinearity): SiLU()
740
+ (conv_shortcut): InflatedConv3d(640, 320, kernel_size=(1, 1), stride=(1, 1))
741
+ )
742
+ )
743
+ (motion_modules): ModuleList(
744
+ (0-2): 3 x VanillaTemporalModule(
745
+ (temporal_transformer): TemporalTransformer3DModel(
746
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
747
+ (proj_in): Linear(in_features=320, out_features=320, bias=True)
748
+ (transformer_blocks): ModuleList(
749
+ (0): TemporalTransformerBlock(
750
+ (attention_blocks): ModuleList(
751
+ (0-1): 2 x VersatileAttention(
752
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
753
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
754
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
755
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
756
+ (to_out): ModuleList(
757
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
758
+ (1): Dropout(p=0.0, inplace=False)
759
+ )
760
+ (pos_encoder): PositionalEncoding(
761
+ (dropout): Dropout(p=0.0, inplace=False)
762
+ )
763
+ )
764
+ )
765
+ (norms): ModuleList(
766
+ (0-1): 2 x LayerNorm((320,), eps=1e-05, elementwise_affine=True)
767
+ )
768
+ (ff): FeedForward(
769
+ (net): ModuleList(
770
+ (0): GEGLU(
771
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
772
+ )
773
+ (1): Dropout(p=0.0, inplace=False)
774
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
775
+ )
776
+ )
777
+ (ff_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
778
+ )
779
+ )
780
+ (proj_out): Linear(in_features=320, out_features=320, bias=True)
781
+ )
782
+ )
783
+ )
784
+ )
785
+ )
786
+ (mid_block): UNetMidBlock3DCrossAttn(
787
+ (attentions): ModuleList(
788
+ (0): Transformer3DModel(
789
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
790
+ (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
791
+ (transformer_blocks): ModuleList(
792
+ (0): TemporalBasicTransformerBlock(
793
+ (attn1): Attention(
794
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
795
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
796
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
797
+ (to_out): ModuleList(
798
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
799
+ (1): Dropout(p=0.0, inplace=False)
800
+ )
801
+ )
802
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
803
+ (attn2): Attention(
804
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
805
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
806
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
807
+ (to_out): ModuleList(
808
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
809
+ (1): Dropout(p=0.0, inplace=False)
810
+ )
811
+ )
812
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
813
+ (ff): FeedForward(
814
+ (net): ModuleList(
815
+ (0): GEGLU(
816
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
817
+ )
818
+ (1): Dropout(p=0.0, inplace=False)
819
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
820
+ )
821
+ )
822
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
823
+ )
824
+ )
825
+ (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
826
+ )
827
+ )
828
+ (resnets): ModuleList(
829
+ (0-1): 2 x ResnetBlock3D(
830
+ (norm1): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
831
+ (conv1): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
832
+ (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)
833
+ (norm2): InflatedGroupNorm(32, 1280, eps=1e-05, affine=True)
834
+ (dropout): Dropout(p=0.0, inplace=False)
835
+ (conv2): InflatedConv3d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
836
+ (nonlinearity): SiLU()
837
+ )
838
+ )
839
+ (motion_modules): ModuleList(
840
+ (0): VanillaTemporalModule(
841
+ (temporal_transformer): TemporalTransformer3DModel(
842
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
843
+ (proj_in): Linear(in_features=1280, out_features=1280, bias=True)
844
+ (transformer_blocks): ModuleList(
845
+ (0): TemporalTransformerBlock(
846
+ (attention_blocks): ModuleList(
847
+ (0-1): 2 x VersatileAttention(
848
+ (Module Info) Attention_Mode: Temporal, Is_Cross_Attention: False
849
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
850
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
851
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
852
+ (to_out): ModuleList(
853
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
854
+ (1): Dropout(p=0.0, inplace=False)
855
+ )
856
+ (pos_encoder): PositionalEncoding(
857
+ (dropout): Dropout(p=0.0, inplace=False)
858
+ )
859
+ )
860
+ )
861
+ (norms): ModuleList(
862
+ (0-1): 2 x LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
863
+ )
864
+ (ff): FeedForward(
865
+ (net): ModuleList(
866
+ (0): GEGLU(
867
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
868
+ )
869
+ (1): Dropout(p=0.0, inplace=False)
870
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
871
+ )
872
+ )
873
+ (ff_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
874
+ )
875
+ )
876
+ (proj_out): Linear(in_features=1280, out_features=1280, bias=True)
877
+ )
878
+ )
879
+ )
880
+ )
881
+ (conv_norm_out): InflatedGroupNorm(32, 320, eps=1e-05, affine=True)
882
+ (conv_act): SiLU()
883
+ (conv_out): InflatedConv3d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
884
+ )
885
+ Reference UNet structure:
886
+ UNet2DConditionModel(
887
+ (conv_in): Conv2d(5, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
888
+ (time_proj): Timesteps()
889
+ (time_embedding): TimestepEmbedding(
890
+ (linear_1): LoRACompatibleLinear(in_features=320, out_features=1280, bias=True)
891
+ (act): SiLU()
892
+ (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
893
+ )
894
+ (down_blocks): ModuleList(
895
+ (0): CrossAttnDownBlock2D(
896
+ (attentions): ModuleList(
897
+ (0-1): 2 x Transformer2DModel(
898
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
899
+ (proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
900
+ (transformer_blocks): ModuleList(
901
+ (0): BasicTransformerBlock(
902
+ (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
903
+ (attn1): Attention(
904
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
905
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
906
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
907
+ (to_out): ModuleList(
908
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
909
+ (1): Dropout(p=0.0, inplace=False)
910
+ )
911
+ )
912
+ (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
913
+ (attn2): Attention(
914
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
915
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
916
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
917
+ (to_out): ModuleList(
918
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
919
+ (1): Dropout(p=0.0, inplace=False)
920
+ )
921
+ )
922
+ (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
923
+ (ff): FeedForward(
924
+ (net): ModuleList(
925
+ (0): GEGLU(
926
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
927
+ )
928
+ (1): Dropout(p=0.0, inplace=False)
929
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
930
+ )
931
+ )
932
+ )
933
+ )
934
+ (proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
935
+ )
936
+ )
937
+ (resnets): ModuleList(
938
+ (0-1): 2 x ResnetBlock2D(
939
+ (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
940
+ (conv1): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
941
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
942
+ (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
943
+ (dropout): Dropout(p=0.0, inplace=False)
944
+ (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
945
+ (nonlinearity): SiLU()
946
+ )
947
+ )
948
+ (downsamplers): ModuleList(
949
+ (0): Downsample2D(
950
+ (conv): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
951
+ )
952
+ )
953
+ )
954
+ (1): CrossAttnDownBlock2D(
955
+ (attentions): ModuleList(
956
+ (0-1): 2 x Transformer2DModel(
957
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
958
+ (proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
959
+ (transformer_blocks): ModuleList(
960
+ (0): BasicTransformerBlock(
961
+ (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
962
+ (attn1): Attention(
963
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
964
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
965
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
966
+ (to_out): ModuleList(
967
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
968
+ (1): Dropout(p=0.0, inplace=False)
969
+ )
970
+ )
971
+ (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
972
+ (attn2): Attention(
973
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
974
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
975
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
976
+ (to_out): ModuleList(
977
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
978
+ (1): Dropout(p=0.0, inplace=False)
979
+ )
980
+ )
981
+ (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
982
+ (ff): FeedForward(
983
+ (net): ModuleList(
984
+ (0): GEGLU(
985
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
986
+ )
987
+ (1): Dropout(p=0.0, inplace=False)
988
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
989
+ )
990
+ )
991
+ )
992
+ )
993
+ (proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
994
+ )
995
+ )
996
+ (resnets): ModuleList(
997
+ (0): ResnetBlock2D(
998
+ (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
999
+ (conv1): LoRACompatibleConv(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1000
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
1001
+ (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
1002
+ (dropout): Dropout(p=0.0, inplace=False)
1003
+ (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1004
+ (nonlinearity): SiLU()
1005
+ (conv_shortcut): LoRACompatibleConv(320, 640, kernel_size=(1, 1), stride=(1, 1))
1006
+ )
1007
+ (1): ResnetBlock2D(
1008
+ (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
1009
+ (conv1): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1010
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
1011
+ (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
1012
+ (dropout): Dropout(p=0.0, inplace=False)
1013
+ (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1014
+ (nonlinearity): SiLU()
1015
+ )
1016
+ )
1017
+ (downsamplers): ModuleList(
1018
+ (0): Downsample2D(
1019
+ (conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1020
+ )
1021
+ )
1022
+ )
1023
+ (2): CrossAttnDownBlock2D(
1024
+ (attentions): ModuleList(
1025
+ (0-1): 2 x Transformer2DModel(
1026
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
1027
+ (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1028
+ (transformer_blocks): ModuleList(
1029
+ (0): BasicTransformerBlock(
1030
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1031
+ (attn1): Attention(
1032
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1033
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1034
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1035
+ (to_out): ModuleList(
1036
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1037
+ (1): Dropout(p=0.0, inplace=False)
1038
+ )
1039
+ )
1040
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1041
+ (attn2): Attention(
1042
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1043
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1044
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1045
+ (to_out): ModuleList(
1046
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1047
+ (1): Dropout(p=0.0, inplace=False)
1048
+ )
1049
+ )
1050
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1051
+ (ff): FeedForward(
1052
+ (net): ModuleList(
1053
+ (0): GEGLU(
1054
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
1055
+ )
1056
+ (1): Dropout(p=0.0, inplace=False)
1057
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
1058
+ )
1059
+ )
1060
+ )
1061
+ )
1062
+ (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1063
+ )
1064
+ )
1065
+ (resnets): ModuleList(
1066
+ (0): ResnetBlock2D(
1067
+ (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
1068
+ (conv1): LoRACompatibleConv(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1069
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1070
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1071
+ (dropout): Dropout(p=0.0, inplace=False)
1072
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1073
+ (nonlinearity): SiLU()
1074
+ (conv_shortcut): LoRACompatibleConv(640, 1280, kernel_size=(1, 1), stride=(1, 1))
1075
+ )
1076
+ (1): ResnetBlock2D(
1077
+ (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
1078
+ (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1079
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1080
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1081
+ (dropout): Dropout(p=0.0, inplace=False)
1082
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1083
+ (nonlinearity): SiLU()
1084
+ )
1085
+ )
1086
+ (downsamplers): ModuleList(
1087
+ (0): Downsample2D(
1088
+ (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1089
+ )
1090
+ )
1091
+ )
1092
+ (3): DownBlock2D(
1093
+ (resnets): ModuleList(
1094
+ (0-1): 2 x ResnetBlock2D(
1095
+ (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
1096
+ (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1097
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1098
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1099
+ (dropout): Dropout(p=0.0, inplace=False)
1100
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1101
+ (nonlinearity): SiLU()
1102
+ )
1103
+ )
1104
+ )
1105
+ )
1106
+ (up_blocks): ModuleList(
1107
+ (0): UpBlock2D(
1108
+ (resnets): ModuleList(
1109
+ (0-2): 3 x ResnetBlock2D(
1110
+ (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
1111
+ (conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1112
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1113
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1114
+ (dropout): Dropout(p=0.0, inplace=False)
1115
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1116
+ (nonlinearity): SiLU()
1117
+ (conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
1118
+ )
1119
+ )
1120
+ (upsamplers): ModuleList(
1121
+ (0): Upsample2D(
1122
+ (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1123
+ )
1124
+ )
1125
+ )
1126
+ (1): CrossAttnUpBlock2D(
1127
+ (attentions): ModuleList(
1128
+ (0-2): 3 x Transformer2DModel(
1129
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
1130
+ (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1131
+ (transformer_blocks): ModuleList(
1132
+ (0): BasicTransformerBlock(
1133
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1134
+ (attn1): Attention(
1135
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1136
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1137
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1138
+ (to_out): ModuleList(
1139
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1140
+ (1): Dropout(p=0.0, inplace=False)
1141
+ )
1142
+ )
1143
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1144
+ (attn2): Attention(
1145
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1146
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1147
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1148
+ (to_out): ModuleList(
1149
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1150
+ (1): Dropout(p=0.0, inplace=False)
1151
+ )
1152
+ )
1153
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1154
+ (ff): FeedForward(
1155
+ (net): ModuleList(
1156
+ (0): GEGLU(
1157
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
1158
+ )
1159
+ (1): Dropout(p=0.0, inplace=False)
1160
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
1161
+ )
1162
+ )
1163
+ )
1164
+ )
1165
+ (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1166
+ )
1167
+ )
1168
+ (resnets): ModuleList(
1169
+ (0-1): 2 x ResnetBlock2D(
1170
+ (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)
1171
+ (conv1): LoRACompatibleConv(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1172
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1173
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1174
+ (dropout): Dropout(p=0.0, inplace=False)
1175
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1176
+ (nonlinearity): SiLU()
1177
+ (conv_shortcut): LoRACompatibleConv(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
1178
+ )
1179
+ (2): ResnetBlock2D(
1180
+ (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
1181
+ (conv1): LoRACompatibleConv(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1182
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1183
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1184
+ (dropout): Dropout(p=0.0, inplace=False)
1185
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1186
+ (nonlinearity): SiLU()
1187
+ (conv_shortcut): LoRACompatibleConv(1920, 1280, kernel_size=(1, 1), stride=(1, 1))
1188
+ )
1189
+ )
1190
+ (upsamplers): ModuleList(
1191
+ (0): Upsample2D(
1192
+ (conv): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1193
+ )
1194
+ )
1195
+ )
1196
+ (2): CrossAttnUpBlock2D(
1197
+ (attentions): ModuleList(
1198
+ (0-2): 3 x Transformer2DModel(
1199
+ (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
1200
+ (proj_in): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
1201
+ (transformer_blocks): ModuleList(
1202
+ (0): BasicTransformerBlock(
1203
+ (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
1204
+ (attn1): Attention(
1205
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
1206
+ (to_k): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
1207
+ (to_v): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
1208
+ (to_out): ModuleList(
1209
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
1210
+ (1): Dropout(p=0.0, inplace=False)
1211
+ )
1212
+ )
1213
+ (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
1214
+ (attn2): Attention(
1215
+ (to_q): LoRACompatibleLinear(in_features=640, out_features=640, bias=False)
1216
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
1217
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=640, bias=False)
1218
+ (to_out): ModuleList(
1219
+ (0): LoRACompatibleLinear(in_features=640, out_features=640, bias=True)
1220
+ (1): Dropout(p=0.0, inplace=False)
1221
+ )
1222
+ )
1223
+ (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
1224
+ (ff): FeedForward(
1225
+ (net): ModuleList(
1226
+ (0): GEGLU(
1227
+ (proj): LoRACompatibleLinear(in_features=640, out_features=5120, bias=True)
1228
+ )
1229
+ (1): Dropout(p=0.0, inplace=False)
1230
+ (2): LoRACompatibleLinear(in_features=2560, out_features=640, bias=True)
1231
+ )
1232
+ )
1233
+ )
1234
+ )
1235
+ (proj_out): LoRACompatibleConv(640, 640, kernel_size=(1, 1), stride=(1, 1))
1236
+ )
1237
+ )
1238
+ (resnets): ModuleList(
1239
+ (0): ResnetBlock2D(
1240
+ (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)
1241
+ (conv1): LoRACompatibleConv(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1242
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
1243
+ (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
1244
+ (dropout): Dropout(p=0.0, inplace=False)
1245
+ (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1246
+ (nonlinearity): SiLU()
1247
+ (conv_shortcut): LoRACompatibleConv(1920, 640, kernel_size=(1, 1), stride=(1, 1))
1248
+ )
1249
+ (1): ResnetBlock2D(
1250
+ (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
1251
+ (conv1): LoRACompatibleConv(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1252
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
1253
+ (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
1254
+ (dropout): Dropout(p=0.0, inplace=False)
1255
+ (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1256
+ (nonlinearity): SiLU()
1257
+ (conv_shortcut): LoRACompatibleConv(1280, 640, kernel_size=(1, 1), stride=(1, 1))
1258
+ )
1259
+ (2): ResnetBlock2D(
1260
+ (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
1261
+ (conv1): LoRACompatibleConv(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1262
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=640, bias=True)
1263
+ (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
1264
+ (dropout): Dropout(p=0.0, inplace=False)
1265
+ (conv2): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1266
+ (nonlinearity): SiLU()
1267
+ (conv_shortcut): LoRACompatibleConv(960, 640, kernel_size=(1, 1), stride=(1, 1))
1268
+ )
1269
+ )
1270
+ (upsamplers): ModuleList(
1271
+ (0): Upsample2D(
1272
+ (conv): LoRACompatibleConv(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1273
+ )
1274
+ )
1275
+ )
1276
+ (3): CrossAttnUpBlock2D(
1277
+ (attentions): ModuleList(
1278
+ (0-2): 3 x Transformer2DModel(
1279
+ (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
1280
+ (proj_in): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
1281
+ (transformer_blocks): ModuleList(
1282
+ (0): BasicTransformerBlock(
1283
+ (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
1284
+ (attn1): Attention(
1285
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
1286
+ (to_k): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
1287
+ (to_v): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
1288
+ (to_out): ModuleList(
1289
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
1290
+ (1): Dropout(p=0.0, inplace=False)
1291
+ )
1292
+ )
1293
+ (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
1294
+ (attn2): Attention(
1295
+ (to_q): LoRACompatibleLinear(in_features=320, out_features=320, bias=False)
1296
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
1297
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=320, bias=False)
1298
+ (to_out): ModuleList(
1299
+ (0): LoRACompatibleLinear(in_features=320, out_features=320, bias=True)
1300
+ (1): Dropout(p=0.0, inplace=False)
1301
+ )
1302
+ )
1303
+ (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
1304
+ (ff): FeedForward(
1305
+ (net): ModuleList(
1306
+ (0): GEGLU(
1307
+ (proj): LoRACompatibleLinear(in_features=320, out_features=2560, bias=True)
1308
+ )
1309
+ (1): Dropout(p=0.0, inplace=False)
1310
+ (2): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
1311
+ )
1312
+ )
1313
+ )
1314
+ )
1315
+ (proj_out): LoRACompatibleConv(320, 320, kernel_size=(1, 1), stride=(1, 1))
1316
+ )
1317
+ )
1318
+ (resnets): ModuleList(
1319
+ (0): ResnetBlock2D(
1320
+ (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)
1321
+ (conv1): LoRACompatibleConv(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1322
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
1323
+ (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
1324
+ (dropout): Dropout(p=0.0, inplace=False)
1325
+ (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1326
+ (nonlinearity): SiLU()
1327
+ (conv_shortcut): LoRACompatibleConv(960, 320, kernel_size=(1, 1), stride=(1, 1))
1328
+ )
1329
+ (1-2): 2 x ResnetBlock2D(
1330
+ (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
1331
+ (conv1): LoRACompatibleConv(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1332
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
1333
+ (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
1334
+ (dropout): Dropout(p=0.0, inplace=False)
1335
+ (conv2): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1336
+ (nonlinearity): SiLU()
1337
+ (conv_shortcut): LoRACompatibleConv(640, 320, kernel_size=(1, 1), stride=(1, 1))
1338
+ )
1339
+ )
1340
+ )
1341
+ )
1342
+ (mid_block): UNetMidBlock2DCrossAttn(
1343
+ (attentions): ModuleList(
1344
+ (0): Transformer2DModel(
1345
+ (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
1346
+ (proj_in): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1347
+ (transformer_blocks): ModuleList(
1348
+ (0): BasicTransformerBlock(
1349
+ (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1350
+ (attn1): Attention(
1351
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1352
+ (to_k): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1353
+ (to_v): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1354
+ (to_out): ModuleList(
1355
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1356
+ (1): Dropout(p=0.0, inplace=False)
1357
+ )
1358
+ )
1359
+ (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1360
+ (attn2): Attention(
1361
+ (to_q): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=False)
1362
+ (to_k): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1363
+ (to_v): LoRACompatibleLinear(in_features=768, out_features=1280, bias=False)
1364
+ (to_out): ModuleList(
1365
+ (0): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1366
+ (1): Dropout(p=0.0, inplace=False)
1367
+ )
1368
+ )
1369
+ (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
1370
+ (ff): FeedForward(
1371
+ (net): ModuleList(
1372
+ (0): GEGLU(
1373
+ (proj): LoRACompatibleLinear(in_features=1280, out_features=10240, bias=True)
1374
+ )
1375
+ (1): Dropout(p=0.0, inplace=False)
1376
+ (2): LoRACompatibleLinear(in_features=5120, out_features=1280, bias=True)
1377
+ )
1378
+ )
1379
+ )
1380
+ )
1381
+ (proj_out): LoRACompatibleConv(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
1382
+ )
1383
+ )
1384
+ (resnets): ModuleList(
1385
+ (0-1): 2 x ResnetBlock2D(
1386
+ (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)
1387
+ (conv1): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1388
+ (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
1389
+ (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)
1390
+ (dropout): Dropout(p=0.0, inplace=False)
1391
+ (conv2): LoRACompatibleConv(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1392
+ (nonlinearity): SiLU()
1393
+ )
1394
+ )
1395
+ )
1396
+ (conv_norm_out): None
1397
+ (conv_act): SiLU()
1398
+ )
1399
+ Pose Guider structure:
1400
+ PoseGuider(
1401
+ (conv_in): InflatedConv3d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1402
+ (blocks): ModuleList(
1403
+ (0): InflatedConv3d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1404
+ (1): InflatedConv3d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1405
+ (2): InflatedConv3d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1406
+ (3): InflatedConv3d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1407
+ (4): InflatedConv3d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1408
+ (5): InflatedConv3d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1409
+ )
1410
+ (conv_out): InflatedConv3d(256, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1411
+ )
1412
+ image_enc:
1413
+ CLIPVisionModelWithProjection(
1414
+ (vision_model): CLIPVisionTransformer(
1415
+ (embeddings): CLIPVisionEmbeddings(
1416
+ (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
1417
+ (position_embedding): Embedding(257, 1024)
1418
+ )
1419
+ (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
1420
+ (encoder): CLIPEncoder(
1421
+ (layers): ModuleList(
1422
+ (0-23): 24 x CLIPEncoderLayer(
1423
+ (self_attn): CLIPAttention(
1424
+ (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
1425
+ (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
1426
+ (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
1427
+ (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
1428
+ )
1429
+ (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
1430
+ (mlp): CLIPMLP(
1431
+ (activation_fn): QuickGELUActivation()
1432
+ (fc1): Linear(in_features=1024, out_features=4096, bias=True)
1433
+ (fc2): Linear(in_features=4096, out_features=1024, bias=True)
1434
+ )
1435
+ (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
1436
+ )
1437
+ )
1438
+ )
1439
+ (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
1440
+ )
1441
+ (visual_projection): Linear(in_features=1024, out_features=768, bias=False)
1442
+ )
1443
+ Pose Guider structure:
1444
+ PoseGuider(
1445
+ (conv_in): InflatedConv3d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1446
+ (blocks): ModuleList(
1447
+ (0): InflatedConv3d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1448
+ (1): InflatedConv3d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1449
+ (2): InflatedConv3d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1450
+ (3): InflatedConv3d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1451
+ (4): InflatedConv3d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1452
+ (5): InflatedConv3d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
1453
+ )
1454
+ (conv_out): InflatedConv3d(256, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1455
+ )
1456
+ pipe:
1457
+ Pose2VideoPipeline {
1458
+ "_class_name": "Pose2VideoPipeline",
1459
+ "_diffusers_version": "0.24.0",
1460
+ "denoising_unet": [
1461
+ "src.models.unet_3d",
1462
+ "UNet3DConditionModel"
1463
+ ],
1464
+ "image_encoder": [
1465
+ "transformers",
1466
+ "CLIPVisionModelWithProjection"
1467
+ ],
1468
+ "image_proj_model": [
1469
+ null,
1470
+ null
1471
+ ],
1472
+ "pose_guider": [
1473
+ "src.models.pose_guider",
1474
+ "PoseGuider"
1475
+ ],
1476
+ "reference_unet": [
1477
+ "src.models.unet_2d_condition",
1478
+ "UNet2DConditionModel"
1479
+ ],
1480
+ "scheduler": [
1481
+ "diffusers",
1482
+ "DDIMScheduler"
1483
+ ],
1484
+ "text_encoder": [
1485
+ null,
1486
+ null
1487
+ ],
1488
+ "tokenizer": [
1489
+ null,
1490
+ null
1491
+ ],
1492
+ "vae": [
1493
+ "diffusers",
1494
+ "AutoencoderKL"
1495
+ ]
1496
+ }
1497
+
myoutput.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nohup: ignoring input
2
+ nohup: failed to run command 'CUDA_VISIBLE_DEVICES=2': No such file or directory
nohup.out ADDED
The diff for this file is too large to render. See raw diff
 
output.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nohup: ignoring input
2
+ nohup: failed to run command 'CUDA_VISIBLE_DEVICES=2': No such file or directory
output/20241207/1929--seed_42-384x512/upper1_00057_00_512x384_3_1929.mp4 ADDED
Binary file (233 kB). View file
 
output/20241207/2241--seed_42-384x512/3_s_1110342_in_xl_512x384_3_2241.mp4 ADDED
Binary file (194 kB). View file
 
output/20241207/2241--seed_42-384x512/7_s_1110342_in_xl_512x384_3_2241.mp4 ADDED
Binary file (196 kB). View file
 
output/20241207/2241--seed_42-384x512/8_s_1009794_in_xl_512x384_3_2241.mp4 ADDED
Binary file (201 kB). View file
 
output/20241207/2241--seed_42-384x512/8_s_1110342_in_xl_512x384_3_2241.mp4 ADDED
Binary file (201 kB). View file
 
read.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ # 假设对应关系存储在名为 "file_pairs.txt" 的文本文件中
4
+ file_pairs_file = "./dataset/ViViD/upper_body/test_pairs.txt"
5
+ output_yaml_path = "./configs/prompts/upper_body2.yaml" # 输出的 YAML 文件路径
6
+ videos_dir = "./dataset/ViViD/upper_body/videos"
7
+ images_dir = "./dataset/ViViD/upper_body/images"
8
+ # 准备要写入 YAML 的数据结构
9
+ yaml_data = {
10
+ "pretrained_base_model_path": "ckpts/sd-image-variations-diffusers",
11
+ "pretrained_vae_path": "ckpts/sd-vae-ft-mse",
12
+ "image_encoder_path": "ckpts/sd-image-variations-diffusers/image_encoder",
13
+ "denoising_unet_path": "ckpts/ViViD/denoising_unet.pth",
14
+ "reference_unet_path": "ckpts/ViViD/reference_unet.pth",
15
+ "pose_guider_path": "ckpts/ViViD/pose_guider.pth",
16
+ "motion_module_path": "ckpts/MotionModule/mm_sd_v15_v2.ckpt",
17
+ "inference_config": "./configs/inference/inference.yaml",
18
+ "weight_dtype": "fp16",
19
+ "model_video_paths": [],
20
+ "cloth_image_paths": []
21
+ }
22
+
23
+
24
+ # 读取文本文件并填充 YAML 数据结构
25
+ with open(file_pairs_file, 'r') as file:
26
+ for line in file:
27
+ # 每行可能是 "视频文件路径 对应图像文件路径"
28
+ video_file_name, image_file_name = line.strip().split() # 假设用空格分隔
29
+ # 构建完整的路径
30
+ video_path = os.path.join(videos_dir, video_file_name) # 完整视频文件路径
31
+ image_path = os.path.join(images_dir, image_file_name) # 完整图像文件路径
32
+ yaml_data["model_video_paths"].append(video_path) # 添加视频文件路径
33
+ yaml_data["cloth_image_paths"].append(image_path) # 添加图像文件路径
34
+
35
+ # 将数据写入 YAML 文件
36
+ with open(output_yaml_path, 'w') as yaml_file:
37
+ yaml.dump(yaml_data, yaml_file, default_flow_style=False)
38
+
39
+ print(f"YAML 文件已生成: {output_yaml_path}")
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ av==11.0.0
3
+ clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4
+ decord==0.6.0
5
+ diffusers==0.24.0
6
+ einops==0.4.1
7
+ gradio==3.41.2
8
+ gradio_client==0.5.0
9
+ imageio==2.33.0
10
+ imageio-ffmpeg==0.4.9
11
+ numpy==1.23.5
12
+ omegaconf==2.2.3
13
+ onnxruntime-gpu==1.16.3
14
+ open-clip-torch==2.20.0
15
+ opencv-contrib-python==4.8.1.78
16
+ opencv-python==4.8.1.78
17
+ Pillow==9.5.0
18
+ scikit-image==0.21.0
19
+ scikit-learn==1.3.2
20
+ scipy==1.11.4
21
+ torch==2.0.1
22
+ torchdiffeq==0.2.3
23
+ torchmetrics==1.2.1
24
+ torchsde==0.2.5
25
+ torchvision==0.15.2
26
+ tqdm==4.66.1
27
+ transformers==4.30.2
28
+ mlflow==2.9.2
29
+ xformers==0.0.22
scripts.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=2 python vivid.py --config /mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/test_lm_build/cloth_complex_dress.yml
2
+
3
+ CUDA_VISIBLE_DEVICES=2 python vivid.py --config /mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/test_lm_build/cloth_complex_low.yml
4
+
5
+ CUDA_VISIBLE_DEVICES=2 python vivid.py --config /mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/test_lm_build/cloth_complex_up.yml
6
+
7
+ CUDA_VISIBLE_DEVICES=2 python vivid.py --config /mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/test_lm_build/complex_motion.yml
stage1_nohup.out ADDED
The diff for this file is too large to render. See raw diff
 
train_stage_1.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import warnings
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from tempfile import TemporaryDirectory
11
+
12
+ import diffusers
13
+ import mlflow
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint
19
+ import transformers
20
+ from accelerate import Accelerator
21
+ from accelerate.logging import get_logger
22
+ from accelerate.utils import DistributedDataParallelKwargs
23
+ from diffusers import AutoencoderKL, DDIMScheduler
24
+ from diffusers.optimization import get_scheduler
25
+ from diffusers.utils import check_min_version
26
+ from diffusers.utils.import_utils import is_xformers_available
27
+ from omegaconf import OmegaConf
28
+ from PIL import Image
29
+ from tqdm.auto import tqdm
30
+ from transformers import CLIPVisionModelWithProjection
31
+
32
+ from src.dataset.dance_image import HumanDanceDataset
33
+ # from src.dwpose import DWposeDetector
34
+ from src.models.mutual_self_attention import ReferenceAttentionControl
35
+ from src.models.pose_guider import PoseGuider
36
+ from src.models.unet_2d_condition import UNet2DConditionModel
37
+ from src.models.unet_3d import UNet3DConditionModel
38
+ from src.pipelines.pipeline_pose2img import Pose2ImagePipeline
39
+ from src.utils.util import delete_additional_ckpt, import_filename, seed_everything
40
+
41
+ warnings.filterwarnings("ignore")
42
+
43
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44
+ check_min_version("0.10.0.dev0")
45
+
46
+ logger = get_logger(__name__, log_level="INFO")
47
+
48
+
49
+ class Net(nn.Module):
50
+ def __init__(
51
+ self,
52
+ reference_unet: UNet2DConditionModel,
53
+ denoising_unet: UNet3DConditionModel,
54
+ pose_guider: PoseGuider,
55
+ reference_control_writer,
56
+ reference_control_reader,
57
+ ):
58
+ super().__init__()
59
+ self.reference_unet = reference_unet
60
+ self.denoising_unet = denoising_unet
61
+ self.pose_guider = pose_guider
62
+ self.reference_control_writer = reference_control_writer
63
+ self.reference_control_reader = reference_control_reader
64
+
65
+ def forward(
66
+ self,
67
+ noisy_latents,
68
+ timesteps,
69
+ ref_image_latents,
70
+ clip_image_embeds,
71
+ pose_img,
72
+ uncond_fwd: bool = False,
73
+ ):
74
+ pose_cond_tensor = pose_img.to(device="cuda")
75
+ pose_fea = self.pose_guider(pose_cond_tensor)
76
+
77
+ if not uncond_fwd:
78
+ ref_timesteps = torch.zeros_like(timesteps)
79
+ self.reference_unet(
80
+ ref_image_latents,
81
+ ref_timesteps,
82
+ encoder_hidden_states=clip_image_embeds,
83
+ return_dict=False,
84
+ )
85
+ self.reference_control_reader.update(self.reference_control_writer)
86
+
87
+ model_pred = self.denoising_unet(
88
+ noisy_latents,
89
+ timesteps,
90
+ pose_cond_fea=pose_fea,
91
+ encoder_hidden_states=clip_image_embeds,
92
+ ).sample
93
+
94
+ return model_pred
95
+
96
+ def log_validation(
97
+ vae,
98
+ image_enc,
99
+ net,
100
+ scheduler,
101
+ accelerator,
102
+ width,
103
+ height,
104
+ save_dir,
105
+ global_step,
106
+ ):
107
+ logger.info("Running validation... ")
108
+
109
+ ori_net = accelerator.unwrap_model(net)
110
+ reference_unet = ori_net.reference_unet
111
+ denoising_unet = ori_net.denoising_unet
112
+ pose_guider = ori_net.pose_guider
113
+
114
+ # generator = torch.manual_seed(42)
115
+ generator = torch.Generator().manual_seed(42)
116
+ # cast unet dtype
117
+ vae = vae.to(dtype=torch.float32)
118
+ image_enc = image_enc.to(dtype=torch.float32)
119
+
120
+ # pose_detector = DWposeDetector()
121
+ # pose_detector.to(accelerator.device)
122
+
123
+ pipe = Pose2ImagePipeline(
124
+ vae=vae,
125
+ image_encoder=image_enc,
126
+ reference_unet=reference_unet,
127
+ denoising_unet=denoising_unet,
128
+ pose_guider=pose_guider,
129
+ scheduler=scheduler,
130
+ )
131
+ pipe = pipe.to(accelerator.device)
132
+ video_image_paths=["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/valid/videos/803137_in_xl.jpg"]
133
+ cloth_paths=["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/valid/cloth/803128_in_xl.jpg"]
134
+ pil_images = []
135
+ for video_image_path in video_image_paths:
136
+ clip_length=1
137
+ for cloth_image_path in cloth_paths:
138
+ agnostic_path=video_image_path.replace("videos","agnostic_images") #data/videos/upper1.mp4——>data/agnostic/upper1.mp4
139
+ agn_mask_path=video_image_path.replace("videos","agnostic_mask_images")
140
+ densepose_path=video_image_path.replace("videos","densepose_images")
141
+ cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask")
142
+
143
+ video_name = video_image_path.split("/")[-1].replace(".jpg", "")
144
+ cloth_name = cloth_image_path.split("/")[-1].replace(".jpg", "")
145
+
146
+ video_image_pil = Image.open(video_image_path).convert("RGB")
147
+ cloth_image_pil = Image.open(cloth_image_path).convert("RGB")
148
+ cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB")
149
+ agnostic_pil = Image.open(agnostic_path).convert("RGB")
150
+ agn_mask_pil = Image.open(agn_mask_path).convert("RGB")
151
+ densepose_pil = Image.open(densepose_path).convert("RGB")
152
+
153
+ image = pipe(
154
+ agnostic_pil,
155
+ agn_mask_pil,
156
+ cloth_image_pil,
157
+ cloth_mask_pil,
158
+ densepose_pil,
159
+ width,
160
+ height,
161
+ clip_length,
162
+ 20,
163
+ 3.5,
164
+ generator=generator,
165
+ ).images
166
+ image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512)
167
+ res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
168
+ # Save ref_image, src_image and the generated_image
169
+ w, h = res_image_pil.size
170
+ canvas = Image.new("RGB", (w * 4, h), "white")
171
+
172
+ cloth_image_pil = cloth_image_pil.resize((w, h))
173
+ video_image_pil = video_image_pil.resize((w, h))
174
+ agnostic_pil = agnostic_pil.resize((w, h))
175
+
176
+
177
+ canvas.paste(cloth_image_pil, (0, 0))
178
+ canvas.paste(video_image_pil, (w, 0))
179
+ canvas.paste(agnostic_pil, (w * 2, 0))
180
+ canvas.paste(res_image_pil, (w * 3, 0))
181
+
182
+ out_file = os.path.join(
183
+ save_dir, f"{global_step:06d}-{video_name}_{cloth_name}.jpg"
184
+ )
185
+ canvas.save(out_file)
186
+
187
+ vae = vae.to(dtype=torch.float32)
188
+ image_enc = image_enc.to(dtype=torch.float32)
189
+
190
+ del pipe
191
+ torch.cuda.empty_cache()
192
+
193
+ return pil_images
194
+
195
+ def compute_snr(noise_scheduler, timesteps):
196
+ """
197
+ Computes SNR as per
198
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
199
+ """
200
+ alphas_cumprod = noise_scheduler.alphas_cumprod
201
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
202
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
203
+
204
+ # Expand the tensors.
205
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
206
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
207
+ timesteps
208
+ ].float()
209
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
210
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
211
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
212
+
213
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
214
+ device=timesteps.device
215
+ )[timesteps].float()
216
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
217
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
218
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
219
+
220
+ # Compute SNR.
221
+ snr = (alpha / sigma) ** 2
222
+ return snr
223
+
224
+
225
+ def main(cfg):
226
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
227
+ accelerator = Accelerator(
228
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
229
+ mixed_precision=cfg.solver.mixed_precision,
230
+ log_with="mlflow",
231
+ project_dir="./mlruns",
232
+ kwargs_handlers=[kwargs],
233
+ )
234
+
235
+ # Make one log on every process with the configuration for debugging.
236
+ logging.basicConfig(
237
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
238
+ datefmt="%m/%d/%Y %H:%M:%S",
239
+ level=logging.INFO,
240
+ )
241
+ logger.info(accelerator.state, main_process_only=False)
242
+ if accelerator.is_local_main_process:
243
+ transformers.utils.logging.set_verbosity_warning()
244
+ diffusers.utils.logging.set_verbosity_info()
245
+ else:
246
+ transformers.utils.logging.set_verbosity_error()
247
+ diffusers.utils.logging.set_verbosity_error()
248
+
249
+ # If passed along, set the training seed now.
250
+ if cfg.seed is not None:
251
+ seed_everything(cfg.seed)
252
+
253
+ exp_name = cfg.exp_name
254
+ save_dir = f"{cfg.output_dir}/{exp_name}"
255
+ if accelerator.is_main_process and not os.path.exists(save_dir):
256
+ os.makedirs(save_dir)
257
+ save_valid_dir = f"{cfg.valid_dir}/{exp_name}"
258
+ if accelerator.is_main_process and not os.path.exists(save_valid_dir):
259
+ os.makedirs(save_valid_dir)
260
+ validation_dir = save_valid_dir
261
+ if cfg.weight_dtype == "fp16":
262
+ weight_dtype = torch.float16
263
+ elif cfg.weight_dtype == "bf16":
264
+ weight_dtype = torch.bfloat16
265
+ elif cfg.weight_dtype == "fp32":
266
+ weight_dtype = torch.float32
267
+ else:
268
+ raise ValueError(
269
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
270
+ )
271
+
272
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
273
+ if cfg.enable_zero_snr:
274
+ sched_kwargs.update(
275
+ rescale_betas_zero_snr=True,
276
+ timestep_spacing="trailing",
277
+ prediction_type="v_prediction",
278
+ )
279
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
280
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
281
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
282
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
283
+ "cuda", dtype=weight_dtype
284
+ )
285
+
286
+ reference_unet = UNet2DConditionModel.from_pretrained_2d(
287
+ config.base_model_path,
288
+ subfolder="unet",
289
+ unet_additional_kwargs={
290
+ "in_channels": 5,
291
+ }
292
+ ).to(dtype=weight_dtype, device="cuda")
293
+
294
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
295
+ cfg.base_model_path,
296
+ "",
297
+ subfolder="unet",
298
+ unet_additional_kwargs={
299
+ "in_channels": 9,
300
+ "use_motion_module": False,
301
+ "unet_use_temporal_attention": False,
302
+ },
303
+ ).to(device="cuda")
304
+
305
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
306
+ cfg.image_encoder_path,
307
+ ).to(dtype=weight_dtype, device="cuda")
308
+
309
+ if cfg.pose_guider_path:
310
+ pose_guider = PoseGuider(
311
+ conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
312
+ ).to(device="cuda")
313
+ # load pretrained controlnet-openpose params for pose_guider
314
+ controlnet_openpose_state_dict = torch.load(cfg.controlnet_openpose_path)
315
+ state_dict_to_load = {}
316
+ for k in controlnet_openpose_state_dict.keys():
317
+ if k.startswith("controlnet_cond_embedding.") and k.find("conv_out") < 0:
318
+ new_k = k.replace("controlnet_cond_embedding.", "")
319
+ state_dict_to_load[new_k] = controlnet_openpose_state_dict[k]
320
+ miss, _ = pose_guider.load_state_dict(state_dict_to_load, strict=False)
321
+ logger.info(f"Missing key for pose guider: {len(miss)}")
322
+ else:
323
+ pose_guider = PoseGuider(
324
+ conditioning_embedding_channels=320,
325
+ ).to(device="cuda")
326
+
327
+ # load pretrained weights
328
+ denoising_unet.load_state_dict(
329
+ torch.load(config.denoising_unet_path, map_location="cpu"),
330
+ strict=True,
331
+ )
332
+ reference_unet.load_state_dict(
333
+ torch.load(config.reference_unet_path, map_location="cpu"),
334
+ strict=True,
335
+ )
336
+
337
+ pose_guider.load_state_dict(
338
+ torch.load(config.pose_guider_path, map_location="cpu"),
339
+ strict=True,
340
+ )
341
+
342
+
343
+ # Freeze
344
+ vae.requires_grad_(False)
345
+ image_enc.requires_grad_(False)
346
+
347
+ # Explictly declare training models
348
+ denoising_unet.requires_grad_(True)
349
+ # Some top layer parames of reference_unet don't need grad
350
+ for name, param in reference_unet.named_parameters():
351
+ if "up_blocks.3" in name:
352
+ param.requires_grad_(False)
353
+ else:
354
+ param.requires_grad_(True)
355
+
356
+ pose_guider.requires_grad_(True)
357
+
358
+ reference_control_writer = ReferenceAttentionControl(
359
+ reference_unet,
360
+ do_classifier_free_guidance=False,
361
+ mode="write",
362
+ fusion_blocks="full",
363
+ )
364
+ reference_control_reader = ReferenceAttentionControl(
365
+ denoising_unet,
366
+ do_classifier_free_guidance=False,
367
+ mode="read",
368
+ fusion_blocks="full",
369
+ )
370
+
371
+ net = Net(
372
+ reference_unet,
373
+ denoising_unet,
374
+ pose_guider,
375
+ reference_control_writer,
376
+ reference_control_reader,
377
+ )
378
+
379
+ if cfg.solver.enable_xformers_memory_efficient_attention:
380
+ if is_xformers_available():
381
+ reference_unet.enable_xformers_memory_efficient_attention()
382
+ denoising_unet.enable_xformers_memory_efficient_attention()
383
+ else:
384
+ raise ValueError(
385
+ "xformers is not available. Make sure it is installed correctly"
386
+ )
387
+
388
+ if cfg.solver.gradient_checkpointing:
389
+ reference_unet.enable_gradient_checkpointing()
390
+ denoising_unet.enable_gradient_checkpointing()
391
+
392
+ if cfg.solver.scale_lr:
393
+ learning_rate = (
394
+ cfg.solver.learning_rate
395
+ * cfg.solver.gradient_accumulation_steps
396
+ * cfg.data.train_bs
397
+ * accelerator.num_processes
398
+ )
399
+ else:
400
+ learning_rate = cfg.solver.learning_rate
401
+
402
+
403
+ optimizer_cls = torch.optim.AdamW
404
+
405
+ trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
406
+ optimizer = optimizer_cls(
407
+ trainable_params,
408
+ lr=learning_rate,
409
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
410
+ weight_decay=cfg.solver.adam_weight_decay,
411
+ eps=cfg.solver.adam_epsilon,
412
+ )
413
+
414
+ # Scheduler
415
+ lr_scheduler = get_scheduler(
416
+ cfg.solver.lr_scheduler,
417
+ optimizer=optimizer,
418
+ num_warmup_steps=cfg.solver.lr_warmup_steps
419
+ * cfg.solver.gradient_accumulation_steps,
420
+ num_training_steps=cfg.solver.max_train_steps
421
+ * cfg.solver.gradient_accumulation_steps,
422
+ )
423
+
424
+ train_dataset = HumanDanceDataset(
425
+ img_size=(cfg.data.train_width, cfg.data.train_height),
426
+ img_scale=(0.9, 1.0),
427
+ data_meta_paths=cfg.data.meta_paths,
428
+ sample_margin=cfg.data.sample_margin,
429
+ )
430
+ train_dataloader = torch.utils.data.DataLoader(
431
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
432
+ )
433
+
434
+ # Prepare everything with our `accelerator`.
435
+ (
436
+ net,
437
+ optimizer,
438
+ train_dataloader,
439
+ lr_scheduler,
440
+ ) = accelerator.prepare(
441
+ net,
442
+ optimizer,
443
+ train_dataloader,
444
+ lr_scheduler,
445
+ )
446
+
447
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
448
+ num_update_steps_per_epoch = math.ceil(
449
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
450
+ )
451
+ # Afterwards we recalculate our number of training epochs
452
+ num_train_epochs = math.ceil(
453
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
454
+ )
455
+
456
+ # We need to initialize the trackers we use, and also store our configuration.
457
+ # The trackers initializes automatically on the main process.
458
+ if accelerator.is_main_process:
459
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
460
+ accelerator.init_trackers(
461
+ cfg.exp_name,
462
+ init_kwargs={"mlflow": {"run_name": run_time}},
463
+ )
464
+ # dump config file
465
+ mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
466
+
467
+ # Train!
468
+ total_batch_size = (
469
+ cfg.data.train_bs
470
+ * accelerator.num_processes
471
+ * cfg.solver.gradient_accumulation_steps
472
+ )
473
+
474
+ logger.info("***** Running training *****")
475
+ logger.info(f" Num examples = {len(train_dataset)}")
476
+ logger.info(f" Num Epochs = {num_train_epochs}")
477
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
478
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
479
+ logger.info(f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
480
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
481
+ global_step = 0
482
+ first_epoch = 0
483
+
484
+ # Potentially load in the weights and states from a previous save
485
+ if cfg.resume_from_checkpoint:
486
+ if cfg.resume_from_checkpoint != "latest":
487
+ resume_dir = cfg.resume_from_checkpoint
488
+ else:
489
+ resume_dir = save_dir
490
+ # Get the most recent checkpoint
491
+ dirs = os.listdir(resume_dir)
492
+ print( dirs)
493
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
494
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
495
+ path = dirs[-1]
496
+ accelerator.load_state(os.path.join(resume_dir, path))
497
+ accelerator.print(f"Resuming from checkpoint {path}")
498
+ global_step = int(path.split("-")[1])
499
+
500
+ first_epoch = global_step // num_update_steps_per_epoch
501
+ resume_step = global_step % num_update_steps_per_epoch
502
+
503
+ # Only show the progress bar once on each machine.
504
+ progress_bar = tqdm(
505
+ range(global_step, cfg.solver.max_train_steps),
506
+ disable=not accelerator.is_local_main_process,
507
+ )
508
+ progress_bar.set_description("Steps")
509
+
510
+ for epoch in range(first_epoch, num_train_epochs):
511
+ train_loss = 0.0
512
+ for step, batch in enumerate(train_dataloader):
513
+ # print(batch.keys())
514
+ with accelerator.accumulate(net):
515
+ # Convert videos to latent space
516
+ pixel_values = batch["tgt_img"].to(weight_dtype)
517
+ masked_pixel_values = batch["agnostic_img"].to(weight_dtype)
518
+ mask_of_pixel_values = batch["agnostic_mask_img"].to(weight_dtype)[:,0:1,:,:]
519
+ with torch.no_grad():
520
+ # print(pixel_values.dtype)
521
+ latents = vae.encode(pixel_values).latent_dist.sample()
522
+ latents = latents.unsqueeze(2) # (b, c, 1, h, w)
523
+ latents = latents * 0.18215
524
+
525
+ masked_latents = vae.encode(masked_pixel_values).latent_dist.sample().unsqueeze(2) * 0.18215
526
+ mask_of_latents = torch.nn.functional.interpolate(mask_of_pixel_values.unsqueeze(2), size=(1,mask_of_pixel_values.shape[-2] // 8, mask_of_pixel_values.shape[-1] // 8))
527
+
528
+
529
+ noise = torch.randn_like(latents)
530
+ if cfg.noise_offset > 0.0:
531
+ noise += cfg.noise_offset * torch.randn(
532
+ (noise.shape[0], noise.shape[1], 1, 1, 1),
533
+ device=noise.device,
534
+ )
535
+
536
+ bsz = latents.shape[0]
537
+ # Sample a random timestep for each video
538
+ timesteps = torch.randint(
539
+ 0,
540
+ train_noise_scheduler.num_train_timesteps,
541
+ (bsz,),
542
+ device=latents.device,
543
+ )
544
+ timesteps = timesteps.long()
545
+
546
+ tgt_pose_img = batch["tgt_pose"]
547
+ tgt_pose_img = tgt_pose_img.unsqueeze(2) # (bs, 3, 1, 512, 512)
548
+
549
+ uncond_fwd = random.random() < cfg.uncond_ratio
550
+ clip_image_list = []
551
+ ref_image_list = []
552
+ cloth_mask_list = []
553
+ for batch_idx, (ref_img, cloth_mask, clip_img) in enumerate(
554
+ zip(
555
+ batch["cloth_img"],
556
+ batch["cloth_mask"],
557
+ batch["clip_images"],
558
+ )
559
+ ):
560
+ if uncond_fwd:
561
+ clip_image_list.append(torch.zeros_like(clip_img))
562
+ else:
563
+ clip_image_list.append(clip_img)
564
+ ref_image_list.append(ref_img)
565
+ cloth_mask_list.append(cloth_mask)
566
+
567
+ with torch.no_grad():
568
+ ref_img = torch.stack(ref_image_list, dim=0).to(
569
+ dtype=vae.dtype, device=vae.device
570
+ )
571
+ ref_image_latents = vae.encode(
572
+ ref_img
573
+ ).latent_dist.sample() # (bs, d, 64, 64)
574
+ ref_image_latents = ref_image_latents * 0.18215
575
+
576
+ cloth_mask = torch.stack(cloth_mask_list, dim=0).to(
577
+ dtype=vae.dtype, device=vae.device
578
+ )
579
+ cloth_mask = cloth_mask[:,0:1,:,:]
580
+ cloth_mask = torch.nn.functional.interpolate(cloth_mask, size=(cloth_mask.shape[-2] // 8, cloth_mask.shape[-1] // 8))
581
+
582
+
583
+ clip_img = torch.stack(clip_image_list, dim=0).to(
584
+ dtype=image_enc.dtype, device=image_enc.device
585
+ )
586
+ clip_image_embeds = image_enc(
587
+ clip_img.to("cuda", dtype=weight_dtype)
588
+ ).image_embeds
589
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d)
590
+
591
+ # add noise
592
+ noisy_latents = train_noise_scheduler.add_noise(
593
+ latents, noise, timesteps
594
+ )
595
+
596
+ # Get the target for loss depending on the prediction type
597
+ if train_noise_scheduler.prediction_type == "epsilon":
598
+ target = noise
599
+ elif train_noise_scheduler.prediction_type == "v_prediction":
600
+ target = train_noise_scheduler.get_velocity(
601
+ latents, noise, timesteps
602
+ )
603
+ else:
604
+ raise ValueError(
605
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
606
+ )
607
+
608
+ model_pred = net(
609
+ # noisy_latents,
610
+ torch.cat([noisy_latents,masked_latents,mask_of_latents],dim=1),
611
+ timesteps,
612
+ torch.cat([ref_image_latents, cloth_mask],dim=1),
613
+ image_prompt_embeds,
614
+ tgt_pose_img,
615
+ uncond_fwd,
616
+ )
617
+
618
+ if cfg.snr_gamma == 0:
619
+ loss = F.mse_loss(
620
+ model_pred.float(), target.float(), reduction="mean"
621
+ )
622
+ else:
623
+ snr = compute_snr(train_noise_scheduler, timesteps)
624
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
625
+ # Velocity objective requires that we add one to SNR values before we divide by them.
626
+ snr = snr + 1
627
+ mse_loss_weights = (
628
+ torch.stack(
629
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
630
+ ).min(dim=1)[0]
631
+ / snr
632
+ )
633
+ loss = F.mse_loss(
634
+ model_pred.float(), target.float(), reduction="none"
635
+ )
636
+ loss = (
637
+ loss.mean(dim=list(range(1, len(loss.shape))))
638
+ * mse_loss_weights
639
+ )
640
+ loss = loss.mean()
641
+
642
+ # Gather the losses across all processes for logging (if we use distributed training).
643
+ avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
644
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
645
+
646
+ # Backpropagate
647
+ accelerator.backward(loss)
648
+ if accelerator.sync_gradients:
649
+ accelerator.clip_grad_norm_(
650
+ trainable_params,
651
+ cfg.solver.max_grad_norm,
652
+ )
653
+ optimizer.step()
654
+ lr_scheduler.step()
655
+ optimizer.zero_grad()
656
+
657
+ if accelerator.sync_gradients:
658
+ reference_control_reader.clear()
659
+ reference_control_writer.clear()
660
+ progress_bar.update(1)
661
+ global_step += 1
662
+ accelerator.log({"train_loss": train_loss}, step=global_step)
663
+ train_loss = 0.0
664
+
665
+ if global_step % cfg.checkpointing_steps == 0:
666
+ if accelerator.is_main_process:
667
+ save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
668
+ delete_additional_ckpt(save_dir, 1)
669
+ accelerator.save_state(save_path)
670
+
671
+ if global_step % cfg.val.validation_steps == 0:
672
+ if accelerator.is_main_process:
673
+ generator = torch.Generator(device=accelerator.device)
674
+ generator.manual_seed(cfg.seed)
675
+
676
+ log_validation(
677
+ vae=vae,
678
+ image_enc=image_enc,
679
+ net=net,
680
+ scheduler=val_noise_scheduler,
681
+ accelerator=accelerator,
682
+ width=cfg.data.train_width,
683
+ height=cfg.data.train_height,
684
+ save_dir=validation_dir,
685
+ global_step=global_step,
686
+ )
687
+
688
+ # for sample_id, sample_dict in enumerate(sample_dicts):
689
+ # sample_name = sample_dict["name"]
690
+ # img = sample_dict["img"]
691
+ # with TemporaryDirectory() as temp_dir:
692
+ # out_file = Path(
693
+ # f"{temp_dir}/{global_step:06d}-{sample_name}.gif"
694
+ # )
695
+ # img.save(out_file)
696
+ # mlflow.log_artifact(out_file)
697
+
698
+
699
+ logs = {
700
+ "step_loss": loss.detach().item(),
701
+ "lr": lr_scheduler.get_last_lr()[0],
702
+ }
703
+ progress_bar.set_postfix(**logs)
704
+
705
+ if global_step >= cfg.solver.max_train_steps:
706
+ break
707
+
708
+ # save model after each epoch
709
+ if (
710
+ epoch + 1
711
+ ) % cfg.save_model_epoch_interval == 0 and accelerator.is_main_process:
712
+ unwrap_net = accelerator.unwrap_model(net)
713
+ save_checkpoint(
714
+ unwrap_net.reference_unet,
715
+ save_dir,
716
+ "reference_unet",
717
+ global_step,
718
+ total_limit=3,
719
+ )
720
+ save_checkpoint(
721
+ unwrap_net.denoising_unet,
722
+ save_dir,
723
+ "denoising_unet",
724
+ global_step,
725
+ total_limit=3,
726
+ )
727
+ save_checkpoint(
728
+ unwrap_net.pose_guider,
729
+ save_dir,
730
+ "pose_guider",
731
+ global_step,
732
+ total_limit=3,
733
+ )
734
+
735
+ # Create the pipeline using the trained modules and save it.
736
+ accelerator.wait_for_everyone()
737
+ accelerator.end_training()
738
+
739
+
740
+ def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None):
741
+ save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
742
+
743
+ if total_limit is not None:
744
+ checkpoints = os.listdir(save_dir)
745
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
746
+ checkpoints = sorted(
747
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
748
+ )
749
+
750
+ if len(checkpoints) >= total_limit:
751
+ num_to_remove = len(checkpoints) - total_limit + 1
752
+ removing_checkpoints = checkpoints[0:num_to_remove]
753
+ logger.info(
754
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
755
+ )
756
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
757
+
758
+ for removing_checkpoint in removing_checkpoints:
759
+ removing_checkpoint = os.path.join(save_dir, removing_checkpoint)
760
+ os.remove(removing_checkpoint)
761
+
762
+ state_dict = model.state_dict()
763
+ torch.save(state_dict, save_path)
764
+
765
+
766
+ if __name__ == "__main__":
767
+ parser = argparse.ArgumentParser()
768
+ parser.add_argument("--config", type=str, default="./configs/training/stage1.yaml")
769
+ args = parser.parse_args()
770
+
771
+ if args.config[-5:] == ".yaml":
772
+ config = OmegaConf.load(args.config)
773
+ elif args.config[-3:] == ".py":
774
+ config = import_filename(args.config).cfg
775
+ else:
776
+ raise ValueError("Do not support this format config file")
777
+ main(config)
778
+
779
+
780
+ # accelerate launch train_stage_1.py --config configs/train/stage1.yaml
781
+ # accelerate launch train_stage_2.py --config configs/train/stage2.yaml
train_stage_2.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import random
8
+ import time
9
+ import warnings
10
+ from collections import OrderedDict
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from tempfile import TemporaryDirectory
14
+ from src.utils.util import get_fps, read_frames, save_videos_grid
15
+
16
+ import diffusers
17
+ import mlflow
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ import transformers
23
+ from accelerate import Accelerator
24
+ from accelerate.logging import get_logger
25
+ from accelerate.utils import DistributedDataParallelKwargs
26
+ from diffusers import AutoencoderKL, DDIMScheduler
27
+ from diffusers.optimization import get_scheduler
28
+ from diffusers.utils import check_min_version
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+ from einops import rearrange
31
+ from omegaconf import OmegaConf
32
+ from PIL import Image
33
+ from torchvision import transforms
34
+ from tqdm.auto import tqdm
35
+ from transformers import CLIPVisionModelWithProjection
36
+
37
+ from src.dataset.dance_video import HumanDanceVideoDataset
38
+ from src.models.mutual_self_attention import ReferenceAttentionControl
39
+ from src.models.pose_guider import PoseGuider
40
+ from src.models.unet_2d_condition import UNet2DConditionModel
41
+ from src.models.unet_3d import UNet3DConditionModel
42
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
43
+ from src.utils.util import (
44
+ delete_additional_ckpt,
45
+ import_filename,
46
+ read_frames,
47
+ save_videos_grid,
48
+ seed_everything,
49
+ )
50
+
51
+ warnings.filterwarnings("ignore")
52
+
53
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
54
+ check_min_version("0.10.0.dev0")
55
+
56
+ logger = get_logger(__name__, log_level="INFO")
57
+
58
+
59
+ class Net(nn.Module):
60
+ def __init__(
61
+ self,
62
+ reference_unet: UNet2DConditionModel,
63
+ denoising_unet: UNet3DConditionModel,
64
+ pose_guider: PoseGuider,
65
+ reference_control_writer,
66
+ reference_control_reader,
67
+ ):
68
+ super().__init__()
69
+ self.reference_unet = reference_unet
70
+ self.denoising_unet = denoising_unet
71
+ self.pose_guider = pose_guider
72
+ self.reference_control_writer = reference_control_writer
73
+ self.reference_control_reader = reference_control_reader
74
+
75
+ def forward(
76
+ self,
77
+ noisy_latents,
78
+ timesteps,
79
+ ref_image_latents,
80
+ clip_image_embeds,
81
+ pose_img,
82
+ uncond_fwd: bool = False,
83
+ ):
84
+ pose_cond_tensor = pose_img.to(device="cuda")
85
+ pose_fea = self.pose_guider(pose_cond_tensor)
86
+
87
+ if not uncond_fwd:
88
+ ref_timesteps = torch.zeros_like(timesteps)
89
+ self.reference_unet(
90
+ ref_image_latents,
91
+ ref_timesteps,
92
+ encoder_hidden_states=clip_image_embeds,
93
+ return_dict=False,
94
+ )
95
+ self.reference_control_reader.update(self.reference_control_writer)
96
+
97
+ model_pred = self.denoising_unet(
98
+ noisy_latents,
99
+ timesteps,
100
+ pose_cond_fea=pose_fea,
101
+ encoder_hidden_states=clip_image_embeds,
102
+ ).sample
103
+
104
+ return model_pred
105
+
106
+
107
+ def compute_snr(noise_scheduler, timesteps):
108
+ """
109
+ Computes SNR as per
110
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
111
+ """
112
+ alphas_cumprod = noise_scheduler.alphas_cumprod
113
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
114
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
115
+
116
+ # Expand the tensors.
117
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
118
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
119
+ timesteps
120
+ ].float()
121
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
122
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
123
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
124
+
125
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
126
+ device=timesteps.device
127
+ )[timesteps].float()
128
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
129
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
130
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
131
+
132
+ # Compute SNR.
133
+ snr = (alpha / sigma) ** 2
134
+ return snr
135
+
136
+
137
+ def log_validation(
138
+ vae,
139
+ image_enc,
140
+ net,
141
+ scheduler,
142
+ accelerator,
143
+ width,
144
+ height,
145
+ global_step,
146
+ clip_length=24,
147
+ generator=None,
148
+
149
+ ):
150
+ logger.info("Running validation... ")
151
+
152
+ ori_net = accelerator.unwrap_model(net)
153
+ reference_unet = ori_net.reference_unet
154
+ denoising_unet = ori_net.denoising_unet
155
+ pose_guider = ori_net.pose_guider
156
+
157
+ if generator is None:
158
+ generator = torch.manual_seed(42)
159
+ tmp_denoising_unet = copy.deepcopy(denoising_unet)
160
+ tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16)
161
+
162
+ pipe = Pose2VideoPipeline(
163
+ vae=vae,
164
+ image_encoder=image_enc,
165
+ reference_unet=reference_unet,
166
+ denoising_unet=tmp_denoising_unet,
167
+ pose_guider=pose_guider,
168
+ scheduler=scheduler,
169
+ )
170
+ pipe = pipe.to(accelerator.device)
171
+ date_str = datetime.now().strftime("%Y%m%d")
172
+ time_str = datetime.now().strftime("%H%M")
173
+ save_dir_name = f"{time_str}"
174
+ save_dir = Path(f"vividfuxian_motion/{date_str}/{save_dir_name}")
175
+ save_dir.mkdir(exist_ok=True, parents=True)
176
+
177
+ model_video_paths = ["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/videos/803128_detail.mp4"]
178
+ cloth_image_paths=["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/images/1060638_in_xl.jpg"]
179
+ transform = transforms.Compose(
180
+ [transforms.Resize((height, width)), transforms.ToTensor()]
181
+ )
182
+ for model_image_path in model_video_paths:
183
+ src_fps = get_fps(model_image_path)
184
+
185
+ model_name = Path(model_image_path).stem
186
+ agnostic_path=model_image_path.replace("videos","agnostic")
187
+ agn_mask_path=model_image_path.replace("videos","agnostic_mask")
188
+ densepose_path=model_image_path.replace("videos","densepose")
189
+
190
+ video_tensor_list=[]
191
+ video_images=read_frames(model_image_path)
192
+
193
+ for vid_image_pil in video_images[:clip_length]:
194
+ video_tensor_list.append(transform(vid_image_pil))
195
+
196
+ video_tensor = torch.stack(video_tensor_list, dim=0) # (f, c, h, w)
197
+ video_tensor = video_tensor.transpose(0, 1)
198
+
199
+ agnostic_list=[]
200
+ agnostic_images=read_frames(agnostic_path)
201
+ for agnostic_image_pil in agnostic_images[:clip_length]:
202
+ agnostic_list.append(agnostic_image_pil)
203
+
204
+ agn_mask_list=[]
205
+ agn_mask_images=read_frames(agn_mask_path)
206
+ for agn_mask_image_pil in agn_mask_images[:clip_length]:
207
+ agn_mask_list.append(agn_mask_image_pil)
208
+
209
+ pose_list=[]
210
+ pose_images=read_frames(densepose_path)
211
+ for pose_image_pil in pose_images[:clip_length]:
212
+ pose_list.append(pose_image_pil)
213
+
214
+ video_tensor = video_tensor.unsqueeze(0)
215
+
216
+
217
+ for cloth_image_path in cloth_image_paths:
218
+ cloth_name = Path(cloth_image_path).stem
219
+ cloth_image_pil = Image.open(cloth_image_path).convert("RGB")
220
+
221
+ cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask")
222
+ cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB")
223
+
224
+ pipeline_output = pipe(
225
+ agnostic_list,
226
+ agn_mask_list,
227
+ cloth_image_pil,
228
+ cloth_mask_pil,
229
+ pose_list,
230
+ width,
231
+ height,
232
+ clip_length,
233
+ 20,
234
+ 3.5,
235
+ generator=generator,
236
+ )
237
+ video = pipeline_output.videos
238
+
239
+ video = torch.cat([video_tensor,video], dim=0)
240
+ save_videos_grid(
241
+ video,
242
+ f"{save_dir}/{global_step:06d}-{model_name}_{cloth_name}.mp4",
243
+ n_rows=2,
244
+ fps=src_fps,
245
+ )
246
+
247
+ del tmp_denoising_unet
248
+ del pipe
249
+ torch.cuda.empty_cache()
250
+
251
+ return video
252
+
253
+
254
+ def main(cfg):
255
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
256
+ accelerator = Accelerator(
257
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
258
+ mixed_precision=cfg.solver.mixed_precision,
259
+ log_with="mlflow",
260
+ project_dir="./mlruns",
261
+ kwargs_handlers=[kwargs],
262
+ )
263
+
264
+ # Make one log on every process with the configuration for debugging.
265
+ logging.basicConfig(
266
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
267
+ datefmt="%m/%d/%Y %H:%M:%S",
268
+ level=logging.INFO,
269
+ )
270
+ logger.info(accelerator.state, main_process_only=False)
271
+ if accelerator.is_local_main_process:
272
+ transformers.utils.logging.set_verbosity_warning()
273
+ diffusers.utils.logging.set_verbosity_info()
274
+ else:
275
+ transformers.utils.logging.set_verbosity_error()
276
+ diffusers.utils.logging.set_verbosity_error()
277
+
278
+ # If passed along, set the training seed now.
279
+ if cfg.seed is not None:
280
+ seed_everything(cfg.seed)
281
+
282
+ exp_name = cfg.exp_name
283
+ save_dir = f"{cfg.output_dir}/{exp_name}"
284
+ if accelerator.is_main_process:
285
+ if not os.path.exists(save_dir):
286
+ os.makedirs(save_dir)
287
+
288
+ # inference_config_path = "./configs/inference/inference_v2.yaml"
289
+ inference_config_path = "./configs/inference/inference.yaml"
290
+ infer_config = OmegaConf.load(inference_config_path)
291
+
292
+ if cfg.weight_dtype == "fp16":
293
+ weight_dtype = torch.float16
294
+ elif cfg.weight_dtype == "bf16":
295
+ weight_dtype = torch.bfloat16
296
+ elif cfg.weight_dtype == "fp32":
297
+ weight_dtype = torch.float32
298
+ else:
299
+ raise ValueError(
300
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
301
+ )
302
+
303
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
304
+ if cfg.enable_zero_snr:
305
+ sched_kwargs.update(
306
+ rescale_betas_zero_snr=True,
307
+ timestep_spacing="trailing",
308
+ prediction_type="v_prediction",
309
+ )
310
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
311
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
312
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
313
+
314
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
315
+ cfg.image_encoder_path,
316
+ ).to(dtype=weight_dtype, device="cuda")
317
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
318
+ "cuda", dtype=weight_dtype
319
+ )
320
+ reference_unet = UNet2DConditionModel.from_pretrained_2d(
321
+ cfg.base_model_path,
322
+ subfolder="unet",
323
+ unet_additional_kwargs={
324
+ "in_channels": 5,
325
+ }
326
+ ).to(device="cuda", dtype=weight_dtype)
327
+
328
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
329
+ cfg.base_model_path,
330
+ cfg.mm_path,
331
+ subfolder="unet",
332
+ unet_additional_kwargs=OmegaConf.to_container(
333
+ infer_config.unet_additional_kwargs
334
+ ),
335
+ ).to(device="cuda")
336
+
337
+ pose_guider = PoseGuider(
338
+ conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
339
+ ).to(device="cuda", dtype=weight_dtype)
340
+
341
+ stage1_ckpt_dir = cfg.stage1_ckpt_dir
342
+ stage1_ckpt_step = cfg.stage1_ckpt_step
343
+ denoising_unet.load_state_dict(
344
+ torch.load(
345
+ os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"),
346
+ map_location="cpu",
347
+ ),
348
+ strict=False,
349
+ )
350
+
351
+ reference_unet.load_state_dict(
352
+ torch.load(
353
+ os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"),
354
+ map_location="cpu",
355
+ ),
356
+ strict=False,
357
+ )
358
+ pose_guider.load_state_dict(
359
+ torch.load(
360
+ os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"),
361
+ map_location="cpu",
362
+ ),
363
+ strict=False,
364
+ )
365
+
366
+
367
+
368
+ # Freeze
369
+ vae.requires_grad_(False)
370
+ image_enc.requires_grad_(False)
371
+ reference_unet.requires_grad_(False)
372
+ denoising_unet.requires_grad_(False)
373
+ pose_guider.requires_grad_(False)
374
+
375
+ # Set motion module learnable
376
+ for name, module in denoising_unet.named_modules():
377
+ if "motion_modules" in name:
378
+ for params in module.parameters():
379
+ params.requires_grad = True
380
+
381
+ reference_control_writer = ReferenceAttentionControl(
382
+ reference_unet,
383
+ do_classifier_free_guidance=False,
384
+ mode="write",
385
+ fusion_blocks="full",
386
+ )
387
+ reference_control_reader = ReferenceAttentionControl(
388
+ denoising_unet,
389
+ do_classifier_free_guidance=False,
390
+ mode="read",
391
+ fusion_blocks="full",
392
+ )
393
+
394
+ net = Net(
395
+ reference_unet,
396
+ denoising_unet,
397
+ pose_guider,
398
+ reference_control_writer,
399
+ reference_control_reader,
400
+ )
401
+
402
+ if cfg.solver.enable_xformers_memory_efficient_attention:
403
+ if is_xformers_available():
404
+ reference_unet.enable_xformers_memory_efficient_attention()
405
+ denoising_unet.enable_xformers_memory_efficient_attention()
406
+ else:
407
+ raise ValueError(
408
+ "xformers is not available. Make sure it is installed correctly"
409
+ )
410
+
411
+ if cfg.solver.gradient_checkpointing:
412
+ reference_unet.enable_gradient_checkpointing()
413
+ denoising_unet.enable_gradient_checkpointing()
414
+
415
+ if cfg.solver.scale_lr:
416
+ learning_rate = (
417
+ cfg.solver.learning_rate
418
+ * cfg.solver.gradient_accumulation_steps
419
+ * cfg.data.train_bs
420
+ * accelerator.num_processes
421
+ )
422
+ else:
423
+ learning_rate = cfg.solver.learning_rate
424
+
425
+ # Initialize the optimizer
426
+ if cfg.solver.use_8bit_adam:
427
+ try:
428
+ import bitsandbytes as bnb
429
+ except ImportError:
430
+ raise ImportError(
431
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
432
+ )
433
+
434
+ optimizer_cls = bnb.optim.AdamW8bit
435
+ else:
436
+ optimizer_cls = torch.optim.AdamW
437
+
438
+ trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
439
+ logger.info(f"Total trainable params {len(trainable_params)}")
440
+ optimizer = optimizer_cls(
441
+ trainable_params,
442
+ lr=learning_rate,
443
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
444
+ weight_decay=cfg.solver.adam_weight_decay,
445
+ eps=cfg.solver.adam_epsilon,
446
+ )
447
+
448
+ # Scheduler
449
+ lr_scheduler = get_scheduler(
450
+ cfg.solver.lr_scheduler,
451
+ optimizer=optimizer,
452
+ num_warmup_steps=cfg.solver.lr_warmup_steps
453
+ * cfg.solver.gradient_accumulation_steps,
454
+ num_training_steps=cfg.solver.max_train_steps
455
+ * cfg.solver.gradient_accumulation_steps,
456
+ )
457
+
458
+ train_dataset = HumanDanceVideoDataset(
459
+ width=cfg.data.train_width,
460
+ height=cfg.data.train_height,
461
+ n_sample_frames=cfg.data.n_sample_frames,
462
+ sample_rate=cfg.data.sample_rate,
463
+ img_scale=(1.0, 1.0),
464
+ data_meta_paths=cfg.data.meta_paths,
465
+ )
466
+ train_dataloader = torch.utils.data.DataLoader(
467
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
468
+ )
469
+
470
+ # Prepare everything with our `accelerator`.
471
+ (
472
+ net,
473
+ optimizer,
474
+ train_dataloader,
475
+ lr_scheduler,
476
+ ) = accelerator.prepare(
477
+ net,
478
+ optimizer,
479
+ train_dataloader,
480
+ lr_scheduler,
481
+ )
482
+
483
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
484
+ num_update_steps_per_epoch = math.ceil(
485
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
486
+ )
487
+ # Afterwards we recalculate our number of training epochs
488
+ num_train_epochs = math.ceil(
489
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
490
+ )
491
+
492
+ # We need to initialize the trackers we use, and also store our configuration.
493
+ # The trackers initializes automatically on the main process.
494
+ if accelerator.is_main_process:
495
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
496
+ accelerator.init_trackers(
497
+ exp_name,
498
+ init_kwargs={"mlflow": {"run_name": run_time}},
499
+ )
500
+ # dump config file
501
+ mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
502
+
503
+ # Train!
504
+ total_batch_size = (
505
+ cfg.data.train_bs
506
+ * accelerator.num_processes
507
+ * cfg.solver.gradient_accumulation_steps
508
+ )
509
+
510
+ logger.info("***** Running training *****")
511
+ logger.info(f" Num examples = {len(train_dataset)}")
512
+ logger.info(f" Num Epochs = {num_train_epochs}")
513
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
514
+ logger.info(
515
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
516
+ )
517
+ logger.info(
518
+ f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
519
+ )
520
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
521
+ global_step = 0
522
+ first_epoch = 0
523
+
524
+ # Potentially load in the weights and states from a previous save
525
+ if cfg.resume_from_checkpoint:
526
+ if cfg.resume_from_checkpoint != "latest":
527
+ resume_dir = cfg.resume_from_checkpoint
528
+ else:
529
+ resume_dir = save_dir
530
+ # Get the most recent checkpoint
531
+ dirs = os.listdir(resume_dir)
532
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
533
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
534
+ path = dirs[-1]
535
+ accelerator.load_state(os.path.join(resume_dir, path))
536
+ accelerator.print(f"Resuming from checkpoint {path}")
537
+ global_step = int(path.split("-")[1])
538
+
539
+ first_epoch = global_step // num_update_steps_per_epoch
540
+ resume_step = global_step % num_update_steps_per_epoch
541
+
542
+ # Only show the progress bar once on each machine.
543
+ progress_bar = tqdm(
544
+ range(global_step, cfg.solver.max_train_steps),
545
+ disable=not accelerator.is_local_main_process,
546
+ )
547
+ progress_bar.set_description("Steps")
548
+
549
+ for epoch in range(first_epoch, num_train_epochs):
550
+ train_loss = 0.0
551
+ t_data_start = time.time()
552
+ for step, batch in enumerate(train_dataloader):
553
+ t_data = time.time() - t_data_start
554
+ with accelerator.accumulate(net):
555
+ # Convert videos to latent space
556
+ pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
557
+ masked_pixel_values = batch["pixel_values_vid_agnostic"].to(weight_dtype)
558
+ # mask_of_pixel_values = batch["pixel_values_vid_agnostic_mask"].to(weight_dtype)
559
+ mask_of_pixel_values = batch["pixel_values_vid_agnostic_mask"].to(weight_dtype)[:,:,0:1,:,:]
560
+ mask_of_pixel_values=mask_of_pixel_values.transpose(1, 2)#b f c h w->b c f h w
561
+ with torch.no_grad():
562
+ video_length = pixel_values_vid.shape[1]
563
+
564
+ pixel_values_vid = rearrange(
565
+ pixel_values_vid, "b f c h w -> (b f) c h w"
566
+ )
567
+ latents = vae.encode(pixel_values_vid).latent_dist.sample()
568
+ latents = rearrange(
569
+ latents, "(b f) c h w -> b c f h w", f=video_length
570
+ )
571
+ latents = latents * 0.18215
572
+
573
+ masked_pixel_values = rearrange(
574
+ masked_pixel_values, "b f c h w -> (b f) c h w"
575
+ )
576
+ masked_latents = vae.encode(masked_pixel_values).latent_dist.sample()
577
+ masked_latents = rearrange(
578
+ masked_latents, "(b f) c h w -> b c f h w", f=video_length
579
+ )
580
+ masked_latents = masked_latents * 0.18215
581
+ mask_of_latents = torch.nn.functional.interpolate(mask_of_pixel_values, size=(24,mask_of_pixel_values.shape[-2] // 8, mask_of_pixel_values.shape[-1] // 8))
582
+
583
+
584
+ noise = torch.randn_like(latents)
585
+ if cfg.noise_offset > 0:
586
+ noise += cfg.noise_offset * torch.randn(
587
+ (latents.shape[0], latents.shape[1], 1, 1, 1),
588
+ device=latents.device,
589
+ )
590
+ bsz = latents.shape[0]
591
+ # Sample a random timestep for each video
592
+ timesteps = torch.randint(
593
+ 0,
594
+ train_noise_scheduler.num_train_timesteps,
595
+ (bsz,),
596
+ device=latents.device,
597
+ )
598
+ timesteps = timesteps.long()
599
+
600
+ pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W)
601
+ pixel_values_pose = pixel_values_pose.transpose(
602
+ 1, 2
603
+ ) # (bs, c, f, H, W)
604
+
605
+ uncond_fwd = random.random() < cfg.uncond_ratio
606
+ clip_image_list = []
607
+ ref_image_list = []
608
+ cloth_mask_list = []
609
+ for batch_idx, (ref_img, cloth_mask, clip_img) in enumerate(
610
+ zip(
611
+ batch["pixel_cloth"],
612
+ batch["pixel_cloth_mask"],
613
+ batch["clip_ref_img"],
614
+ )
615
+ ):
616
+ if uncond_fwd:
617
+ clip_image_list.append(torch.zeros_like(clip_img))
618
+ else:
619
+ clip_image_list.append(clip_img)
620
+ ref_image_list.append(ref_img)
621
+ cloth_mask_list.append(cloth_mask)
622
+
623
+ with torch.no_grad():
624
+ ref_img = torch.stack(ref_image_list, dim=0).to(
625
+ dtype=vae.dtype, device=vae.device
626
+ )
627
+ ref_image_latents = vae.encode(
628
+ ref_img
629
+ ).latent_dist.sample() # (bs, d, 64, 64)
630
+ ref_image_latents = ref_image_latents * 0.18215
631
+
632
+ cloth_mask = torch.stack(cloth_mask_list, dim=0).to(
633
+ dtype=vae.dtype, device=vae.device
634
+ )
635
+ cloth_mask = cloth_mask[:,0:1,:,:]
636
+ cloth_mask = torch.nn.functional.interpolate(cloth_mask, size=(cloth_mask.shape[-2] // 8, cloth_mask.shape[-1] // 8))
637
+
638
+
639
+ clip_img = torch.stack(clip_image_list, dim=0).to(
640
+ dtype=image_enc.dtype, device=image_enc.device
641
+ )
642
+ clip_img = clip_img.to(device="cuda", dtype=weight_dtype)
643
+ clip_image_embeds = image_enc(
644
+ clip_img.to("cuda", dtype=weight_dtype)
645
+ ).image_embeds
646
+ clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d)
647
+
648
+ # add noise
649
+ noisy_latents = train_noise_scheduler.add_noise(
650
+ latents, noise, timesteps
651
+ )
652
+
653
+ # Get the target for loss depending on the prediction type
654
+ if train_noise_scheduler.prediction_type == "epsilon":
655
+ target = noise
656
+ elif train_noise_scheduler.prediction_type == "v_prediction":
657
+ target = train_noise_scheduler.get_velocity(
658
+ latents, noise, timesteps
659
+ )
660
+ else:
661
+ raise ValueError(
662
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
663
+ )
664
+ # ---- Forward!!! -----
665
+ model_pred = net(
666
+ # noisy_latents,
667
+ torch.cat([noisy_latents,masked_latents,mask_of_latents],dim=1),
668
+ timesteps,
669
+ # ref_image_latents,
670
+ torch.cat([ref_image_latents, cloth_mask],dim=1),
671
+ clip_image_embeds,
672
+ pixel_values_pose,
673
+ uncond_fwd=uncond_fwd,
674
+ )
675
+
676
+ if cfg.snr_gamma == 0:
677
+ loss = F.mse_loss(
678
+ model_pred.float(), target.float(), reduction="mean"
679
+ )
680
+ else:
681
+ snr = compute_snr(train_noise_scheduler, timesteps)
682
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
683
+ # Velocity objective requires that we add one to SNR values before we divide by them.
684
+ snr = snr + 1
685
+ mse_loss_weights = (
686
+ torch.stack(
687
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
688
+ ).min(dim=1)[0]
689
+ / snr
690
+ )
691
+ loss = F.mse_loss(
692
+ model_pred.float(), target.float(), reduction="none"
693
+ )
694
+ loss = (
695
+ loss.mean(dim=list(range(1, len(loss.shape))))
696
+ * mse_loss_weights
697
+ )
698
+ loss = loss.mean()
699
+
700
+ # Gather the losses across all processes for logging (if we use distributed training).
701
+ avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
702
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
703
+
704
+ # Backpropagate
705
+ accelerator.backward(loss)
706
+ if accelerator.sync_gradients:
707
+ accelerator.clip_grad_norm_(
708
+ trainable_params,
709
+ cfg.solver.max_grad_norm,
710
+ )
711
+ optimizer.step()
712
+ lr_scheduler.step()
713
+ optimizer.zero_grad()
714
+
715
+ if accelerator.sync_gradients:
716
+ reference_control_reader.clear()
717
+ reference_control_writer.clear()
718
+ progress_bar.update(1)
719
+ global_step += 1
720
+ accelerator.log({"train_loss": train_loss}, step=global_step)
721
+ train_loss = 0.0
722
+
723
+ if global_step % cfg.val.validation_steps == 0:
724
+ if accelerator.is_main_process:
725
+ generator = torch.Generator(device=accelerator.device)
726
+ generator.manual_seed(cfg.seed)
727
+
728
+ log_validation(
729
+ vae=vae,
730
+ image_enc=image_enc,
731
+ net=net,
732
+ scheduler=val_noise_scheduler,
733
+ accelerator=accelerator,
734
+ width=cfg.data.train_width,
735
+ height=cfg.data.train_height,
736
+ global_step=global_step,
737
+ clip_length=cfg.data.n_sample_frames,
738
+ generator=generator,
739
+
740
+ )
741
+
742
+ # for sample_id, sample_dict in enumerate(sample_dicts):
743
+ # sample_name = sample_dict["name"]
744
+ # vid = sample_dict["vid"]
745
+ # with TemporaryDirectory() as temp_dir:
746
+ # out_file = Path(
747
+ # f"{temp_dir}/{global_step:06d}-{sample_name}.gif"
748
+ # )
749
+ # save_videos_grid(vid, out_file, n_rows=2)
750
+ # mlflow.log_artifact(out_file)
751
+
752
+
753
+ logs = {
754
+ "step_loss": loss.detach().item(),
755
+ "lr": lr_scheduler.get_last_lr()[0],
756
+ "td": f"{t_data:.2f}s",
757
+ }
758
+ t_data_start = time.time()
759
+ progress_bar.set_postfix(**logs)
760
+
761
+ if global_step >= cfg.solver.max_train_steps:
762
+ break
763
+
764
+ # save model after each epoch
765
+ if accelerator.is_main_process:
766
+ save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
767
+ delete_additional_ckpt(save_dir, 1)
768
+ # accelerator.save_state(save_path)
769
+ # save motion module only
770
+ unwrap_net = accelerator.unwrap_model(net)
771
+ save_checkpoint(
772
+ unwrap_net.denoising_unet,
773
+ save_dir,
774
+ "motion_module",
775
+ global_step,
776
+ total_limit=3,
777
+ )
778
+
779
+ # Create the pipeline using the trained modules and save it.
780
+ accelerator.wait_for_everyone()
781
+ accelerator.end_training()
782
+
783
+
784
+ def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None):
785
+ save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
786
+
787
+ if total_limit is not None:
788
+ checkpoints = os.listdir(save_dir)
789
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
790
+ checkpoints = sorted(
791
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
792
+ )
793
+
794
+ if len(checkpoints) >= total_limit:
795
+ num_to_remove = len(checkpoints) - total_limit + 1
796
+ removing_checkpoints = checkpoints[0:num_to_remove]
797
+ logger.info(
798
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
799
+ )
800
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
801
+
802
+ for removing_checkpoint in removing_checkpoints:
803
+ removing_checkpoint = os.path.join(save_dir, removing_checkpoint)
804
+ os.remove(removing_checkpoint)
805
+
806
+ mm_state_dict = OrderedDict()
807
+ state_dict = model.state_dict()
808
+ for key in state_dict:
809
+ if "motion_module" in key:
810
+ mm_state_dict[key] = state_dict[key]
811
+
812
+ torch.save(mm_state_dict, save_path)
813
+
814
+
815
+ def decode_latents(vae, latents):
816
+ video_length = latents.shape[2]
817
+ latents = 1 / 0.18215 * latents
818
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
819
+ # video = self.vae.decode(latents).sample
820
+ video = []
821
+ for frame_idx in tqdm(range(latents.shape[0])):
822
+ video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample)
823
+ video = torch.cat(video)
824
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
825
+ video = (video / 2 + 0.5).clamp(0, 1)
826
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
827
+ video = video.cpu().float().numpy()
828
+ return video
829
+
830
+
831
+ if __name__ == "__main__":
832
+ parser = argparse.ArgumentParser()
833
+ parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
834
+ args = parser.parse_args()
835
+
836
+ if args.config[-5:] == ".yaml":
837
+ config = OmegaConf.load(args.config)
838
+ elif args.config[-3:] == ".py":
839
+ config = import_filename(args.config).cfg
840
+ else:
841
+ raise ValueError("Do not support this format config file")
842
+ main(config)
vivid.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import sys
5
+ import torch
6
+ import os
7
+ from diffusers import AutoencoderKL, DDIMScheduler
8
+ from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from transformers import CLIPVisionModelWithProjection
12
+
13
+ from src.models.pose_guider import PoseGuider
14
+ from src.models.unet_2d_condition import UNet2DConditionModel
15
+ from src.models.unet_3d import UNet3DConditionModel
16
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
17
+ from src.utils.util import get_fps, read_frames, save_videos_grid
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config",type=str,default="/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/valid.yaml")
22
+ parser.add_argument("-W", type=int, default=384)
23
+ parser.add_argument("-H", type=int, default=512)
24
+ parser.add_argument("-L", type=int, default=24)
25
+
26
+ parser.add_argument("--seed", type=int, default=42)
27
+ parser.add_argument("--cfg", type=float, default=3.5)
28
+ parser.add_argument("--steps", type=int, default=20)
29
+ parser.add_argument("--fps", type=int)
30
+ args = parser.parse_args()
31
+
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = parse_args()
37
+
38
+ config = OmegaConf.load(args.config)
39
+
40
+ if config.weight_dtype == "fp16":
41
+ weight_dtype = torch.float16
42
+ else:
43
+ weight_dtype = torch.float32
44
+
45
+ vae = AutoencoderKL.from_pretrained(
46
+ config.pretrained_vae_path,
47
+ ).to("cuda", dtype=weight_dtype)
48
+
49
+ reference_unet = UNet2DConditionModel.from_pretrained_2d(
50
+ config.pretrained_base_model_path,
51
+ subfolder="unet",
52
+ unet_additional_kwargs={
53
+ "in_channels": 5,
54
+ }
55
+ ).to(dtype=weight_dtype, device="cuda")
56
+
57
+ inference_config_path = config.inference_config #'/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/inference/inference.yaml'
58
+ infer_config = OmegaConf.load(inference_config_path)
59
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
+ config.pretrained_base_model_path,
61
+ config.motion_module_path,
62
+ subfolder="unet",
63
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
64
+ ).to(dtype=weight_dtype, device="cuda")
65
+
66
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
67
+ dtype=weight_dtype, device="cuda"
68
+ )
69
+
70
+
71
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
72
+ config.image_encoder_path
73
+ ).to(dtype=weight_dtype, device="cuda")
74
+
75
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
76
+ scheduler = DDIMScheduler(**sched_kwargs)
77
+
78
+ seed = config.get("seed",args.seed)
79
+ generator = torch.manual_seed(seed)
80
+
81
+ width, height = args.W, args.H
82
+ clip_length = config.get("L",args.L)
83
+ steps = args.steps
84
+ guidance_scale = args.cfg
85
+
86
+ # load pretrained weights
87
+ denoising_unet.load_state_dict(
88
+ torch.load(config.denoising_unet_path, map_location="cpu"),
89
+ strict=False,
90
+ )
91
+
92
+ reference_unet.load_state_dict(
93
+ torch.load(config.reference_unet_path, map_location="cpu"),
94
+ )
95
+
96
+
97
+ pose_guider.load_state_dict(
98
+ torch.load(config.pose_guider_path, map_location="cpu"),
99
+ )
100
+
101
+
102
+
103
+ pipe = Pose2VideoPipeline(
104
+ vae=vae,
105
+ image_encoder=image_enc,
106
+ reference_unet=reference_unet,
107
+ denoising_unet=denoising_unet,
108
+ pose_guider=pose_guider,
109
+ scheduler=scheduler,
110
+ )
111
+ # 设置日志文件路径
112
+ # log_file_path = "model_structures.log"
113
+ # with open(log_file_path, 'w') as log_file:
114
+ # # 重定向标准输出到日志文件
115
+ # orig_stdout = sys.stdout # 保存原始的标准输出
116
+ # sys.stdout = log_file # 将标准输出重定向到日志文件
117
+
118
+ # # 打印模型结构
119
+ # print("Denoising UNet structure:")
120
+ # print(denoising_unet) # 打印 denoising_unet 的结构
121
+
122
+ # print("Reference UNet structure:")
123
+ # print(reference_unet) # 打印 reference_unet 的结构
124
+
125
+ # print("Pose Guider structure:")
126
+ # print(pose_guider) # 打印 pose_guider 的结构
127
+
128
+ # print("image_enc:")
129
+ # print(image_enc)
130
+
131
+ # print("Pose Guider structure:")
132
+ # print(pose_guider)
133
+
134
+ # print("pipe:")
135
+ # print(pipe)
136
+ # # 恢复标准输出
137
+ # sys.stdout = orig_stdout # 还原标准输出
138
+ # print(f"The model structures have been saved to {log_file_path}.")
139
+ pipe = pipe.to("cuda", dtype=weight_dtype)
140
+
141
+ date_str = datetime.now().strftime("%Y%m%d")
142
+ time_str = datetime.now().strftime("%H%M")
143
+ save_dir_name = f"{time_str}--seed_{seed}-{args.W}x{args.H}"
144
+
145
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
146
+ save_dir.mkdir(exist_ok=True, parents=True)
147
+
148
+ model_video_paths = config.model_video_paths
149
+ cloth_image_paths = config.cloth_image_paths
150
+
151
+ transform = transforms.Compose(
152
+ [transforms.Resize((height, width)), transforms.ToTensor()]
153
+ )
154
+
155
+
156
+ for model_image_path in model_video_paths:
157
+ # print("model_image_path", model_image_path)
158
+ src_fps = get_fps(model_image_path)
159
+
160
+ model_name = Path(model_image_path).stem
161
+ agnostic_path=model_image_path.replace("videos","agnostic") #data/videos/upper1.mp4——>data/agnostic/upper1.mp4
162
+ agn_mask_path=model_image_path.replace("videos","agnostic_mask")
163
+ densepose_path=model_image_path.replace("videos","densepose")
164
+
165
+ video_tensor_list=[]
166
+ video_images=read_frames(model_image_path)
167
+
168
+ clip_length = len(video_images) # 设置 clip_length 为输入视频的总帧数
169
+ # clip_length=48
170
+ for vid_image_pil in video_images[:clip_length]: #clip_length=24
171
+ video_tensor_list.append(transform(vid_image_pil))
172
+
173
+ video_tensor = torch.stack(video_tensor_list, dim=0) # (f, c, h, w)
174
+ video_tensor = video_tensor.transpose(0, 1)
175
+
176
+
177
+ agnostic_list=[]
178
+ agnostic_images=read_frames(agnostic_path)
179
+ for agnostic_image_pil in agnostic_images[:clip_length]:
180
+ agnostic_list.append(agnostic_image_pil)
181
+
182
+ agn_mask_list=[]
183
+ agn_mask_images=read_frames(agn_mask_path)
184
+ # print(" agn_mask_images", agn_mask_images)
185
+ for agn_mask_image_pil in agn_mask_images[:clip_length]:
186
+ agn_mask_list.append(agn_mask_image_pil)
187
+
188
+ pose_list=[]
189
+ pose_images=read_frames(densepose_path)
190
+ for pose_image_pil in pose_images[:clip_length]:
191
+ pose_list.append(pose_image_pil)
192
+
193
+ video_tensor = video_tensor.unsqueeze(0)
194
+
195
+
196
+ for cloth_image_path in cloth_image_paths:
197
+ cloth_name = Path(cloth_image_path).stem
198
+ cloth_image_pil = Image.open(cloth_image_path).convert("RGB")
199
+
200
+ cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask")
201
+ cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB")
202
+
203
+ pipeline_output = pipe(
204
+ agnostic_list,
205
+ agn_mask_list,
206
+ cloth_image_pil,
207
+ cloth_mask_pil,
208
+ pose_list,
209
+ width,
210
+ height,
211
+ clip_length,
212
+ steps,
213
+ guidance_scale,
214
+ generator=generator,
215
+ )
216
+ # print("pipeline_output", pipeline_output)
217
+ video = pipeline_output.videos
218
+
219
+ video = torch.cat([video_tensor,video], dim=0)
220
+ save_videos_grid(
221
+ video,
222
+ f"{save_dir}/{model_name}_{cloth_name}_{args.H}x{args.W}_{int(guidance_scale)}_{time_str}.mp4",
223
+ n_rows=2,
224
+ fps=src_fps if args.fps is None else args.fps,
225
+ )
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()
vividfuxian_motion/20241211/1715/803128_detail_1060638_in_xl.mp4 ADDED
Binary file (95.1 kB). View file
 
vividfuxian_motion/20241212/1437/000004-803128_detail_1060638_in_xl.mp4 ADDED
Binary file (94 kB). View file
 
vividfuxian_motion/20241212/1506/000200-803128_detail_1060638_in_xl.mp4 ADDED
Binary file (98.1 kB). View file
 
vividfuxian_motion/20241212/1629/000600-803128_detail_1060638_in_xl.mp4 ADDED
Binary file (97.8 kB). View file
 
vividfuxian_valid/stage1/000010-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/000200-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/000400-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/000600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/000800-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/001000-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/001200-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/001600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/001800-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/002000-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/002200-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/002400-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/002600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/002800-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/003000-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/003400-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/003600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/003800-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/004200-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/004400-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/004600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/004800-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/005200-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/005400-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/005600-803137_in_xl_812294_in_xl.jpg ADDED
vividfuxian_valid/stage1/005800-803137_in_xl_812294_in_xl.jpg ADDED