yinbq commited on
Commit
0341b51
·
verified ·
1 Parent(s): b35d13a

Add files using upload-large-folder tool

Browse files
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ .vscode
4
+ notebooks
5
+ results
6
+ *.ipynb_checkpoints
7
+ eval_results
8
+ tests
9
+ .DS_Store
10
+ gradio.sh
11
+ models
12
+ bagel_example
13
+ Zebra-CoT
14
+ model_bf16.safetensors
15
+ zebra-cot.tar.gz
16
+ reasoning_output*
EVAL.md ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VLM
2
+ We follow [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html) to evaluate the performance on MME, MMBench, MMMU, MMVet, MathVista and MMVP.
3
+
4
+ ## Data prepration
5
+ Please follow the [InternVL2](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the corresponding data. And the link the data under `vlm`.
6
+
7
+ The final directory structure is:
8
+ ```shell
9
+ data
10
+ ├── MathVista
11
+ ├── mmbench
12
+ ├── mme
13
+ ├── MMMU
14
+ ├── mm-vet
15
+ └── MMVP
16
+ ```
17
+
18
+ ## Evaluation
19
+
20
+ Directly run `scripts/eval/run_eval_vlm.sh` to evaluate different benchmarks. The output will be saved in `$output_path`.
21
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
22
+ - Increase `GPUS` if you want to run faster.
23
+ - For MMBench, please use the official [evaluation server](https://mmbench.opencompass.org.cn/mmbench-submission).
24
+ - For MMVet, please use the official [evaluation server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator).
25
+ - For MathVista, please set `$openai_api_key` in `scripts/eval/run_eval_vlm.sh` and `your_api_url` in `eval/vlm/eval/mathvista/utilities.py`. The default GPT version is `gpt-4o-2024-11-20`.
26
+ - For MMMU, we use CoT in the report, which improve the accuracy by about 2%. For evaluation of the oprn-ended answer, we use GPT-4o for judgement.
27
+
28
+
29
+ # GenEval
30
+ We modify the code in [GenEval](https://github.com/djghosh13/geneval/tree/main) for faster evaluation.
31
+
32
+ ## Setup
33
+ Install the following dependencies:
34
+ ```shell
35
+ pip install open-clip-torch
36
+ pip install clip-benchmark
37
+ pip install --upgrade setuptools
38
+
39
+ sudo pip install -U openmim
40
+ sudo mim install mmengine mmcv-full==1.7.2
41
+
42
+ git clone https://github.com/open-mmlab/mmdetection.git
43
+ cd mmdetection; git checkout 2.x
44
+ pip install -v -e .
45
+ ```
46
+
47
+ Download Detector:
48
+ ```shell
49
+ cd ./eval/gen/geneval
50
+ mkdir model
51
+
52
+ bash ./evaluation/download_models.sh ./model
53
+ ```
54
+
55
+ ## Evaluation
56
+ Directly run `scripts/eval/run_geneval.sh` to evaluate GenEVAL. The output will be saved in `$output_path`.
57
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
58
+ - Set `metadata_file` to `./eval/gen/geneval/prompts/evaluation_metadata.jsonl` for original GenEval prompts.
59
+
60
+
61
+ # WISE
62
+ We modify the code in [WISE](https://github.com/PKU-YuanGroup/WISE/tree/main) for faster evaluation.
63
+
64
+
65
+ ## Evaluation
66
+ Directly run `scripts/eval/run_wise.sh` to evaluate WISE. The output will be saved in `$output_path`.
67
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
68
+ - Set `$openai_api_key` in `scripts/eval/run_wise.sh` and `your_api_url` in `eval/gen/wise/gpt_eval_mp.py`. The default GPT version is `gpt-4o-2024-11-20`.
69
+ - Use `think` for thinking mode.
70
+
71
+
72
+
73
+ # GEdit-Bench
74
+ We adopt the code in [GEdit-Bench](https://github.com/stepfun-ai/Step1X-Edit/blob/main/GEdit-Bench/EVAL.md) for evaluation.
75
+
76
+ ## Evaluation
77
+
78
+ Modify the model path, the output path, the api key, and the api url in `scripts/eval/run_gedit.sh`. Then, run the following command:
79
+ ```shell
80
+ bash script/eval/run_gedit.sh
81
+ ```
82
+ The GPT version for evaluation is `gpt-4.1-2025-04-14`.
83
+
84
+
85
+ # IntelligentBench
86
+ TBD
87
+
88
+
89
+ # KRIS
90
+ We modify the code in [KRIS-Bench](https://github.com/mercurystraw/Kris_Bench) for faster evaluation.
91
+
92
+ ## Data prepration
93
+ Please download the benchmark data from [KRIS-Bench](https://huggingface.co/datasets/Liang0223/KRIS_Bench) and and place it in the `KRIS_Bench` directory.
94
+
95
+ The final directory structure is:
96
+ ```shell
97
+ KRIS_Bench
98
+ ├── abstract_reasoning
99
+ ├── anomaly_correction
100
+ ├── biology
101
+ ├── chemistry
102
+ ├── color_change
103
+ ├── count_change
104
+ ├── geography
105
+ ├── humanities
106
+ ├── mathematics
107
+ ├── medicine
108
+ ├── multi-element_composition
109
+ ├── multi-instruction_execution
110
+ ├── part_completion
111
+ ├── physics
112
+ ├── position_movement
113
+ ├── practical_knowledge
114
+ ├── rule-based_reasoning
115
+ ├── size_adjustment
116
+ ├── temporal_prediction
117
+ └── viewpoint_change
118
+ ```
119
+
120
+ ## Evaluation
121
+ Directly run `scripts/eval/run_kris.sh` to evaluate KRIS-Bench. The output will be saved in `$output_path`.
122
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
123
+ - Set `$openai_api_key` in `scripts/eval/run_kris.sh` and `your_api_url` in `eval/gen/kris/metrics_xx.py`. The default GPT version is `gpt-4o-2024-11-20`.
124
+ - Use `think` for thinking mode.
125
+ - We set `cfg_text_scale=4` and `cfg_img_scale=1.5` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
126
+
127
+ <details>
128
+ <summary><b>Results</b></summary>
129
+ <pre>
130
+ Category, meta-category, and overall average scores (100-point scale):
131
+ Attribute Perception:
132
+ VC: 76.64
133
+ VQ: 74.45
134
+ IF: 41.73
135
+ AVG: 64.27
136
+ Spatial Perception:
137
+ VC: 70.25
138
+ VQ: 80.00
139
+ IF: 37.00
140
+ AVG: 62.42
141
+ Temporal Prediction:
142
+ VC: 36.49
143
+ VQ: 61.82
144
+ IF: 29.05
145
+ AVG: 42.45
146
+ Social Science:
147
+ VC: 76.20
148
+ VQ: 78.80
149
+ IF: 37.00
150
+ KP: 29.60
151
+ AVG: 55.40
152
+ Natural Science:
153
+ VC: 69.59
154
+ VQ: 84.03
155
+ IF: 40.27
156
+ KP: 30.15
157
+ AVG: 56.01
158
+ Logical Reasoning:
159
+ VC: 80.17
160
+ VQ: 85.67
161
+ IF: 26.33
162
+ KP: 18.00
163
+ AVG: 52.54
164
+ Instruction Decomposition:
165
+ VC: 40.17
166
+ VQ: 69.50
167
+ IF: 42.00
168
+ AVG: 50.56
169
+ Factual Knowledge:
170
+ AVG: 60.26
171
+ Conceptual Knowledge:
172
+ AVG: 55.86
173
+ Procedural Knowledge:
174
+ AVG: 51.69
175
+ Overall:
176
+ AVG: 56.21
177
+ </pre>
178
+ </details>
179
+
180
+ <details>
181
+ <summary><b>Results w/ CoT</b></summary>
182
+ <pre>
183
+ Category, meta-category, and overall average scores (100-point scale):
184
+ Attribute Perception:
185
+ VC: 75.09
186
+ VQ: 74.00
187
+ IF: 53.18
188
+ AVG: 67.42
189
+ Spatial Perception:
190
+ VC: 78.75
191
+ VQ: 87.25
192
+ IF: 39.00
193
+ AVG: 68.33
194
+ Temporal Prediction:
195
+ VC: 48.31
196
+ VQ: 81.08
197
+ IF: 46.62
198
+ AVG: 58.67
199
+ Social Science:
200
+ VC: 80.40
201
+ VQ: 79.40
202
+ IF: 51.60
203
+ KP: 42.80
204
+ AVG: 63.55
205
+ Natural Science:
206
+ VC: 67.68
207
+ VQ: 82.95
208
+ IF: 52.10
209
+ KP: 42.88
210
+ AVG: 61.40
211
+ Logical Reasoning:
212
+ VC: 62.83
213
+ VQ: 79.67
214
+ IF: 28.33
215
+ KP: 21.67
216
+ AVG: 48.12
217
+ Instruction Decomposition:
218
+ VC: 47.83
219
+ VQ: 66.83
220
+ IF: 36.00
221
+ AVG: 50.22
222
+ Factual Knowledge:
223
+ AVG: 66.18
224
+ Conceptual Knowledge:
225
+ AVG: 61.92
226
+ Procedural Knowledge:
227
+ AVG: 49.02
228
+ Overall:
229
+ AVG: 60.18
230
+ </pre>
231
+ </details>
232
+
233
+
234
+ # RISE
235
+ We modify the code in [RISEBench](https://github.com/PhoenixZ810/RISEBench) for faster evaluation.
236
+
237
+ ## Data prepration
238
+ Please download the benchmark data from [RISEBench](https://huggingface.co/datasets/PhoenixZ/RISEBench) and and place it in the `data` directory.
239
+
240
+ The final directory structure is:
241
+ ```shell
242
+ data
243
+ ├── datav2_total_w_subtask.json
244
+ ├── causal_reasoning_images
245
+ ├── logical_reasoning_images
246
+ ├── spatial_reasoning_images
247
+ └── temporal_reasoning_images
248
+ ```
249
+
250
+ ## Evaluation
251
+ Directly run `scripts/eval/run_rise.sh` to evaluate RISEBench. The output will be saved in `$output_path`.
252
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
253
+ - Set `$openai_api_key` in `scripts/eval/run_rise.sh` and `your_api_url` in `eval/gen/rise/gpt_eval.py`. The default GPT version is `gpt-4.1-2025-04-14`.
254
+ - Use `think` for thinking mode.
255
+ - We set `cfg_text_scale=4` and `cfg_img_scale=2.0` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
256
+
257
+ <details>
258
+ <summary><b>Results (cfg_img_scale=1.5)</b></summary>
259
+ <pre>
260
+ - Score-Origin Score-Percentage Accuracy
261
+ 0 Overall 2.537778 38.444444 0.061111
262
+ 1 Temporal 2.654118 41.352941 0.023529
263
+ 2 Causal 2.788889 44.722222 0.055556
264
+ 3 Spatial 3.452000 61.300000 0.140000
265
+ 4 Logical 1.080000 2.000000 0.011765
266
+ 5 Overall_Reasoning 2.458333 36.458333 NaN
267
+ 6 Overall_ApprConsistency 3.141643 53.541076 NaN
268
+ 7 Overall_VisualPlausibility_total 3.920000 73.000000 NaN
269
+ 8 Temporal_Reasoning 2.588235 39.705882 NaN
270
+ 9 Temporal_Consistency 3.250000 56.250000 NaN
271
+ 10 Temporal_Quality 3.505882 62.647059 NaN
272
+ 11 Causal_Reasoning 2.733333 43.333333 NaN
273
+ 12 Causal_Consistency 3.579545 64.488636 NaN
274
+ 13 Causal_Quality 3.688889 67.222222 NaN
275
+ 14 Spatial_Reasoning 3.300000 57.500000 NaN
276
+ 15 Spatial_Consistency 3.330000 58.250000 NaN
277
+ 16 Spatial_Quality 4.480000 87.000000 NaN
278
+ 17 Logical_Reasoning 1.047059 1.176471 NaN
279
+ 18 Logical_Consistency 2.364706 34.117647 NaN
280
+ 19 Temp-Life Progression 2.757895 43.947368 0.000000
281
+ 20 Temp-Material Progression 2.500000 37.500000 0.021739
282
+ 21 Temp-Environmental Cycles 3.061538 51.538462 0.076923
283
+ 22 Temp-Societal Transformation 2.628571 40.714286 0.000000
284
+ 23 Causal-Structural Deformation 2.766667 44.166667 0.055556
285
+ 24 Causal-State Transition 3.112000 52.800000 0.080000
286
+ 25 Causal-Chemical and Biological Transformation 2.325000 33.125000 0.062500
287
+ 26 Causal-Physics Manifestation 2.800000 45.000000 0.000000
288
+ 27 Spa-Component Assembly 3.434783 60.869565 0.043478
289
+ 28 Spa-Object Arrangement 2.733333 43.333333 0.000000
290
+ 29 Spa-Viewpoint Generation 3.629630 65.740741 0.222222
291
+ 30 Spa-Structural Inference 4.066667 76.666667 0.133333
292
+ 31 Spa-Layout Reasoning 3.234783 55.869565 0.217391
293
+ 32 Logic-Pattern Prediction 1.035484 0.887097 0.000000
294
+ 33 Logic-Mathematical Derivation 1.350000 8.750000 0.071429
295
+ 34 Logic-Puzzle Solving 1.020000 0.500000 0.000000
296
+ </pre>
297
+ </details>
298
+
299
+ <details>
300
+ <summary><b>Results w/ CoT</b></summary>
301
+ <pre>
302
+ - Score-Origin Score-Percentage Accuracy
303
+ 0 Overall 2.933333 48.333333 0.119444
304
+ 1 Temporal 3.336471 58.411765 0.058824
305
+ 2 Causal 3.608889 65.222222 0.177778
306
+ 3 Spatial 3.492000 62.300000 0.210000
307
+ 4 Logical 1.157647 3.941176 0.011765
308
+ 5 Overall_Reasoning 2.836111 45.902778 NaN
309
+ 6 Overall_ApprConsistency 3.951841 73.796034 NaN
310
+ 7 Overall_VisualPlausibility_total 4.203636 80.090909 NaN
311
+ 8 Temporal_Reasoning 3.188235 54.705882 NaN
312
+ 9 Temporal_Consistency 4.225000 80.625000 NaN
313
+ 10 Temporal_Quality 4.200000 80.000000 NaN
314
+ 11 Causal_Reasoning 3.533333 63.333333 NaN
315
+ 12 Causal_Consistency 4.386364 84.659091 NaN
316
+ 13 Causal_Quality 4.100000 77.500000 NaN
317
+ 14 Spatial_Reasoning 3.350000 58.750000 NaN
318
+ 15 Spatial_Consistency 4.300000 82.500000 NaN
319
+ 16 Spatial_Quality 4.300000 82.500000 NaN
320
+ 17 Logical_Reasoning 1.141176 3.529412 NaN
321
+ 18 Logical_Consistency 2.835294 45.882353 NaN
322
+ 19 Temp-Life Progression 3.526316 63.157895 0.052632
323
+ 20 Temp-Material Progression 3.208696 55.217391 0.086957
324
+ 21 Temp-Environmental Cycles 3.584615 64.615385 0.000000
325
+ 22 Temp-Societal Transformation 3.200000 55.000000 0.000000
326
+ 23 Causal-Structural Deformation 3.750000 68.750000 0.138889
327
+ 24 Causal-State Transition 3.792000 69.800000 0.320000
328
+ 25 Causal-Chemical and Biological Transformation 3.512500 62.812500 0.062500
329
+ 26 Causal-Physics Manifestation 2.984615 49.615385 0.153846
330
+ 27 Spa-Component Assembly 3.652174 66.304348 0.304348
331
+ 28 Spa-Object Arrangement 2.700000 42.500000 0.000000
332
+ 29 Spa-Viewpoint Generation 3.800000 70.000000 0.259259
333
+ 30 Spa-Structural Inference 3.680000 67.000000 0.266667
334
+ 31 Spa-Layout Reasoning 3.260870 56.521739 0.130435
335
+ 32 Logic-Pattern Prediction 1.064516 1.612903 0.000000
336
+ 33 Logic-Mathematical Derivation 1.707143 17.678571 0.071429
337
+ 34 Logic-Puzzle Solving 1.037500 0.937500 0.000000
338
+ </pre>
339
+ </details>
340
+
341
+
342
+ # ImgEdit
343
+ We modify the code in [ImgEdit](https://github.com/PKU-YuanGroup/ImgEdit) for faster evaluation.
344
+
345
+ ## Data prepration
346
+ Please download the benchmark data from [ImgEdit-Bench](https://huggingface.co/datasets/sysuyy/ImgEdit/blob/main/Benchmark.tar) and and place it in the `Benchmark` directory.
347
+
348
+ The final directory structure is:
349
+ ```shell
350
+ Benchmark
351
+ ├── hard
352
+ ├── multiturn
353
+ └── singleturn
354
+ ├── judge_prompt.json
355
+ ├── singleturn.json
356
+ ├── animal
357
+ ├── architecture
358
+ ├── clothes
359
+ ├── compose
360
+ ├── daily object
361
+ ├── for_add
362
+ ├── human
363
+ ├── style
364
+ └── transport
365
+ ```
366
+
367
+ ## Evaluation
368
+ Directly run `scripts/eval/run_imgedit.sh` to evaluate ImgEdit-Bench. The output will be saved in `$output_path`.
369
+ - Set `$model_path` and `$output_path` for the path for checkpoint and log.
370
+ - Set `$openai_api_key` in `scripts/eval/run_imgedit.sh` and `your_api_url` in `eval/gen/imgedit/basic_bench.py`. The default GPT version is `gpt-4o-2024-11-20`.
371
+ - We set `cfg_text_scale=4` and `cfg_img_scale=1.5` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
372
+
373
+ <details>
374
+ <summary><b>Results</b></summary>
375
+ <pre>
376
+ background: 3.28
377
+ adjust: 3.23
378
+ style: 4.26
379
+ extract: 1.48
380
+ remove: 2.99
381
+ add: 3.45
382
+ replace: 3.76
383
+ compose: 3.18
384
+ action: 4.38
385
+ overall: 3.28
386
+ </pre>
387
+ </details>
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Zebra-CoT: A Dataset for Interleaved Vision-Language Reasoning
2
+
3
+ ![Image](assets/zebra_cot_datacard.png)
4
+ ### BAGEL Training Zebra-CoT
5
+
6
+ This repository is adapted from the [Bagel](https://github.com/ByteDance-Seed/Bagel) repository.
7
+ ### Setup
8
+
9
+ ```bash
10
+ git clone https://github.com/multimodal-reasoning-lab/Bagel-Zebra-CoT.git
11
+ cd Bagel-Zebra-CoT
12
+ conda create -n bagel python=3.10 -y
13
+ conda activate bagel
14
+ pip install -r requirements.txt
15
+ pip install flash_attn --no-build-isolation
16
+ ```
17
+
18
+ ### Download checkpoint
19
+
20
+ Set the `HF_HOME` in `download_model.py` to the path of the checkpoint you want to download.
21
+
22
+ ```bash
23
+ python download_model.py
24
+ ```
25
+
26
+ You can also do this straight from python if your `HF_HOME` has already been set.
27
+ ```python
28
+ from huggingface_hub import snapshot_download
29
+
30
+ snapshot_download(
31
+ repo_id="multimodal-reasoning-lab/Bagel-Zebra-CoT",
32
+ local_dir_use_symlinks=False,
33
+ resume_download=True,
34
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
35
+ )
36
+ ```
37
+
38
+ ### Inference
39
+
40
+ ![Image](assets/bagel-cot-example.png)
41
+
42
+ The inference script (`infz_bf16.py`) supports inherent interleaved text and visual reasoning. To customize it for your
43
+ specific use case:
44
+
45
+ ##### 1. Model Checkpoint Path
46
+
47
+ Update the checkpoint path to point to your model:
48
+
49
+ ```python
50
+ checkpoint_dir = "/path/to/your/HF_HOME/models/Bagel-Zebra-CoT"
51
+ ```
52
+
53
+ For example, under the `HF_HOME`, the path to the checkpoint folder is:
54
+
55
+ ```bash
56
+ checkpoint_dir = f"{HF_HOME}/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/c1ff3c56dd5909841523e3a6b554c77d919c2b28
57
+ ```
58
+
59
+ You can also use the local dir:
60
+
61
+ ```
62
+ checkpoint_dir = f"{HF_HOME}/models/Bagel-Zebra-CoT
63
+ ```
64
+
65
+ ##### 2. Setting up prompt and images
66
+
67
+ Edit the prompt and image variables in `infz_bf16.py` (around lines 203-211):
68
+
69
+ **For single image problems:**
70
+ ```python
71
+ prompt = "Your question here"
72
+ image = Image.open('path/to/your/image.png')
73
+ ```
74
+
75
+ **For multiple image problems:**
76
+ ```python
77
+ prompt = "Your question about multiple images"
78
+ image_1 = Image.open('path/to/image1.jpg')
79
+ image_2 = Image.open('path/to/image2.jpg')
80
+ image_3 = Image.open('path/to/image3.jpg')
81
+ image = [image_1, image_2, image_3] # List of images
82
+ ```
83
+
84
+ **For text-only problems:**
85
+ ```python
86
+ prompt = "Your text-only question"
87
+ image = None
88
+ ```
89
+
90
+ ##### 3. Inference Parameters
91
+
92
+ You can adjust the generation parameters in the `inference_hyper` dictionary:
93
+
94
+ ```python
95
+ inference_hyper = dict(
96
+ do_sample=True,
97
+ text_temperature=0.3,
98
+ cfg_text_scale=4.0,
99
+ cfg_img_scale=2.0,
100
+ cfg_interval=[0.0, 1.0],
101
+ timestep_shift=3.0,
102
+ num_timesteps=50,
103
+ cfg_renorm_min=0.0,
104
+ cfg_renorm_type="text_channel",
105
+ )
106
+ ```
107
+
108
+ For details, refer to the original jupyter notebook [here](inference.ipynb).
109
+
110
+ #### Example Use Cases
111
+
112
+ ```python
113
+ prompt = "Subtract all cylinders. Add 1 red sphere. How many objects are left?"
114
+ image = Image.open('test_images/image.png')
115
+ ```
116
+
117
+ ### Training
118
+ For training, run
119
+
120
+ ```bash
121
+ bash scripts/train.sh
122
+ ```
123
+
124
+ For details, please refer to the original repo [README](https://github.com/bytedance-seed/BAGEL).
125
+
126
+ The interleaved reasoning data customized for Zebra-CoT can be found in [think_trace_dataset.py](data/interleave_datasets/think_trace_dataset.py).
127
+
128
+ ### Cite
129
+ ```bibtex
130
+ @misc{li2025zebracot,
131
+ title={Zebra-CoT: A Dataset for Interleaved Vision Language Reasoning},
132
+ author={Ang Li and Charles Wang and Kaiyu Yue and Zikui Cai and Ollie Liu and Deqing Fu and Peng Guo and Wang Bill Zhu and Vatsal Sharan and Robin Jia and Willie Neiswanger and Furong Huang and Tom Goldstein and Micah Goldblum},
133
+ year={2025},
134
+ eprint={2507.16746},
135
+ archivePrefix={arXiv},
136
+ primaryClass={cs.CV},
137
+ url={https://arxiv.org/abs/2507.16746},
138
+ }
139
+ ```
TRAIN.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data prepration
2
+
3
+ We provide data examples for **T2I**, **Editing**, and **VLM** tasks. The T2I dataset is generated using [FLUX.1‑dev](https://huggingface.co/black-forest-labs/FLUX.1-dev); the editing examples are randomly sampled from [SEED‑Data‑Edit‑Part3](https://huggingface.co/datasets/AILab-CVC/SEED-Data-Edit-Part2-3); and the VLM set is sourced from [LLaVA‑OneVision‑Data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data).
4
+
5
+ We offer examples in both raw-image folder and parquet shard formats. For other data formats, you can use our dataset code as a template and extend it as needed.
6
+
7
+
8
+ 1. **Download the sample dataset**
9
+
10
+ ```bash
11
+ wget -O bagel_example.zip \
12
+ https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/bagel_example.zip
13
+ unzip bagel_example.zip -d /data
14
+ ```
15
+ 2. **Expected hierarchy**
16
+
17
+ ```text
18
+ bagel_example
19
+ ├── t2i/ # text-to-image (parquet)
20
+ ├── editing/ # image editing (parquet)
21
+ │ ├── seedxedit_multi/
22
+ │ └── parquet_info/
23
+ └── vlm/
24
+ ├── images/ # JPEG / PNG frames
25
+ └── llava_ov_si.jsonl # vision‑language SFT conversations
26
+ ```
27
+ 3. Edit every `your_data_path` placeholder in **`data/dataset_info.py`**.
28
+ 4. *(Optional)* Extend `DATASET_INFO` with your own parquet shards or JSONL files to mix extra data.
29
+
30
+ ---
31
+
32
+ # Training
33
+
34
+ The baseline training recipe looks like this (replace environment variables with real paths or values):
35
+
36
+ ```shell
37
+ # Pre-training
38
+ torchrun \
39
+ --nnodes=$num_nodes \
40
+ --node_rank=$node_rank \
41
+ --nproc_per_node=8 \
42
+ --master_addr=$master_addr \
43
+ --master_port=$master_port \
44
+ train/pretrain_unified_navit.py \
45
+ --dataset_config_file ./data/configs/example.yaml \
46
+ --llm_path $llm_path \
47
+ --vae_path $vae_path \
48
+ --vit_path $vit_path \
49
+ --layer_module Qwen2MoTDecoderLayer \
50
+ --use_flex True \
51
+ --resume_from $resume_from \
52
+ --results_dir $output_path \
53
+ --checkpoint_dir $ckpt_path \
54
+ --max_latent_size 64 # 32 for low-resolution pre-training
55
+
56
+ # Fine-tuning
57
+ torchrun \
58
+ --nnodes=$num_nodes \
59
+ --node_rank=$node_rank \
60
+ --nproc_per_node=8 \
61
+ --master_addr=$master_addr \
62
+ --master_port=$master_port \
63
+ train/pretrain_unified_navit.py \
64
+ --dataset_config_file ./data/configs/example.yaml \
65
+ --model_path $model_path \
66
+ --layer_module Qwen2MoTDecoderLayer \
67
+ --max_latent_size 64 \
68
+ --resume-from $model_path \
69
+ --finetune_from_hf True \
70
+ --auto_resume True \
71
+ --resume-model-only True \
72
+ --finetune-from-ema True \
73
+ --log_every 1 \
74
+ --lr 2e-5 \
75
+ --num_worker 1 \
76
+ --expected_num_tokens 10240 \
77
+ --max_num_tokens 11520 \
78
+ --max_num_tokens_per_sample 10240
79
+ ```
80
+
81
+ - **When fine-tuning BAGEL, set `max_latent_size=64` to ensure the correct pretrained weights are loaded.** If this is not set, an out-of-bounds error may occur.
82
+ - The total value of `num_used_data` should be greater than `NUM_GPUS × NUM_WORKERS`. (For toy data, use `num_worker=1`.)
83
+ - For T2I-only fine-tuning, set `visual_und=False`. For VLM-only fine-tuning, set `visual_gen=False`.
84
+ - For debugging purposes, use smaller values for `expected_num_tokens`, `max_num_tokens`, and `max_num_tokens_per_sample`.
85
+ - When fine-tuning on toy data, the loss behaves as follows:
86
+ ```shell
87
+ [2025-05-25 17:01:37] (step=0000000) Train Loss mse: 0.4063, Train Loss ce: 0.5504, Train Steps/Sec: 0.01,
88
+ [2025-05-25 17:01:40] (step=0000001) Train Loss mse: 0.4121, Train Loss ce: 0.8152, Train Steps/Sec: 0.44,
89
+ [2025-05-25 17:01:42] (step=0000002) Train Loss mse: 0.3876, Train Loss ce: 1.3411, Train Steps/Sec: 0.40,
90
+ [2025-05-25 17:01:45] (step=0000003) Train Loss mse: 0.3825, Train Loss ce: 0.7360, Train Steps/Sec: 0.44,
91
+ ```
92
+
93
+
94
+ You are encouraged to adjust any of these hyperparameters to fit your GPU budget and the scale of your dataset. If you encounter any issues, please open an issue for assistance. 🎉
95
+
96
+
97
+ ## Model config
98
+
99
+
100
+ | Argument | Default | Description |
101
+ | ---------------------------- | ------------------------------------------- | --------------------------------------------------------------- |
102
+ | `llm_path` | `hf/Qwen2.5-0.5B-Instruct` | Language‑model backbone (HuggingFace repo or local folder). |
103
+ | `vae_path` | `flux/vae/ae.safetensors` | Pre‑trained VAE checkpoint for latent diffusion. |
104
+ | `vit_path` | `hf/siglip-so400m-14-980-flash-attn2-navit` | SigLIP ViT used for image understanding. |
105
+ | `max_latent_size` | `32` | Maximum latent grid side; defines highest generable resolution. |
106
+ | `latent_patch_size` | `2` | VAE pixels represented by one latent patch. |
107
+ | `vit_max_num_patch_per_side` | `70` | Max ViT patches per image side after resizing. |
108
+ | `text_cond_dropout_prob` | `0.1` | Probability to drop text conditioning while training. |
109
+ | `vae_cond_dropout_prob` | `0.3` | Dropout on VAE latent inputs. |
110
+ | `vit_cond_dropout_prob` | `0.3` | Dropout on visual features. |
111
+
112
+ *(See `ModelArguments` for many more options.)*
113
+
114
+
115
+ ## Data config
116
+
117
+
118
+ | Argument | Default | Description |
119
+ | --------------------------- | --------------------------- | --------------------------------------------------------- |
120
+ | `dataset_config_file` | `data/configs/example.yaml` | YAML that groups datasets and assigns sampling weights. |
121
+ | `num_workers` | `4` | Background workers per rank for the PyTorch `DataLoader`. |
122
+ | `prefetch_factor` | `2` | Batches pre‑fetched by each worker. |
123
+ | `max_num_tokens_per_sample` | `16384` | Skip raw samples longer than this. |
124
+ | `max_num_tokens` | `36864` | Hard cap for a packed batch (prevents OOM). |
125
+ | `max_buffer_size` | `50` | Overflow buffer length for oversized samples. |
126
+ | `data_seed` | `42` | Seed for reproducible shuffling and sampling. |
127
+
128
+
129
+ ## Training config
130
+
131
+ | Argument | Default | Description |
132
+ | -------------------------------------- | ---------------------- | ------------------------------------------------------ |
133
+ | `total_steps` | `500_000` | Optimiser steps to run. |
134
+ | `lr` | `1e-4` | Peak learning rate after warm‑up. |
135
+ | `lr_scheduler` | `constant` | Learning‑rate schedule (`constant` or `cosine`). |
136
+ | `warmup_steps` | `2000` | Linear warm‑up duration. |
137
+ | `ema` | `0.9999` | Exponential moving‑average decay for model weights. |
138
+ | `max_grad_norm` | `1.0` | Gradient‑clipping threshold. |
139
+ | `save_every` | `2000` | Checkpoint frequency (steps). |
140
+ | `visual_gen / visual_und` | `True` | Enable image generation / understanding branches. |
141
+ | `freeze_llm / freeze_vit / freeze_vae` | `False / False / True` | Freeze selected modules to save VRAM or for ablations. |
142
+ | `use_flex` | `True` (in example) | Enable FLEX packing for higher GPU utilisation. |
143
+ | `sharding_strategy` | `HYBRID_SHARD` | FSDP sharding mode. |
144
+ | `num_shard` | `8` | Parameter shards per rank in HYBRID mode. |
145
+
146
+ **Distributed‑launch environment variables**
147
+
148
+ | Var | Meaning |
149
+ | ----------------------------- | --------------------------------- |
150
+ | `num_nodes` / `node_rank` | Multi‑node orchestration indices. |
151
+ | `nproc_per_node` | Number of GPUs per node. |
152
+ | `master_addr` / `master_port` | NCCL rendezvous endpoint. |
153
+
154
+
155
+ ## Logging config
156
+
157
+
158
+ | Argument | Default | Description |
159
+ | ---------------- | --------------------- | ---------------------------------------------------- |
160
+ | `results_dir` | `results` | Root directory for logs and metrics. |
161
+ | `checkpoint_dir` | `results/checkpoints` | Checkpoints are saved here. |
162
+ | `log_every` | `10` | Steps between console / W\&B logs. |
163
+ | `wandb_project` | `bagel` | Weights & Biases project name. |
164
+ | `wandb_name` | `run` | Run name inside the project. |
165
+ | `wandb_offline` | `False` | Switch to offline mode (logs locally, sync later). |
166
+ | `wandb_resume` | `allow` | Resumption policy if an existing run ID is detected. |
167
+
168
+ > **Tip** Export `WANDB_API_KEY` before launching if you want online dashboards.
app.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ import random
6
+
7
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
8
+ from PIL import Image
9
+
10
+ from data.data_utils import add_special_tokens, pil_img2rgb
11
+ from data.transforms import ImageTransform
12
+ from inferencer import InterleaveInferencer
13
+ from modeling.autoencoder import load_ae
14
+ from modeling.bagel.qwen2_navit import NaiveCache
15
+ from modeling.bagel import (
16
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
17
+ SiglipVisionConfig, SiglipVisionModel
18
+ )
19
+ from modeling.qwen2 import Qwen2Tokenizer
20
+
21
+ import argparse
22
+ from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
23
+
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--server_name", type=str, default="127.0.0.1")
27
+ parser.add_argument("--server_port", type=int, default=7860)
28
+ parser.add_argument("--share", action="store_true")
29
+ parser.add_argument("--model_path", type=str, default="models/BAGEL-7B-MoT")
30
+ parser.add_argument("--mode", type=int, default=1)
31
+ parser.add_argument("--zh", action="store_true")
32
+ args = parser.parse_args()
33
+
34
+ # Model Initialization
35
+ model_path = args.model_path #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT to models/BAGEL-7B-MoT
36
+
37
+ model_path = args.model_path
38
+
39
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
40
+ llm_config.qk_norm = True
41
+ llm_config.tie_word_embeddings = False
42
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
43
+
44
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
45
+ vit_config.rope = False
46
+ vit_config.num_hidden_layers -= 1
47
+
48
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
49
+
50
+ config = BagelConfig(
51
+ visual_gen=True,
52
+ visual_und=True,
53
+ llm_config=llm_config,
54
+ vit_config=vit_config,
55
+ vae_config=vae_config,
56
+ vit_max_num_patch_per_side=70,
57
+ connector_act='gelu_pytorch_tanh',
58
+ latent_patch_size=2,
59
+ max_latent_size=64,
60
+ )
61
+
62
+ with init_empty_weights():
63
+ language_model = Qwen2ForCausalLM(llm_config)
64
+ vit_model = SiglipVisionModel(vit_config)
65
+ model = Bagel(language_model, vit_model, config)
66
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
67
+
68
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
69
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
70
+
71
+ vae_transform = ImageTransform(1024, 512, 16)
72
+ vit_transform = ImageTransform(980, 224, 14)
73
+
74
+ # Model Loading and Multi GPU Infernece Preparing
75
+ device_map = infer_auto_device_map(
76
+ model,
77
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
78
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
79
+ )
80
+
81
+ same_device_modules = [
82
+ 'language_model.model.embed_tokens',
83
+ 'time_embedder',
84
+ 'latent_pos_embed',
85
+ 'vae2llm',
86
+ 'llm2vae',
87
+ 'connector',
88
+ 'vit_pos_embed'
89
+ ]
90
+
91
+ if torch.cuda.device_count() == 1:
92
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
93
+ for k in same_device_modules:
94
+ if k in device_map:
95
+ device_map[k] = first_device
96
+ else:
97
+ device_map[k] = "cuda:0"
98
+ else:
99
+ first_device = device_map.get(same_device_modules[0])
100
+ for k in same_device_modules:
101
+ if k in device_map:
102
+ device_map[k] = first_device
103
+
104
+ if args.mode == 1:
105
+ model = load_checkpoint_and_dispatch(
106
+ model,
107
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
108
+ device_map=device_map,
109
+ offload_buffers=True,
110
+ offload_folder="offload",
111
+ dtype=torch.bfloat16,
112
+ force_hooks=True,
113
+ ).eval()
114
+ elif args.mode == 2: # NF4
115
+ bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4")
116
+ model = load_and_quantize_model(
117
+ model,
118
+ weights_location=os.path.join(model_path, "ema.safetensors"),
119
+ bnb_quantization_config=bnb_quantization_config,
120
+ device_map=device_map,
121
+ offload_folder="offload",
122
+ ).eval()
123
+ elif args.mode == 3: # INT8
124
+ bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, torch_dtype=torch.float32)
125
+ model = load_and_quantize_model(
126
+ model,
127
+ weights_location=os.path.join(model_path, "ema.safetensors"),
128
+ bnb_quantization_config=bnb_quantization_config,
129
+ device_map=device_map,
130
+ offload_folder="offload",
131
+ ).eval()
132
+ else:
133
+ raise NotImplementedError
134
+
135
+ # Inferencer Preparing
136
+ inferencer = InterleaveInferencer(
137
+ model=model,
138
+ vae_model=vae_model,
139
+ tokenizer=tokenizer,
140
+ vae_transform=vae_transform,
141
+ vit_transform=vit_transform,
142
+ new_token_ids=new_token_ids,
143
+ )
144
+
145
+
146
+ def set_seed(seed):
147
+ """Set random seeds for reproducibility"""
148
+ if seed > 0:
149
+ random.seed(seed)
150
+ np.random.seed(seed)
151
+ torch.manual_seed(seed)
152
+ if torch.cuda.is_available():
153
+ torch.cuda.manual_seed(seed)
154
+ torch.cuda.manual_seed_all(seed)
155
+ torch.backends.cudnn.deterministic = True
156
+ torch.backends.cudnn.benchmark = False
157
+ return seed
158
+
159
+
160
+ # Text to Image function with thinking option and hyperparameters
161
+ def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
162
+ timestep_shift=3.0, num_timesteps=50,
163
+ cfg_renorm_min=0.0, cfg_renorm_type="global",
164
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
165
+ seed=0, image_ratio="1:1"):
166
+ # Set seed for reproducibility
167
+ set_seed(seed)
168
+
169
+ if image_ratio == "1:1":
170
+ image_shapes = (1024, 1024)
171
+ elif image_ratio == "4:3":
172
+ image_shapes = (768, 1024)
173
+ elif image_ratio == "3:4":
174
+ image_shapes = (1024, 768)
175
+ elif image_ratio == "16:9":
176
+ image_shapes = (576, 1024)
177
+ elif image_ratio == "9:16":
178
+ image_shapes = (1024, 576)
179
+
180
+ # Set hyperparameters
181
+ inference_hyper = dict(
182
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
183
+ do_sample=do_sample if show_thinking else False,
184
+ text_temperature=text_temperature if show_thinking else 0.3,
185
+ cfg_text_scale=cfg_text_scale,
186
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
187
+ timestep_shift=timestep_shift,
188
+ num_timesteps=num_timesteps,
189
+ cfg_renorm_min=cfg_renorm_min,
190
+ cfg_renorm_type=cfg_renorm_type,
191
+ image_shapes=image_shapes,
192
+ )
193
+
194
+ # Call inferencer with or without think parameter based on user choice
195
+ result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
196
+ return result["image"], result.get("text", None)
197
+
198
+
199
+ # Image Understanding function with thinking option and hyperparameters
200
+ def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
201
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
202
+ if image is None:
203
+ return "Please upload an image."
204
+
205
+ if isinstance(image, np.ndarray):
206
+ image = Image.fromarray(image)
207
+
208
+ image = pil_img2rgb(image)
209
+
210
+ # Set hyperparameters
211
+ inference_hyper = dict(
212
+ do_sample=do_sample,
213
+ text_temperature=text_temperature,
214
+ max_think_token_n=max_new_tokens, # Set max_length
215
+ )
216
+
217
+ # Use show_thinking parameter to control thinking process
218
+ result = inferencer(image=image, text=prompt, think=show_thinking,
219
+ understanding_output=True, **inference_hyper)
220
+ return result["text"]
221
+
222
+
223
+ # Image Editing function with thinking option and hyperparameters
224
+ def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
225
+ cfg_img_scale=2.0, cfg_interval=0.0,
226
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=0.0,
227
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
228
+ do_sample=False, text_temperature=0.3, seed=0):
229
+ # Set seed for reproducibility
230
+ set_seed(seed)
231
+
232
+ if image is None:
233
+ return "Please upload an image.", ""
234
+
235
+ if isinstance(image, np.ndarray):
236
+ image = Image.fromarray(image)
237
+
238
+ image = pil_img2rgb(image)
239
+
240
+ # Set hyperparameters
241
+ inference_hyper = dict(
242
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
243
+ do_sample=do_sample if show_thinking else False,
244
+ text_temperature=text_temperature if show_thinking else 0.3,
245
+ cfg_text_scale=cfg_text_scale,
246
+ cfg_img_scale=cfg_img_scale,
247
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
248
+ timestep_shift=timestep_shift,
249
+ num_timesteps=num_timesteps,
250
+ cfg_renorm_min=cfg_renorm_min,
251
+ cfg_renorm_type=cfg_renorm_type,
252
+ )
253
+
254
+ # Include thinking parameter based on user choice
255
+ result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
256
+ return result["image"], result.get("text", "")
257
+
258
+
259
+ # Helper function to load example images
260
+ def load_example_image(image_path):
261
+ try:
262
+ return Image.open(image_path)
263
+ except Exception as e:
264
+ print(f"Error loading example image: {e}")
265
+ return None
266
+
267
+
268
+ # Gradio UI
269
+ with gr.Blocks() as demo:
270
+ gr.Markdown("""
271
+ <div>
272
+ <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
273
+ </div>
274
+ """)
275
+
276
+ with gr.Tab("📝 Text to Image"):
277
+ txt_input = gr.Textbox(
278
+ label="Prompt",
279
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
280
+ )
281
+
282
+ with gr.Row():
283
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
284
+
285
+ # Add hyperparameter controls in an accordion
286
+ with gr.Accordion("Inference Hyperparameters", open=False):
287
+ with gr.Group():
288
+ with gr.Row():
289
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
290
+ label="Seed", info="0 for random seed, positive for reproducible results")
291
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
292
+ value="1:1", label="Image Ratio",
293
+ info="The longer size is fixed to 1024")
294
+
295
+ with gr.Row():
296
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
297
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
298
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
299
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
300
+
301
+ with gr.Row():
302
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
303
+ value="global", label="CFG Renorm Type",
304
+ info="If the genrated image is blurry, use 'global'")
305
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
306
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
307
+
308
+ with gr.Row():
309
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
310
+ label="Timesteps", info="Total denoising steps")
311
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
312
+ label="Timestep Shift", info="Higher values for layout, lower for details")
313
+
314
+ # Thinking parameters in a single row
315
+ thinking_params = gr.Group(visible=False)
316
+ with thinking_params:
317
+ with gr.Row():
318
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
319
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
320
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
321
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
322
+ label="Temperature", info="Controls randomness in text generation")
323
+
324
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
325
+ img_output = gr.Image(label="Generated Image")
326
+ gen_btn = gr.Button("Generate", variant="primary")
327
+
328
+ # Dynamically show/hide thinking process box and parameters
329
+ def update_thinking_visibility(show):
330
+ return gr.update(visible=show), gr.update(visible=show)
331
+
332
+ show_thinking.change(
333
+ fn=update_thinking_visibility,
334
+ inputs=[show_thinking],
335
+ outputs=[thinking_output, thinking_params]
336
+ )
337
+
338
+ # Process function based on thinking option and hyperparameters
339
+ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
340
+ cfg_interval, timestep_shift,
341
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
342
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio):
343
+ image, thinking = text_to_image(
344
+ prompt, show_thinking, cfg_text_scale, cfg_interval,
345
+ timestep_shift, num_timesteps,
346
+ cfg_renorm_min, cfg_renorm_type,
347
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
348
+ )
349
+ return image, thinking if thinking else ""
350
+
351
+ gr.on(
352
+ triggers=[gen_btn.click, txt_input.submit],
353
+ fn=process_text_to_image,
354
+ inputs=[
355
+ txt_input, show_thinking, cfg_text_scale,
356
+ cfg_interval, timestep_shift,
357
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
358
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
359
+ ],
360
+ outputs=[img_output, thinking_output]
361
+ )
362
+
363
+ with gr.Tab("🖌️ Image Edit"):
364
+ with gr.Row():
365
+ with gr.Column(scale=1):
366
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
367
+ edit_prompt = gr.Textbox(
368
+ label="Prompt",
369
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
370
+ )
371
+
372
+ with gr.Column(scale=1):
373
+ edit_image_output = gr.Image(label="Result")
374
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
375
+
376
+ with gr.Row():
377
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
378
+
379
+ # Add hyperparameter controls in an accordion
380
+ with gr.Accordion("Inference Hyperparameters", open=False):
381
+ with gr.Group():
382
+ with gr.Row():
383
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
384
+ label="Seed", info="0 for random seed, positive for reproducible results")
385
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
386
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
387
+
388
+ with gr.Row():
389
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
390
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
391
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
392
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
393
+
394
+ with gr.Row():
395
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
396
+ value="text_channel", label="CFG Renorm Type",
397
+ info="If the genrated image is blurry, use 'global'")
398
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
399
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
400
+
401
+ with gr.Row():
402
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
403
+ label="Timesteps", info="Total denoising steps")
404
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
405
+ label="Timestep Shift", info="Higher values for layout, lower for details")
406
+
407
+
408
+ # Thinking parameters in a single row
409
+ edit_thinking_params = gr.Group(visible=False)
410
+ with edit_thinking_params:
411
+ with gr.Row():
412
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
413
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
414
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
415
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
416
+ label="Temperature", info="Controls randomness in text generation")
417
+
418
+ edit_btn = gr.Button("Submit", variant="primary")
419
+
420
+ # Dynamically show/hide thinking process box for editing
421
+ def update_edit_thinking_visibility(show):
422
+ return gr.update(visible=show), gr.update(visible=show)
423
+
424
+ edit_show_thinking.change(
425
+ fn=update_edit_thinking_visibility,
426
+ inputs=[edit_show_thinking],
427
+ outputs=[edit_thinking_output, edit_thinking_params]
428
+ )
429
+
430
+ # Process editing with thinking option and hyperparameters
431
+ def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
432
+ cfg_img_scale, cfg_interval,
433
+ timestep_shift, num_timesteps, cfg_renorm_min,
434
+ cfg_renorm_type, max_think_token_n, do_sample,
435
+ text_temperature, seed):
436
+ edited_image, thinking = edit_image(
437
+ image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
438
+ cfg_interval, timestep_shift,
439
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
440
+ max_think_token_n, do_sample, text_temperature, seed
441
+ )
442
+
443
+ return edited_image, thinking if thinking else ""
444
+
445
+ gr.on(
446
+ triggers=[edit_btn.click, edit_prompt.submit],
447
+ fn=process_edit_image,
448
+ inputs=[
449
+ edit_image_input, edit_prompt, edit_show_thinking,
450
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
451
+ edit_timestep_shift, edit_num_timesteps,
452
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
453
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
454
+ ],
455
+ outputs=[edit_image_output, edit_thinking_output]
456
+ )
457
+
458
+ with gr.Tab("🖼️ Image Understanding"):
459
+ with gr.Row():
460
+ with gr.Column(scale=1):
461
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
462
+ understand_prompt = gr.Textbox(
463
+ label="Prompt",
464
+ value="Can someone explain what's funny about this meme??"
465
+ )
466
+
467
+ with gr.Column(scale=1):
468
+ txt_output = gr.Textbox(label="Result", lines=20)
469
+
470
+ with gr.Row():
471
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
472
+
473
+ # Add hyperparameter controls in an accordion
474
+ with gr.Accordion("Inference Hyperparameters", open=False):
475
+ with gr.Row():
476
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
477
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
478
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
479
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
480
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
481
+
482
+ img_understand_btn = gr.Button("Submit", variant="primary")
483
+
484
+ # Process understanding with thinking option and hyperparameters
485
+ def process_understanding(image, prompt, show_thinking, do_sample,
486
+ text_temperature, max_new_tokens):
487
+ result = image_understanding(
488
+ image, prompt, show_thinking, do_sample,
489
+ text_temperature, max_new_tokens
490
+ )
491
+ return result
492
+
493
+ gr.on(
494
+ triggers=[img_understand_btn.click, understand_prompt.submit],
495
+ fn=process_understanding,
496
+ inputs=[
497
+ img_input, understand_prompt, understand_show_thinking,
498
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
499
+ ],
500
+ outputs=txt_output
501
+ )
502
+
503
+ gr.Markdown("""
504
+ <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
505
+ <a href="https://bagel-ai.org/">
506
+ <img
507
+ src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
508
+ alt="BAGEL Website"
509
+ />
510
+ </a>
511
+ <a href="https://arxiv.org/abs/2505.14683">
512
+ <img
513
+ src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
514
+ alt="BAGEL Paper on arXiv"
515
+ />
516
+ </a>
517
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
518
+ <img
519
+ src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
520
+ alt="BAGEL on Hugging Face"
521
+ />
522
+ </a>
523
+ <a href="https://demo.bagel-ai.org/">
524
+ <img
525
+ src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
526
+ alt="BAGEL Demo"
527
+ />
528
+ </a>
529
+ <a href="https://discord.gg/Z836xxzy">
530
+ <img
531
+ src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
532
+ alt="BAGEL Discord"
533
+ />
534
+ </a>
535
+ <a href="mailto:[email protected]">
536
+ <img
537
+ src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
538
+ alt="BAGEL Email"
539
+ />
540
+ </a>
541
+ </div>
542
+ """)
543
+
544
+ UI_TRANSLATIONS = {
545
+ "📝 Text to Image":"📝 文生图",
546
+ "Prompt":"提示词",
547
+ "Thinking":"思考模式",
548
+ "Inference Hyperparameters":"推理参数",
549
+ "Seed":"随机种子",
550
+ "0 for random seed, positive for reproducible results":"0为随机种子,正数表示可重复结果",
551
+ "Image Ratio":"图片比例",
552
+ "The longer size is fixed to 1024":"长边固定为1024",
553
+ "CFG Text Scale":"文本CFG强度",
554
+ "Controls how strongly the model follows the text prompt (4.0-8.0)":"控制模型是否遵循文本提示(4.0-8.0)",
555
+ "CFG Interval":"CFG应用间隔",
556
+ "Start of CFG application interval (end is fixed at 1.0)":"CFG应用间隔的开始(结束固定为1.0)",
557
+ "CFG Renorm Type":"CFG 重归一化类型",
558
+ "If the genrated image is blurry, use 'global'":"如果生成的图像模糊,请使用'global'",
559
+ "CFG Renorm Min":"CFG 重归一化最小值",
560
+ "1.0 disables CFG-Renorm":"1.0 禁用 CFG 重归一化",
561
+ "Timesteps":"时间步数",
562
+ "Total denoising steps":"总去噪步数",
563
+ "Timestep Shift":"时间步偏移",
564
+ "Higher values for layout, lower for details":"值更大更倾向于调整布局,值更小更倾向于调整细节",
565
+ "Sampling":"采样",
566
+ "Enable sampling for text generation":"为文本生成启用采样",
567
+ "Max Think Tokens":"最大思考token数",
568
+ "Maximum number of tokens for thinking":"思考的最大token数",
569
+ "Temperature":"温度系数",
570
+ "Controls randomness in text generation":"控制文本生成的随机性",
571
+ "Thinking Process":"思考过程",
572
+ "Generated Image":"生成图像",
573
+ "Generate":"开始生成",
574
+ "🖌️ Image Edit":"🖌️ 图像编辑",
575
+ "Input Image":"图像输入",
576
+ "Result":"结果",
577
+ "Controls how strongly the model follows the text prompt":"控制模型是否遵循文本提示的强度",
578
+ "CFG Image Scale":"图像CFG强度",
579
+ "Controls how much the model preserves input image details":"控制模型保留输入图像细节的强度",
580
+ "Submit":"开始生成",
581
+ "🖼️ Image Understanding":"🖼️ 图像理解",
582
+ "Controls randomness in text generation (0=deterministic, 1=creative)":"控制文本生成的随机性(0=确定,1=creative)",
583
+ "Max New Tokens":"最大新token数",
584
+ "Maximum length of generated text, including potential thinking":"生成文本的最大长度,包括可能的思考",
585
+ }
586
+
587
+ def apply_localization(block):
588
+ def process_component(component):
589
+ if not component:
590
+ return
591
+
592
+ for attr in ['label', 'info', 'placeholder']:
593
+ if hasattr(component, attr):
594
+ text = getattr(component, attr)
595
+ if text in UI_TRANSLATIONS:
596
+ setattr(component, attr, UI_TRANSLATIONS[text])
597
+
598
+ if hasattr(component, 'children'):
599
+ for child in component.children:
600
+ process_component(child)
601
+
602
+ process_component(block)
603
+ return block
604
+
605
+ if __name__ == "__main__":
606
+ if args.zh:
607
+ demo = apply_localization(demo)
608
+ demo.launch(
609
+ server_name=args.server_name,
610
+ server_port=args.server_port,
611
+ share=args.share,
612
+ inbrowser=True,
613
+ )
bug.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 现在我明白问题了!模型目录中有 model_bf16.safetensors 但是代码在寻找 model.safetensors。让我检查一下 fsdp_utils.py 中的加载逻辑:
2
+
3
+ Read fsdp_utils.py, lines 155 to 165
4
+
5
+ 找到问题了!代码在寻找 model.safetensors,但实际文件名是 model_bf16.safetensors。我需要修改加载逻辑来适配实际的文件名。
6
+
7
+
8
+
9
+
10
+ cd /scratch/by2593/Bagel-Zebra-CoT-origin && PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin
11
+
12
+
13
+
data/configs/example_smm_semantic.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ block_dataset:
2
+ dataset_names:
3
+ - block_dataset
4
+ jsonl_path_list: ["/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl"]
5
+ num_used_data: None
6
+ image_prefix_dir: "/scratch/by2593/project/SMM/semantic_blocks_part1"
7
+ image_transform_args:
8
+ image_stride: 16
9
+ max_image_size: 512 # VAE使用stride=16, 512/16=32 patches
10
+ min_image_size: 512
11
+ vit_image_transform_args:
12
+ image_stride: 14
13
+ max_image_size: 512 # ViT使用stride=14, 512/14=36 patches (匹配模型能力)
14
+ min_image_size: 512
15
+ weight: 1.0
16
+ is_mandatory: true
17
+
18
+ # unified_edit:
19
+ # dataset_names:
20
+ # - seedxedit_multi
21
+ # image_transform_args:
22
+ # image_stride: 16
23
+ # max_image_size: 1024
24
+ # min_image_size: 512
25
+ # vit_image_transform_args:
26
+ # image_stride: 14
27
+ # max_image_size: 518
28
+ # min_image_size: 224
29
+ # is_mandatory: true
30
+ # num_used_data:
31
+ # - 10
32
+ # weight: 1
33
+
34
+ # vlm_sft:
35
+ # dataset_names:
36
+ # - llava_ov
37
+ # image_transform_args:
38
+ # image_stride: 14
39
+ # max_image_size: 980
40
+ # min_image_size: 378
41
+ # max_pixels: 2_007_040
42
+ # frame_sampler_args:
43
+ # max_num_frames: 12
44
+ # min_num_frames: 8
45
+ # is_mandatory: true
46
+ # shuffle_lines: True
47
+ # shuffle_seed: 0
48
+ # num_used_data:
49
+ # - 1000
50
+ # weight: 1
data/data_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import math
6
+ import random
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch.nn.attention.flex_attention import or_masks, and_masks
11
+
12
+
13
+ def create_sparse_mask(document_lens, split_lens, attn_modes, device):
14
+ def causal_mask(b, h, q_idx, kv_idx):
15
+ return q_idx >= kv_idx
16
+
17
+ def full_and_noise_mask(b, h, q_idx, kv_idx):
18
+ return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
19
+
20
+ def remove_noise_mask(b, h, q_idx, kv_idx):
21
+ return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
22
+
23
+ def sample_mask(b, h, q_idx, kv_idx):
24
+ return document_id[q_idx] == document_id[kv_idx]
25
+
26
+ full_and_noise_tmp = []
27
+ noise_tmp = []
28
+
29
+ for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
30
+ value = i if model in ['full', 'noise'] else -1
31
+ full_and_noise_tmp.extend([value] * length)
32
+ value_noise = i if model == 'noise' else -1
33
+ noise_tmp.extend([value_noise] * length)
34
+
35
+ full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
36
+ noise_seq_id = torch.Tensor(noise_tmp).to(device)
37
+
38
+ document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
39
+
40
+ return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
41
+
42
+
43
+ def patchify(image, patch_size):
44
+ p = patch_size
45
+ c, h, w = image.shape
46
+ assert h % p == 0 and w % p == 0
47
+ image = image.reshape(c, h // p, p, w // p, p)
48
+ image = torch.einsum("chpwq->hwpqc", image)
49
+ image = image.reshape(-1, p**2 * c)
50
+ return image
51
+
52
+
53
+ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
54
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
55
+ coords_h = torch.arange(0, num_patches_h)
56
+ coords_w = torch.arange(0, num_patches_w)
57
+ pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
58
+ return pos_ids
59
+
60
+
61
+ def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
62
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
63
+ boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
64
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
65
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
66
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
67
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
68
+ pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
69
+ return pos_ids
70
+
71
+
72
+ def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
73
+ """
74
+ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
75
+ a sample, where each sample contains multiple splits with different attn modes.
76
+ nested_attn_modes: whether to use full attn in each split.
77
+ """
78
+ sample_len = sum(split_lens)
79
+ attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
80
+
81
+ csum = 0
82
+ for s, attn_mode in zip(split_lens, attn_modes):
83
+ assert attn_mode in ['causal', 'full', 'noise']
84
+ if attn_mode == "causal":
85
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
86
+ attention_mask[csum:csum + s, :csum] = 1
87
+ else:
88
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
89
+ attention_mask[csum:csum + s, :csum] = 1
90
+ csum += s
91
+
92
+ csum = 0
93
+ for s, attn_mode in zip(split_lens, attn_modes):
94
+ if attn_mode == "noise":
95
+ attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
96
+ attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
97
+ csum += s
98
+
99
+ attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
100
+ ~attention_mask, float("-inf")
101
+ )
102
+
103
+ return attention_mask
104
+
105
+
106
+ def split_integer_exp_decay(S, ng_sample_decay=1.0):
107
+ if ng_sample_decay == 1.0:
108
+ N = random.randint(1, S)
109
+ else:
110
+ base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
111
+ p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
112
+ N = random.choices(list(range(1, S + 1)), p, k=1)[0]
113
+ cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
114
+ result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
115
+ return result, cumsum
116
+
117
+
118
+ def pil_img2rgb(image):
119
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
120
+ image = image.convert("RGBA")
121
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
122
+ white.paste(image, mask=image.split()[3])
123
+ image = white
124
+ else:
125
+ image = image.convert("RGB")
126
+
127
+ return image
128
+
129
+
130
+ def add_special_tokens(tokenizer):
131
+ all_special_tokens = []
132
+ for k, v in tokenizer.special_tokens_map.items():
133
+ if isinstance(v, str):
134
+ all_special_tokens.append(v)
135
+ elif isinstance(v, list):
136
+ all_special_tokens += v
137
+
138
+ new_tokens = []
139
+
140
+ if '<|im_start|>' not in all_special_tokens:
141
+ new_tokens.append('<|im_start|>')
142
+
143
+ if '<|im_end|>' not in all_special_tokens:
144
+ new_tokens.append('<|im_end|>')
145
+
146
+ if '<|vision_start|>' not in all_special_tokens:
147
+ new_tokens.append('<|vision_start|>')
148
+
149
+ if '<|vision_end|>' not in all_special_tokens:
150
+ new_tokens.append('<|vision_end|>')
151
+
152
+ num_new_tokens = tokenizer.add_tokens(new_tokens)
153
+ bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
154
+ eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
155
+ start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
156
+ end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
157
+
158
+ new_token_ids = dict(
159
+ bos_token_id=bos_token_id,
160
+ eos_token_id=eos_token_id,
161
+ start_of_image=start_of_image,
162
+ end_of_image=end_of_image,
163
+ )
164
+
165
+ return tokenizer, new_token_ids, num_new_tokens
166
+
167
+
168
+ def len2weight(x, loss_reduction='square'):
169
+ if x == 0:
170
+ return x
171
+ if loss_reduction == 'token':
172
+ return 1
173
+ if loss_reduction == 'sample':
174
+ return 1 / x
175
+ if loss_reduction == 'square':
176
+ return 1 / (x ** 0.5)
177
+ raise NotImplementedError(loss_reduction)
data/interleave_datasets/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .edit_dataset import UnifiedEditIterableDataset
5
+ from .think_trace_dataset import ThinkTraceJSONLIterableDataset
6
+
data/parquet_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import os
6
+ import subprocess
7
+ import logging
8
+
9
+ import pyarrow.fs as pf
10
+ import torch.distributed as dist
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
16
+ num_data_dirs = len(data_dir_list)
17
+ if world_size > 1:
18
+ chunk_size = (num_data_dirs + world_size - 1) // world_size
19
+ start_idx = rank * chunk_size
20
+ end_idx = min(start_idx + chunk_size, num_data_dirs)
21
+ local_data_dir_list = data_dir_list[start_idx:end_idx]
22
+ local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
23
+ else:
24
+ local_data_dir_list = data_dir_list
25
+ local_num_sampled_data_paths = num_sampled_data_paths
26
+
27
+ local_data_paths = []
28
+ for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
29
+ if data_dir.startswith("hdfs://"):
30
+ files = hdfs_ls_cmd(data_dir)
31
+ data_paths_per_dir = [
32
+ file for file in files if file.endswith(".parquet")
33
+ ]
34
+ else:
35
+ files = os.listdir(data_dir)
36
+ data_paths_per_dir = [
37
+ os.path.join(data_dir, name)
38
+ for name in files
39
+ if name.endswith(".parquet")
40
+ ]
41
+ repeat = num_data_path // len(data_paths_per_dir)
42
+ data_paths_per_dir = data_paths_per_dir * (repeat + 1)
43
+ local_data_paths.extend(data_paths_per_dir[:num_data_path])
44
+
45
+ if world_size > 1:
46
+ gather_list = [None] * world_size
47
+ dist.all_gather_object(gather_list, local_data_paths)
48
+
49
+ combined_chunks = []
50
+ for chunk_list in gather_list:
51
+ if chunk_list is not None:
52
+ combined_chunks.extend(chunk_list)
53
+ else:
54
+ combined_chunks = local_data_paths
55
+
56
+ return combined_chunks
57
+
58
+
59
+ # NOTE: cumtomize this function for your cluster
60
+ def get_hdfs_host():
61
+ return "hdfs://xxx"
62
+
63
+
64
+ # NOTE: cumtomize this function for your cluster
65
+ def get_hdfs_block_size():
66
+ return 134217728
67
+
68
+
69
+ # NOTE: cumtomize this function for your cluster
70
+ def get_hdfs_extra_conf():
71
+ return None
72
+
73
+
74
+ def init_arrow_pf_fs(parquet_file_path):
75
+ if parquet_file_path.startswith("hdfs://"):
76
+ fs = pf.HadoopFileSystem(
77
+ host=get_hdfs_host(),
78
+ port=0,
79
+ buffer_size=get_hdfs_block_size(),
80
+ extra_conf=get_hdfs_extra_conf(),
81
+ )
82
+ else:
83
+ fs = pf.LocalFileSystem()
84
+ return fs
85
+
86
+
87
+ def hdfs_ls_cmd(dir):
88
+ result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
89
+ return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
data/t2i_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import io
5
+ import json
6
+ import pyarrow.parquet as pq
7
+ import random
8
+ from PIL import Image
9
+
10
+ from .data_utils import pil_img2rgb
11
+ from .distributed_iterable_dataset import DistributedIterableDataset
12
+ from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
13
+
14
+ Image.MAX_IMAGE_PIXELS = 20_000_000
15
+
16
+
17
+ class T2IIterableDataset(DistributedIterableDataset):
18
+ def __init__(
19
+ self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
20
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
21
+ ):
22
+ """
23
+ data_dir_list: list of data directories contains parquet files
24
+ num_used_data: list of number of sampled data paths for each data directory
25
+ """
26
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
27
+ self.transform = transform
28
+ self.tokenizer = tokenizer
29
+ self.data_status = data_status
30
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
31
+ self.set_epoch()
32
+
33
+ def get_data_paths(self, data_dir_list, num_used_data):
34
+ return get_parquet_data_paths(data_dir_list, num_used_data)
35
+
36
+ def __iter__(self):
37
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
38
+ if self.data_status is not None:
39
+ parquet_start_id = self.data_status[worker_id][0]
40
+ row_group_start_id = self.data_status[worker_id][1]
41
+ row_start_id = self.data_status[worker_id][2] + 1
42
+ else:
43
+ parquet_start_id = 0
44
+ row_group_start_id = 0
45
+ row_start_id = 0
46
+ transform_stride = self.transform.stride
47
+
48
+ print(
49
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
50
+ f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
51
+ )
52
+
53
+ while True:
54
+ data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
55
+ for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
56
+ fs = init_arrow_pf_fs(parquet_file_path)
57
+ with fs.open_input_file(parquet_file_path) as f:
58
+ fr = pq.ParquetFile(f)
59
+ row_group_ids = list(range(fr.num_row_groups))
60
+ row_group_ids_ = row_group_ids[row_group_start_id:]
61
+
62
+ for row_group_id in row_group_ids_:
63
+ df = fr.read_row_group(row_group_id).to_pandas()
64
+ df = df.iloc[row_start_id:]
65
+
66
+ for row_idx, row in df.iterrows():
67
+ num_tokens = 0
68
+ try:
69
+ image_byte = row['image']
70
+ image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
71
+ except Exception as e:
72
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
73
+ continue
74
+ image_tensor = self.transform(image)
75
+ height, width = image_tensor.shape[1:]
76
+ num_tokens += width * height // transform_stride ** 2
77
+
78
+ try:
79
+ caption_dict = row['captions']
80
+ caption_dict = json.loads(caption_dict)
81
+ except Exception as e:
82
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
83
+ continue
84
+
85
+ caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
86
+ if len(caps_token) == 0:
87
+ print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
88
+ caption_token = self.tokenizer.encode(' ')
89
+ else:
90
+ caption_token = random.choice(caps_token)
91
+
92
+ sequence_plan, text_ids_list = [], []
93
+ text_ids = caption_token
94
+ num_tokens += len(caption_token)
95
+ text_ids_list.append(text_ids)
96
+ sequence_plan.append({
97
+ 'type': 'text',
98
+ 'enable_cfg': 1,
99
+ 'loss': 0,
100
+ 'special_token_loss': 0,
101
+ 'special_token_label': None,
102
+ })
103
+
104
+ sequence_plan.append({
105
+ 'type': 'vae_image',
106
+ 'enable_cfg': 0,
107
+ 'loss': 1,
108
+ 'special_token_loss': 0,
109
+ 'special_token_label': None,
110
+ })
111
+
112
+ sample = dict(
113
+ image_tensor_list=[image_tensor],
114
+ text_ids_list=text_ids_list,
115
+ num_tokens=num_tokens,
116
+ sequence_plan=sequence_plan,
117
+ data_indexes={
118
+ "data_indexes": [parquet_idx, row_group_id, row_idx],
119
+ "worker_id": worker_id,
120
+ "dataset_name": self.dataset_name,
121
+ }
122
+ )
123
+ yield sample
124
+
125
+ row_start_id = 0
126
+ row_group_start_id = 0
127
+ parquet_start_id = 0
128
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
data/transforms.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import random
5
+ from PIL import Image
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from torchvision import transforms
11
+ from torchvision.transforms import functional as F
12
+ from torchvision.transforms import InterpolationMode
13
+
14
+
15
+ class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
16
+ """Resize the input image so that its longest side and shortest side are within a specified range,
17
+ ensuring that both sides are divisible by a specified stride.
18
+
19
+ Args:
20
+ max_size (int): Maximum size for the longest edge of the image.
21
+ min_size (int): Minimum size for the shortest edge of the image.
22
+ stride (int): Value by which the height and width of the image must be divisible.
23
+ max_pixels (int): Maximum pixels for the full image.
24
+ interpolation (InterpolationMode): Desired interpolation enum defined by
25
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
26
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
27
+ ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
28
+ The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
29
+ antialias (bool, optional): Whether to apply antialiasing (default is True).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ max_size: int,
35
+ min_size: int,
36
+ stride: int,
37
+ max_pixels: int,
38
+ interpolation=InterpolationMode.BICUBIC,
39
+ antialias=True
40
+ ):
41
+ super().__init__()
42
+ self.max_size = max_size
43
+ self.min_size = min_size
44
+ self.stride = stride
45
+ self.max_pixels = max_pixels
46
+ self.interpolation = interpolation
47
+ self.antialias = antialias
48
+
49
+ def _make_divisible(self, value, stride):
50
+ """Ensure the value is divisible by the stride."""
51
+ return max(stride, int(round(value / stride) * stride))
52
+
53
+ def _apply_scale(self, width, height, scale):
54
+ new_width = round(width * scale)
55
+ new_height = round(height * scale)
56
+ new_width = self._make_divisible(new_width, self.stride)
57
+ new_height = self._make_divisible(new_height, self.stride)
58
+ return new_width, new_height
59
+
60
+ def forward(self, img, img_num=1):
61
+ """
62
+ Args:
63
+ img (PIL Image): Image to be resized.
64
+ img_num (int): Number of images, used to change max_tokens.
65
+ Returns:
66
+ PIL Image or Tensor: Rescaled image with divisible dimensions.
67
+ """
68
+ if isinstance(img, torch.Tensor):
69
+ height, width = img.shape[-2:]
70
+ else:
71
+ width, height = img.size
72
+
73
+ scale = min(self.max_size / max(width, height), 1.0)
74
+ scale = max(scale, self.min_size / min(width, height))
75
+ new_width, new_height = self._apply_scale(width, height, scale)
76
+
77
+ # Ensure the number of pixels does not exceed max_pixels
78
+ if new_width * new_height > self.max_pixels / img_num:
79
+ scale = self.max_pixels / img_num / (new_width * new_height)
80
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
81
+
82
+ # Ensure longest edge does not exceed max_size
83
+ if max(new_width, new_height) > self.max_size:
84
+ scale = self.max_size / max(new_width, new_height)
85
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
86
+
87
+ return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
88
+
89
+
90
+ class ImageTransform:
91
+ def __init__(
92
+ self,
93
+ max_image_size,
94
+ min_image_size,
95
+ image_stride,
96
+ max_pixels=14*14*9*1024,
97
+ image_mean=[0.5, 0.5, 0.5],
98
+ image_std=[0.5, 0.5, 0.5]
99
+ ):
100
+ self.stride = image_stride
101
+
102
+ self.resize_transform = MaxLongEdgeMinShortEdgeResize(
103
+ max_size=max_image_size,
104
+ min_size=min_image_size,
105
+ stride=image_stride,
106
+ max_pixels=max_pixels,
107
+ )
108
+ self.to_tensor_transform = transforms.ToTensor()
109
+ self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
110
+
111
+ def __call__(self, img, img_num=1):
112
+ img = self.resize_transform(img, img_num=img_num)
113
+ img = self.to_tensor_transform(img)
114
+ img = self.normalize_transform(img)
115
+ return img
116
+
117
+
118
+ def decolorization(image):
119
+ gray_image = image.convert('L')
120
+ return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
121
+
122
+
123
+ def downscale(image, scale_factor):
124
+ new_width = int(round(image.width * scale_factor))
125
+ new_height = int(round(image.height * scale_factor))
126
+ new_width = max(1, new_width)
127
+ new_height = max(1, new_height)
128
+ return image.resize((new_width, new_height), resample=Image.BICUBIC)
129
+
130
+
131
+ def crop(image, crop_factors):
132
+ target_h, target_w = crop_factors
133
+ img_w, img_h = image.size
134
+
135
+ if target_h > img_h or target_w > img_w:
136
+ raise ValueError("Crop size exceeds image dimensions")
137
+
138
+ x = random.randint(0, img_w - target_w)
139
+ y = random.randint(0, img_h - target_h)
140
+
141
+ return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
142
+
143
+
144
+ def motion_blur_opencv(image, kernel_size=15, angle=0):
145
+ # 线性核
146
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
147
+ kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
148
+
149
+ # 旋转核
150
+ center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
151
+ M = cv2.getRotationMatrix2D(center, angle, 1)
152
+ rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
153
+
154
+ # 归一化核
155
+ rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
156
+
157
+ img = np.array(image)
158
+ if img.ndim == 2:
159
+ blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
160
+ else:
161
+ # 对于彩色图像,各通道独立卷积
162
+ blurred = np.zeros_like(img)
163
+ for c in range(img.shape[2]):
164
+ blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
165
+
166
+ return Image.fromarray(blurred.astype(np.uint8))
167
+
168
+
169
+ def shuffle_patch(image, num_splits, gap_size=2):
170
+ """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
171
+ h_splits, w_splits = num_splits
172
+ img_w, img_h = image.size
173
+
174
+ base_patch_h = img_h // h_splits
175
+ patch_heights = [base_patch_h] * (h_splits - 1)
176
+ patch_heights.append(img_h - sum(patch_heights))
177
+
178
+ base_patch_w = img_w // w_splits
179
+ patch_widths = [base_patch_w] * (w_splits - 1)
180
+ patch_widths.append(img_w - sum(patch_widths))
181
+
182
+ patches = []
183
+ current_y = 0
184
+ for i in range(h_splits):
185
+ current_x = 0
186
+ patch_h = patch_heights[i]
187
+ for j in range(w_splits):
188
+ patch_w = patch_widths[j]
189
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
190
+ patches.append(patch)
191
+ current_x += patch_w
192
+ current_y += patch_h
193
+
194
+ random.shuffle(patches)
195
+
196
+ total_width = sum(patch_widths) + (w_splits - 1) * gap_size
197
+ total_height = sum(patch_heights) + (h_splits - 1) * gap_size
198
+ new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
199
+
200
+ current_y = 0 # 当前行的起始 Y 坐标
201
+ patch_idx = 0 # 当前处理的块索引
202
+ for i in range(h_splits):
203
+ current_x = 0 # 当前列的起始 X 坐标
204
+ patch_h = patch_heights[i] # 当前行块的高度
205
+ for j in range(w_splits):
206
+ # 取出打乱后的块
207
+ patch = patches[patch_idx]
208
+ patch_w = patch_widths[j] # 当前列块的宽度
209
+ # 粘贴块(左上角坐标为 (current_x, current_y))
210
+ new_image.paste(patch, (current_x, current_y))
211
+ # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
212
+ current_x += patch_w + gap_size
213
+ patch_idx += 1
214
+ # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
215
+ current_y += patch_h + gap_size
216
+
217
+ return new_image
218
+
219
+
220
+ def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
221
+ """
222
+ 图像分割后随机空白部分patch,用于inpainting任务
223
+
224
+ 参数:
225
+ image: PIL.Image 输入图像(RGB模式)
226
+ h_splits: int 行分割数(垂直方向分割块数)
227
+ w_splits: int 列分割数(水平方向分割块数)
228
+ blank_ratio: float 空白patch的比例(0~1)
229
+ blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
230
+
231
+ 返回:
232
+ PIL.Image 处理后拼接的图像
233
+ """
234
+ h_splits, w_splits = num_splits
235
+ img_w, img_h = image.size
236
+
237
+ base_patch_h = img_h // h_splits
238
+ patch_heights = [base_patch_h] * (h_splits - 1)
239
+ patch_heights.append(img_h - sum(patch_heights))
240
+
241
+ base_patch_w = img_w // w_splits
242
+ patch_widths = [base_patch_w] * (w_splits - 1)
243
+ patch_widths.append(img_w - sum(patch_widths))
244
+
245
+ patches = []
246
+ current_y = 0
247
+ for i in range(h_splits):
248
+ current_x = 0
249
+ patch_h = patch_heights[i]
250
+ for j in range(w_splits):
251
+ patch_w = patch_widths[j]
252
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
253
+ patches.append(patch)
254
+ current_x += patch_w
255
+ current_y += patch_h
256
+
257
+ total_patches = h_splits * w_splits
258
+ num_blank = int(total_patches * blank_ratio)
259
+ num_blank = max(0, min(num_blank, total_patches))
260
+ blank_indices = random.sample(range(total_patches), num_blank)
261
+
262
+ processed_patches = []
263
+ for idx, patch in enumerate(patches):
264
+ if idx in blank_indices:
265
+ blank_patch = Image.new("RGB", patch.size, color=blank_color)
266
+ processed_patches.append(blank_patch)
267
+ else:
268
+ processed_patches.append(patch)
269
+
270
+ # 创建结果图像(尺寸与原图一致)
271
+ result_image = Image.new("RGB", (img_w, img_h))
272
+ current_y = 0
273
+ patch_idx = 0
274
+ for i in range(h_splits):
275
+ current_x = 0
276
+ patch_h = patch_heights[i]
277
+ for j in range(w_splits):
278
+ # 取出处理后的patch
279
+ patch = processed_patches[patch_idx]
280
+ patch_w = patch_widths[j]
281
+ # 粘贴到原位置
282
+ result_image.paste(patch, (current_x, current_y))
283
+ current_x += patch_w
284
+ patch_idx += 1
285
+ current_y += patch_h
286
+
287
+ return result_image
data/video_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 OpenGVLab
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under MIT, with the full license text
8
+ # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ import io
14
+ import os
15
+ import random
16
+ import re
17
+
18
+ import numpy as np
19
+ import decord
20
+ from PIL import Image
21
+
22
+
23
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
24
+ if sample in ['rand', 'middle']: # uniform sampling
25
+ acc_samples = min(num_frames, vlen)
26
+ # split the video into `acc_samples` intervals, and sample from each interval.
27
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
28
+ ranges = []
29
+ for idx, interv in enumerate(intervals[:-1]):
30
+ ranges.append((interv, intervals[idx + 1] - 1))
31
+ if sample == 'rand':
32
+ try:
33
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
34
+ except:
35
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
36
+ frame_indices.sort()
37
+ frame_indices = list(frame_indices)
38
+ elif fix_start is not None:
39
+ frame_indices = [x[0] + fix_start for x in ranges]
40
+ elif sample == 'middle':
41
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ if len(frame_indices) < num_frames: # padded with last frame
46
+ padded_frame_indices = [frame_indices[-1]] * num_frames
47
+ padded_frame_indices[:len(frame_indices)] = frame_indices
48
+ frame_indices = padded_frame_indices
49
+ elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
50
+ output_fps = float(sample[3:])
51
+ duration = float(vlen) / input_fps
52
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
53
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
54
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
55
+ frame_indices = [e for e in frame_indices if e < vlen]
56
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
57
+ frame_indices = frame_indices[:max_num_frames]
58
+ else:
59
+ raise ValueError
60
+ return frame_indices
61
+
62
+
63
+ def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
64
+ video_reader = decord.VideoReader(video_path, num_threads=1)
65
+ vlen = len(video_reader)
66
+ fps = video_reader.get_avg_fps()
67
+ duration = vlen / float(fps)
68
+ if clip:
69
+ start, end = clip
70
+ duration = end - start
71
+ vlen = int(duration * fps)
72
+ start_index = int(start * fps)
73
+
74
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
75
+
76
+ frame_indices = get_frame_indices(
77
+ t_num_frames, vlen, sample=sample, fix_start=fix_start,
78
+ input_fps=fps
79
+ )
80
+ if clip:
81
+ frame_indices = [f + start_index for f in frame_indices]
82
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
83
+ frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
84
+ return frames
85
+
86
+
87
+ def extract_frame_number(filename):
88
+ # Extract the numeric part from the filename using regular expressions
89
+ match = re.search(r'_(\d+).jpg$', filename)
90
+ return int(match.group(1)) if match else -1
91
+
92
+
93
+ def sort_frames(frame_paths):
94
+ # Extract filenames from each path and sort by their numeric part
95
+ return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
96
+
97
+
98
+ def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
99
+ image_list = sort_frames(list(os.listdir(video_path)))
100
+ frames = []
101
+ for image in image_list:
102
+ fp = os.path.join(video_path, image)
103
+ frame = Image.open(fp).convert('RGB')
104
+ frames.append(frame)
105
+ vlen = len(frames)
106
+
107
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
108
+
109
+ if vlen > t_num_frames:
110
+ frame_indices = get_frame_indices(
111
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
112
+ )
113
+ frames = [frames[i] for i in frame_indices]
114
+ return frames
115
+
116
+
117
+ class FrameSampler:
118
+ def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
119
+ self.max_num_frames = max_num_frames
120
+ self.min_num_frames = min_num_frames
121
+ self.sample = sample
122
+
123
+ def __call__(self, file_name):
124
+ fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
125
+ frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
126
+ return frames
127
+
128
+
129
+ def decode_video_byte(video_bytes):
130
+ video_stream = io.BytesIO(video_bytes)
131
+ vr = decord.VideoReader(video_stream)
132
+ return vr
133
+
134
+
135
+ def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
136
+ if isinstance(mp4_p, str):
137
+ vr = decord.VideoReader(mp4_p, num_threads=1)
138
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
139
+ vr = mp4_p
140
+ video_fps = vr.get_avg_fps() # 获取视频的帧率
141
+ video_duration = len(vr) / video_fps
142
+ if n_frames is not None:
143
+ if random_sample:
144
+ frame_indices = sorted(random.sample(range(len(vr)), n_frames))
145
+ else:
146
+ frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
147
+ else:
148
+ frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
149
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
150
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
151
+ if not return_frame_indices:
152
+ return frames, video_duration
153
+ else:
154
+ return frames, video_duration, frame_indices
155
+
156
+
157
+ def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
158
+ if isinstance(mp4_p, str):
159
+ vr = decord.VideoReader(mp4_p, num_threads=1)
160
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
161
+ vr = mp4_p
162
+ # sample the frames in frame_indices
163
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
164
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
165
+ return frames
data/vlm_dataset.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import os
6
+ import traceback
7
+ from PIL import Image, ImageFile, PngImagePlugin
8
+
9
+ from .data_utils import pil_img2rgb
10
+ from .distributed_iterable_dataset import DistributedIterableDataset
11
+
12
+
13
+ Image.MAX_IMAGE_PIXELS = 200000000
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+ MaximumDecompressedSize = 1024
16
+ MegaByte = 2 ** 20
17
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
18
+
19
+
20
+ class SftJSONLIterableDataset(DistributedIterableDataset):
21
+ def __init__(
22
+ self, dataset_name, transform, tokenizer, frame_sampler,
23
+ jsonl_path_list, data_dir_list, num_used_data,
24
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
25
+ shuffle_lines=False, shuffle_seed=0,
26
+ ):
27
+ """
28
+ jsonl_path_list: list of jsonl file paths
29
+ data_dir_list: list of image directories containing the images of each jsonl file
30
+ num_used_data: list of number of sampled data points for each jsonl
31
+ """
32
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
33
+ self.transform = transform
34
+ self.tokenizer = tokenizer
35
+ self.frame_sampler = frame_sampler
36
+ self.data_status = data_status
37
+ self.data_paths = self.get_data_paths(
38
+ jsonl_path_list,
39
+ data_dir_list,
40
+ num_used_data,
41
+ shuffle_lines,
42
+ shuffle_seed,
43
+ )
44
+ self.set_epoch()
45
+
46
+ def get_data_paths(
47
+ self,
48
+ jsonl_path_list,
49
+ data_dir_list,
50
+ num_used_data,
51
+ shuffle_lines,
52
+ shuffle_seed,
53
+ ):
54
+ data_paths = []
55
+ for jsonl_path, image_dir, num_data_point in zip(
56
+ jsonl_path_list, data_dir_list, num_used_data
57
+ ):
58
+ with open(jsonl_path, 'r') as f:
59
+ raw_data = f.readlines()
60
+ if shuffle_lines:
61
+ self.rng.seed(shuffle_seed)
62
+ self.rng.shuffle(raw_data)
63
+ raw_data = raw_data[:num_data_point]
64
+ data_paths.extend([(json_data, image_dir) for json_data in raw_data])
65
+ return data_paths
66
+
67
+ def change_format(self, data, num_images):
68
+ elements = []
69
+ for conversation in data['conversations']:
70
+ if conversation['from'] == 'human':
71
+ if '<image>' not in conversation['value']:
72
+ elements.append({
73
+ 'type': 'text',
74
+ 'has_loss': 0,
75
+ 'text': conversation['value'],
76
+ })
77
+ else:
78
+ text_list = conversation['value'].split('<image>')
79
+ for idx, text in enumerate(text_list):
80
+ if text.strip() != '':
81
+ elements.append({
82
+ 'type': 'text',
83
+ 'has_loss': 0,
84
+ 'text': text.strip(),
85
+ })
86
+ if (idx != len(text_list) - 1) and (idx < num_images):
87
+ elements.append({'type': 'image',})
88
+ elif conversation['from'] == 'gpt':
89
+ elements.append({
90
+ 'type': 'text',
91
+ 'has_loss': 1,
92
+ 'text': conversation['value'],
93
+ })
94
+ return elements
95
+
96
+ def __iter__(self):
97
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
98
+ if self.data_status is not None:
99
+ row_start_id = self.data_status[worker_id] + 1
100
+ else:
101
+ row_start_id = 0
102
+ transform_stride = self.transform.stride
103
+
104
+ print(
105
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
106
+ f"resuming data at row#{row_start_id}"
107
+ )
108
+
109
+ while True:
110
+ data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
111
+ for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
112
+ num_tokens = 0
113
+ image_tensor_list = []
114
+ text_ids_list = []
115
+ sequence_plan = []
116
+
117
+ try:
118
+ data_item = json.loads(data)
119
+ raw_images = None
120
+ if 'image' in data_item:
121
+ if type(data_item['image']) == list:
122
+ raw_images = [
123
+ pil_img2rgb(Image.open(os.path.join(image_dir, image)))
124
+ for image in data_item['image']
125
+ ]
126
+ else:
127
+ raw_images = [
128
+ pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
129
+ ]
130
+ elif 'video' in data_item:
131
+ raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
132
+ special_tokens = '<image>' * len(raw_images)
133
+ for item in data_item['conversations']:
134
+ if '<video>' in item['value']:
135
+ item['value'] = item['value'].replace('<video>', special_tokens)
136
+ break
137
+ else:
138
+ raise ValueError("Cannot find <video> in the conversation!")
139
+ except:
140
+ traceback.print_exc()
141
+ continue
142
+
143
+ if raw_images:
144
+ for raw_image in raw_images:
145
+ image_tensor = self.transform(raw_image, img_num=len(raw_images))
146
+ image_tensor_list.append(image_tensor)
147
+ height, width = image_tensor.shape[1:]
148
+ num_tokens += width * height // transform_stride ** 2
149
+
150
+ elements = self.change_format(data_item, len(image_tensor_list))
151
+
152
+ for item in elements:
153
+ if item['type'] == 'text':
154
+ text_data = item['text']
155
+ text_ids = self.tokenizer.encode(text_data)
156
+ if len(text_ids) > 0:
157
+ text_ids_list.append(text_ids)
158
+ num_tokens += len(text_ids)
159
+ current_plan = {
160
+ 'type': 'text',
161
+ 'enable_cfg': 0,
162
+ 'loss': item['has_loss'],
163
+ 'special_token_loss': 0,
164
+ 'special_token_label': None,
165
+ }
166
+ sequence_plan.append(current_plan)
167
+ elif item['type'] == 'image':
168
+ current_plan = {
169
+ 'type': 'vit_image',
170
+ 'enable_cfg': 0,
171
+ 'loss': 0,
172
+ 'special_token_loss': 0,
173
+ 'special_token_label': None,
174
+ }
175
+ sequence_plan.append(current_plan)
176
+
177
+ has_loss = [item['loss'] for item in sequence_plan]
178
+ if sum(has_loss) == 0:
179
+ print(f'No loss defined, skipped.')
180
+ continue
181
+
182
+ yield dict(
183
+ image_tensor_list=image_tensor_list,
184
+ text_ids_list=text_ids_list,
185
+ sequence_plan=sequence_plan,
186
+ num_tokens=num_tokens,
187
+ data_indexes={
188
+ "data_indexes": row_idx,
189
+ "worker_id": worker_id,
190
+ "dataset_name": self.dataset_name,
191
+ }
192
+ )
193
+
194
+ row_start_id = 0
195
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
download_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ HF_HOME = "/mnt/wsfuse/kaiyuyue/cache/huggingface"
4
+ repo_id = "multimodal-reasoning-lab/Bagel-Zebra-CoT"
5
+
6
+ snapshot_download(
7
+ cache_dir=HF_HOME,
8
+ repo_id=repo_id,
9
+ local_dir_use_symlinks=False,
10
+ resume_download=True,
11
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
12
+ )
inference.ipynb ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# Copyright 2025 Bytedance Ltd. and/or its affiliates.\n",
12
+ "# SPDX-License-Identifier: Apache-2.0"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {
19
+ "tags": []
20
+ },
21
+ "outputs": [],
22
+ "source": [
23
+ "%load_ext autoreload\n",
24
+ "%autoreload 2"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {
31
+ "tags": []
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "import os\n",
36
+ "from copy import deepcopy\n",
37
+ "from typing import (\n",
38
+ " Any,\n",
39
+ " AsyncIterable,\n",
40
+ " Callable,\n",
41
+ " Dict,\n",
42
+ " Generator,\n",
43
+ " List,\n",
44
+ " NamedTuple,\n",
45
+ " Optional,\n",
46
+ " Tuple,\n",
47
+ " Union,\n",
48
+ ")\n",
49
+ "import requests\n",
50
+ "from io import BytesIO\n",
51
+ "\n",
52
+ "from PIL import Image\n",
53
+ "import torch\n",
54
+ "from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights\n",
55
+ "\n",
56
+ "from data.transforms import ImageTransform\n",
57
+ "from data.data_utils import pil_img2rgb, add_special_tokens\n",
58
+ "from modeling.bagel import (\n",
59
+ " BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel\n",
60
+ ")\n",
61
+ "from modeling.qwen2 import Qwen2Tokenizer\n",
62
+ "from modeling.bagel.qwen2_navit import NaiveCache\n",
63
+ "from modeling.autoencoder import load_ae\n",
64
+ "from safetensors.torch import load_file"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## Model Initialization"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {
78
+ "tags": []
79
+ },
80
+ "outputs": [],
81
+ "source": [
82
+ "model_path = \"/path/to/BAGEL-7B-MoT/weights\" # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT\n",
83
+ "\n",
84
+ "# LLM config preparing\n",
85
+ "llm_config = Qwen2Config.from_json_file(os.path.join(model_path, \"llm_config.json\"))\n",
86
+ "llm_config.qk_norm = True\n",
87
+ "llm_config.tie_word_embeddings = False\n",
88
+ "llm_config.layer_module = \"Qwen2MoTDecoderLayer\"\n",
89
+ "\n",
90
+ "# ViT config preparing\n",
91
+ "vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, \"vit_config.json\"))\n",
92
+ "vit_config.rope = False\n",
93
+ "vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1\n",
94
+ "\n",
95
+ "# VAE loading\n",
96
+ "vae_model, vae_config = load_ae(local_path=os.path.join(model_path, \"ae.safetensors\"))\n",
97
+ "\n",
98
+ "# Bagel config preparing\n",
99
+ "config = BagelConfig(\n",
100
+ " visual_gen=True,\n",
101
+ " visual_und=True,\n",
102
+ " llm_config=llm_config, \n",
103
+ " vit_config=vit_config,\n",
104
+ " vae_config=vae_config,\n",
105
+ " vit_max_num_patch_per_side=70,\n",
106
+ " connector_act='gelu_pytorch_tanh',\n",
107
+ " latent_patch_size=2,\n",
108
+ " max_latent_size=64,\n",
109
+ ")\n",
110
+ "\n",
111
+ "with init_empty_weights():\n",
112
+ " language_model = Qwen2ForCausalLM(llm_config)\n",
113
+ " vit_model = SiglipVisionModel(vit_config)\n",
114
+ " model = Bagel(language_model, vit_model, config)\n",
115
+ " model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)\n",
116
+ "\n",
117
+ "# Tokenizer Preparing\n",
118
+ "tokenizer = Qwen2Tokenizer.from_pretrained(model_path)\n",
119
+ "tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)\n",
120
+ "\n",
121
+ "# Image Transform Preparing\n",
122
+ "vae_transform = ImageTransform(1024, 512, 16)\n",
123
+ "vit_transform = ImageTransform(980, 224, 14)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Model Loading and Multi GPU Infernece Preparing"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {
137
+ "tags": []
138
+ },
139
+ "outputs": [],
140
+ "source": [
141
+ "max_mem_per_gpu = \"80GiB\" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.\n",
142
+ "\n",
143
+ "device_map = infer_auto_device_map(\n",
144
+ " model,\n",
145
+ " max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},\n",
146
+ " no_split_module_classes=[\"Bagel\", \"Qwen2MoTDecoderLayer\"],\n",
147
+ ")\n",
148
+ "print(device_map)\n",
149
+ "\n",
150
+ "same_device_modules = [\n",
151
+ " 'language_model.model.embed_tokens',\n",
152
+ " 'time_embedder',\n",
153
+ " 'latent_pos_embed',\n",
154
+ " 'vae2llm',\n",
155
+ " 'llm2vae',\n",
156
+ " 'connector',\n",
157
+ " 'vit_pos_embed'\n",
158
+ "]\n",
159
+ "\n",
160
+ "if torch.cuda.device_count() == 1:\n",
161
+ " first_device = device_map.get(same_device_modules[0], \"cuda:0\")\n",
162
+ " for k in same_device_modules:\n",
163
+ " if k in device_map:\n",
164
+ " device_map[k] = first_device\n",
165
+ " else:\n",
166
+ " device_map[k] = \"cuda:0\"\n",
167
+ "else:\n",
168
+ " first_device = device_map.get(same_device_modules[0])\n",
169
+ " for k in same_device_modules:\n",
170
+ " if k in device_map:\n",
171
+ " device_map[k] = first_device\n",
172
+ "\n",
173
+ "# Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8\n",
174
+ "model = load_checkpoint_and_dispatch(\n",
175
+ " model,\n",
176
+ " checkpoint=os.path.join(model_path, \"ema.safetensors\"),\n",
177
+ " device_map=device_map,\n",
178
+ " offload_buffers=True,\n",
179
+ " dtype=torch.bfloat16,\n",
180
+ " force_hooks=True,\n",
181
+ " offload_folder=\"/tmp/offload\"\n",
182
+ ")\n",
183
+ "\n",
184
+ "model = model.eval()\n",
185
+ "print('Model loaded')"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": []
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "## Inferencer Preparing "
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {
206
+ "tags": []
207
+ },
208
+ "outputs": [],
209
+ "source": [
210
+ "from inferencer import InterleaveInferencer\n",
211
+ "\n",
212
+ "inferencer = InterleaveInferencer(\n",
213
+ " model=model, \n",
214
+ " vae_model=vae_model, \n",
215
+ " tokenizer=tokenizer, \n",
216
+ " vae_transform=vae_transform, \n",
217
+ " vit_transform=vit_transform, \n",
218
+ " new_token_ids=new_token_ids\n",
219
+ ")"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {
226
+ "tags": []
227
+ },
228
+ "outputs": [],
229
+ "source": [
230
+ "import random\n",
231
+ "import numpy as np\n",
232
+ "\n",
233
+ "seed = 42\n",
234
+ "random.seed(seed)\n",
235
+ "np.random.seed(seed)\n",
236
+ "torch.manual_seed(seed)\n",
237
+ "if torch.cuda.is_available():\n",
238
+ " torch.cuda.manual_seed(seed)\n",
239
+ " torch.cuda.manual_seed_all(seed)\n",
240
+ "torch.backends.cudnn.deterministic = True\n",
241
+ "torch.backends.cudnn.benchmark = False"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "markdown",
246
+ "metadata": {},
247
+ "source": [
248
+ "**About Inference Hyperparameters:**\n",
249
+ "- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.\n",
250
+ "- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.\n",
251
+ "- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.\n",
252
+ "- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).\n",
253
+ "- **`num_timesteps`:** Total denoising steps. Typical: `50`.\n",
254
+ "- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.\n",
255
+ "- **`cfg_renorm_type`:** CFG-Renorm method: \n",
256
+ " - `global`: Normalize over all tokens and channels (default for T2I).\n",
257
+ " - `channel`: Normalize across channels for each token.\n",
258
+ " - `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).\n",
259
+ "- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**\n"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "metadata": {},
265
+ "source": [
266
+ "## Image Generation"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {
273
+ "tags": []
274
+ },
275
+ "outputs": [],
276
+ "source": [
277
+ "inference_hyper=dict(\n",
278
+ " cfg_text_scale=4.0,\n",
279
+ " cfg_img_scale=1.0,\n",
280
+ " cfg_interval=[0.4, 1.0],\n",
281
+ " timestep_shift=3.0,\n",
282
+ " num_timesteps=50,\n",
283
+ " cfg_renorm_min=0.0,\n",
284
+ " cfg_renorm_type=\"global\",\n",
285
+ ")"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "metadata": {
292
+ "tags": []
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "prompt = \"A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.\"\n",
297
+ "\n",
298
+ "print(prompt)\n",
299
+ "print('-' * 10)\n",
300
+ "output_dict = inferencer(text=prompt, **inference_hyper)\n",
301
+ "display(output_dict['image'])"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "markdown",
306
+ "metadata": {
307
+ "tags": []
308
+ },
309
+ "source": [
310
+ "## Image Generation with Think"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {
317
+ "tags": []
318
+ },
319
+ "outputs": [],
320
+ "source": [
321
+ "inference_hyper=dict(\n",
322
+ " max_think_token_n=1000,\n",
323
+ " do_sample=False,\n",
324
+ " # text_temperature=0.3,\n",
325
+ " cfg_text_scale=4.0,\n",
326
+ " cfg_img_scale=1.0,\n",
327
+ " cfg_interval=[0.4, 1.0],\n",
328
+ " timestep_shift=3.0,\n",
329
+ " num_timesteps=50,\n",
330
+ " cfg_renorm_min=0.0,\n",
331
+ " cfg_renorm_type=\"global\",\n",
332
+ ")"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": null,
338
+ "metadata": {
339
+ "tags": []
340
+ },
341
+ "outputs": [],
342
+ "source": [
343
+ "prompt = 'a car made of small cars'\n",
344
+ "\n",
345
+ "print(prompt)\n",
346
+ "print('-' * 10)\n",
347
+ "output_dict = inferencer(text=prompt, think=True, **inference_hyper)\n",
348
+ "print(output_dict['text'])\n",
349
+ "display(output_dict['image'])"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": []
358
+ },
359
+ {
360
+ "cell_type": "markdown",
361
+ "metadata": {},
362
+ "source": [
363
+ "## Editing"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "metadata": {
370
+ "tags": []
371
+ },
372
+ "outputs": [],
373
+ "source": [
374
+ "inference_hyper=dict(\n",
375
+ " cfg_text_scale=4.0,\n",
376
+ " cfg_img_scale=2.0,\n",
377
+ " cfg_interval=[0.0, 1.0],\n",
378
+ " timestep_shift=3.0,\n",
379
+ " num_timesteps=50,\n",
380
+ " cfg_renorm_min=0.0,\n",
381
+ " cfg_renorm_type=\"text_channel\",\n",
382
+ ")"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {
389
+ "tags": []
390
+ },
391
+ "outputs": [],
392
+ "source": [
393
+ "image = Image.open('test_images/women.jpg')\n",
394
+ "prompt = 'She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.'\n",
395
+ "\n",
396
+ "display(image)\n",
397
+ "print(prompt)\n",
398
+ "print('-'*10)\n",
399
+ "output_dict = inferencer(image=image, text=prompt, **inference_hyper)\n",
400
+ "display(output_dict['image'])"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": []
409
+ },
410
+ {
411
+ "cell_type": "markdown",
412
+ "metadata": {},
413
+ "source": [
414
+ "## Edit with Think"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "metadata": {
421
+ "tags": []
422
+ },
423
+ "outputs": [],
424
+ "source": [
425
+ "inference_hyper=dict(\n",
426
+ " max_think_token_n=1000,\n",
427
+ " do_sample=False,\n",
428
+ " # text_temperature=0.3,\n",
429
+ " cfg_text_scale=4.0,\n",
430
+ " cfg_img_scale=2.0,\n",
431
+ " cfg_interval=[0.0, 1.0],\n",
432
+ " timestep_shift=3.0,\n",
433
+ " num_timesteps=50,\n",
434
+ " cfg_renorm_min=0.0,\n",
435
+ " cfg_renorm_type=\"text_channel\",\n",
436
+ ")"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {
443
+ "tags": []
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "image = Image.open('test_images/octupusy.jpg')\n",
448
+ "prompt = 'Could you display the sculpture that takes after this design?'\n",
449
+ "\n",
450
+ "display(image)\n",
451
+ "print('-'*10)\n",
452
+ "output_dict = inferencer(image=image, text=prompt, think=True, **inference_hyper)\n",
453
+ "print(output_dict['text'])\n",
454
+ "display(output_dict['image'])"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "metadata": {},
461
+ "outputs": [],
462
+ "source": []
463
+ },
464
+ {
465
+ "cell_type": "markdown",
466
+ "metadata": {},
467
+ "source": [
468
+ "## Understanding"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "code",
473
+ "execution_count": null,
474
+ "metadata": {
475
+ "tags": []
476
+ },
477
+ "outputs": [],
478
+ "source": [
479
+ "inference_hyper=dict(\n",
480
+ " max_think_token_n=1000,\n",
481
+ " do_sample=False,\n",
482
+ " # text_temperature=0.3,\n",
483
+ ")"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "metadata": {
490
+ "tags": []
491
+ },
492
+ "outputs": [],
493
+ "source": [
494
+ "image = Image.open('test_images/meme.jpg')\n",
495
+ "prompt = \"Can someone explain what’s funny about this meme??\"\n",
496
+ "\n",
497
+ "display(image)\n",
498
+ "print(prompt)\n",
499
+ "print('-'*10)\n",
500
+ "output_dict = inferencer(image=image, text=prompt, understanding_output=True, **inference_hyper)\n",
501
+ "print(output_dict['text'])"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": []
510
+ }
511
+ ],
512
+ "metadata": {
513
+ "fileId": "1bfaa82d-51b0-4c13-9e4c-295ba28bcd8a",
514
+ "filePath": "/mnt/bn/seed-aws-va/chaorui/code/cdt-hf/notebooks/chat.ipynb",
515
+ "kernelspec": {
516
+ "display_name": "Python 3 (ipykernel)",
517
+ "language": "python",
518
+ "name": "python3"
519
+ },
520
+ "language_info": {
521
+ "codemirror_mode": {
522
+ "name": "ipython",
523
+ "version": 3
524
+ },
525
+ "file_extension": ".py",
526
+ "mimetype": "text/x-python",
527
+ "name": "python",
528
+ "nbconvert_exporter": "python",
529
+ "pygments_lexer": "ipython3",
530
+ "version": "3.11.2"
531
+ }
532
+ },
533
+ "nbformat": 4,
534
+ "nbformat_minor": 4
535
+ }
inferencer.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+ from typing import List, Dict, Optional, Union, Any
6
+
7
+ from PIL import Image
8
+ import torch
9
+
10
+ from data.data_utils import pil_img2rgb
11
+ from modeling.bagel.qwen2_navit import NaiveCache
12
+
13
+
14
+
15
+ VLM_THINK_SYSTEM_PROMPT = '''Generation Instructions: You should first think about the reasoning process in the mind and then provide the user with the answer.
16
+ The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
17
+
18
+ GEN_THINK_SYSTEM_PROMPT = '''Generation Instructions: You should first think about the planning process in the mind and then generate the image.
19
+ The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
20
+
21
+
22
+ class InterleaveInferencer:
23
+ def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
24
+ self.model = model
25
+ self.vae_model = vae_model
26
+ self.tokenizer = tokenizer
27
+ self.vae_transform = vae_transform
28
+ self.vit_transform = vit_transform
29
+ self.new_token_ids = new_token_ids
30
+
31
+ def init_gen_context(self):
32
+ gen_context = {
33
+ 'kv_lens': [0],
34
+ 'ropes': [0],
35
+ 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
36
+ }
37
+ return gen_context
38
+
39
+ @torch.no_grad()
40
+ def update_context_text(self, text, gen_context):
41
+ # used for interleave data, currently only support 1 data inference,
42
+
43
+ past_key_values = gen_context['past_key_values']
44
+ kv_lens = gen_context['kv_lens']
45
+ ropes = gen_context['ropes']
46
+ generation_input, kv_lens, ropes = self.model.prepare_prompts(
47
+ curr_kvlens=kv_lens,
48
+ curr_rope=ropes,
49
+ prompts=[text],
50
+ tokenizer=self.tokenizer,
51
+ new_token_ids=self.new_token_ids,
52
+ )
53
+
54
+ past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
55
+ gen_context['kv_lens'] = kv_lens
56
+ gen_context['ropes'] = ropes
57
+ gen_context['past_key_values'] = past_key_values
58
+
59
+ return gen_context
60
+
61
+ @torch.no_grad()
62
+ def update_context_image(self, image, gen_context, vae=True, vit=True):
63
+ # used for interleave data, currently only support 1 data inference,
64
+
65
+ assert vae or vit
66
+ past_key_values = gen_context['past_key_values']
67
+ kv_lens = gen_context['kv_lens']
68
+ ropes = gen_context['ropes']
69
+
70
+ if vae:
71
+ ## update vae
72
+ generation_input, kv_lens, ropes = self.model.prepare_vae_images(
73
+ curr_kvlens=kv_lens,
74
+ curr_rope=ropes,
75
+ images=[image],
76
+ transforms=self.vae_transform,
77
+ new_token_ids=self.new_token_ids,
78
+ )
79
+ past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
80
+
81
+ if vit:
82
+ ## update vit
83
+ generation_input, kv_lens, ropes = self.model.prepare_vit_images(
84
+ curr_kvlens=kv_lens,
85
+ curr_rope=ropes,
86
+ images=[image],
87
+ transforms=self.vit_transform,
88
+ new_token_ids=self.new_token_ids,
89
+ )
90
+ past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
91
+
92
+ gen_context['kv_lens'] = kv_lens
93
+ gen_context['ropes'] = ropes
94
+ gen_context['past_key_values'] = past_key_values
95
+
96
+ return gen_context
97
+
98
+ @torch.no_grad()
99
+ def gen_image(
100
+ self,
101
+ image_shape,
102
+ gen_context,
103
+ cfg_text_scale=4.0,
104
+ cfg_img_scale=1.5,
105
+
106
+ cfg_text_precontext=None,
107
+ cfg_img_precontext=None,
108
+ cfg_interval=(0.4, 1.0),
109
+ cfg_renorm_min=0.0,
110
+ cfg_renorm_type="global",
111
+
112
+ num_timesteps=50,
113
+ timestep_shift=3.0
114
+ ):
115
+ # print(cfg_renorm_type)
116
+ past_key_values = gen_context['past_key_values']
117
+ kv_lens = gen_context['kv_lens']
118
+ ropes = gen_context['ropes']
119
+ generation_input = self.model.prepare_vae_latent(
120
+ curr_kvlens=kv_lens,
121
+ curr_rope=ropes,
122
+ image_sizes=[image_shape],
123
+ new_token_ids=self.new_token_ids,
124
+ )
125
+
126
+ # text cfg
127
+ cfg_text_past_key_values = cfg_text_precontext['past_key_values']
128
+ kv_lens_cfg = cfg_text_precontext['kv_lens']
129
+ ropes_cfg = cfg_text_precontext['ropes']
130
+ generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
131
+ curr_kvlens=kv_lens_cfg,
132
+ curr_rope=ropes_cfg,
133
+ image_sizes=[image_shape],
134
+ )
135
+
136
+ # img cfg
137
+ cfg_img_past_key_values = cfg_img_precontext['past_key_values']
138
+ kv_lens_cfg = cfg_img_precontext['kv_lens']
139
+ ropes_cfg = cfg_img_precontext['ropes']
140
+ generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
141
+ curr_kvlens=kv_lens_cfg,
142
+ curr_rope=ropes_cfg,
143
+ image_sizes=[image_shape],
144
+ )
145
+
146
+ unpacked_latent = self.model.generate_image(
147
+ past_key_values=past_key_values,
148
+ cfg_text_past_key_values=cfg_text_past_key_values,
149
+ cfg_img_past_key_values=cfg_img_past_key_values,
150
+ num_timesteps=num_timesteps,
151
+ cfg_text_scale=cfg_text_scale,
152
+ cfg_img_scale=cfg_img_scale,
153
+ cfg_interval=cfg_interval,
154
+ cfg_renorm_min=cfg_renorm_min,
155
+ cfg_renorm_type=cfg_renorm_type,
156
+ timestep_shift=timestep_shift,
157
+ **generation_input,
158
+ cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
159
+ cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
160
+ cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
161
+ cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
162
+ cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
163
+ cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
164
+ cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
165
+ cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
166
+ )
167
+
168
+ image = self.decode_image(unpacked_latent[0], image_shape)
169
+ return image
170
+
171
+
172
+ def decode_image(self, latent, image_shape):
173
+ H, W = image_shape
174
+ h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
175
+
176
+ latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
177
+ latent = torch.einsum("nhwpqc->nchpwq", latent)
178
+ latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
179
+ image = self.vae_model.decode(latent)
180
+ image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
181
+ image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
182
+
183
+ return image
184
+
185
+ @torch.no_grad()
186
+ def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
187
+ gen_context = deepcopy(gen_context)
188
+ past_key_values = gen_context['past_key_values']
189
+ kv_lens = gen_context['kv_lens']
190
+ ropes = gen_context['ropes']
191
+
192
+ generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
193
+ unpacked_latent = self.model.generate_text(
194
+ past_key_values=past_key_values,
195
+ max_length=max_length,
196
+ do_sample=do_sample,
197
+ temperature=temperature,
198
+ end_token_id=self.new_token_ids['eos_token_id'],
199
+ # end_token_id=151652,
200
+ **generation_input,
201
+ )
202
+
203
+ output = self.tokenizer.decode(unpacked_latent[:,0])
204
+ return output
205
+
206
+ @torch.no_grad()
207
+ def interleave_inference(
208
+ self,
209
+ input_lists: List[Union[str, Image.Image]],
210
+ understanding_output=False,
211
+ system_prompt=None,
212
+ max_think_token_n=1000,
213
+ do_sample=False,
214
+ text_temperature=0.3,
215
+ cfg_text_scale=3.0,
216
+ cfg_img_scale=1.5,
217
+ cfg_interval=[0.4, 1.0],
218
+ timestep_shift=3.0,
219
+ num_timesteps=50,
220
+ cfg_renorm_min=0.0,
221
+ cfg_renorm_type="global",
222
+ image_shapes=(1024, 1024),
223
+ ) -> List[Union[str, Image.Image]]:
224
+
225
+ output_list = []
226
+ gen_context = self.init_gen_context()
227
+ cfg_text_context = deepcopy(gen_context)
228
+ cfg_img_context = deepcopy(gen_context)
229
+
230
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
231
+ if system_prompt:
232
+ gen_context = self.update_context_text(system_prompt, gen_context)
233
+ cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
234
+
235
+ for input_term in input_lists:
236
+ if isinstance(input_term, str):
237
+ cfg_text_context = deepcopy(gen_context)
238
+ gen_context = self.update_context_text(input_term, gen_context)
239
+ cfg_img_context = self.update_context_text(input_term, cfg_img_context)
240
+
241
+ elif isinstance(input_term, Image.Image):
242
+ input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
243
+ gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
244
+
245
+ image_shapes = input_term.size[::-1]
246
+ cfg_text_context = deepcopy(gen_context)
247
+
248
+ else:
249
+ raise ValueError(f"Unsupported input type: {type(input_term)}")
250
+
251
+ if understanding_output:
252
+ gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
253
+ output_list.append(gen_text)
254
+
255
+ else:
256
+ img = self.gen_image(
257
+ image_shapes,
258
+ gen_context,
259
+ cfg_text_precontext=cfg_text_context,
260
+ cfg_img_precontext=cfg_img_context,
261
+
262
+ cfg_text_scale=cfg_text_scale,
263
+ cfg_img_scale=cfg_img_scale,
264
+ cfg_interval=cfg_interval,
265
+ timestep_shift=timestep_shift,
266
+ num_timesteps=num_timesteps,
267
+ cfg_renorm_min=cfg_renorm_min,
268
+ cfg_renorm_type=cfg_renorm_type,
269
+ )
270
+
271
+ output_list.append(img)
272
+
273
+ return output_list
274
+
275
+ def __call__(
276
+ self,
277
+ image: Optional[Image.Image] = None,
278
+ text: Optional[str] = None,
279
+ **kargs
280
+ ) -> Dict[str, Any]:
281
+ output_dict = {'image': None, 'text': None}
282
+
283
+ if image is None and text is None:
284
+ print('Please provide at least one input: either an image or text.')
285
+ return output_dict
286
+
287
+ input_list = []
288
+ if image is not None:
289
+ input_list.append(image)
290
+ if text is not None:
291
+ input_list.append(text)
292
+
293
+ output_list = self.interleave_inference(input_list, **kargs)
294
+
295
+ for i in output_list:
296
+ if isinstance(i, Image.Image):
297
+ output_dict['image'] = i
298
+ elif isinstance(i, str):
299
+ output_dict['text'] = i
300
+ return output_dict
infz_bf16.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from datetime import datetime
5
+ from copy import deepcopy
6
+ from typing import (
7
+ Any,
8
+ AsyncIterable,
9
+ Callable,
10
+ Dict,
11
+ Generator,
12
+ List,
13
+ NamedTuple,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ )
18
+ import requests
19
+ from io import BytesIO
20
+
21
+ from PIL import Image
22
+ import torch
23
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
24
+
25
+ from data.transforms import ImageTransform
26
+ from data.data_utils import pil_img2rgb, add_special_tokens
27
+ from modeling.bagel import (
28
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
29
+ )
30
+ from modeling.qwen2 import Qwen2Tokenizer
31
+ from modeling.bagel.qwen2_navit import NaiveCache
32
+ from modeling.autoencoder import load_ae
33
+
34
+ # Set paths for your trained checkpoint
35
+ # checkpoint_dir = "/scratch/by2593/merged_checkpoint_final"
36
+ origin_checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8"
37
+ checkpoint_dir = "/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_questionimage/0000150"
38
+
39
+
40
+ checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2_test/000010'
41
+ checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2/000150'
42
+ checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_v1_final/0000500'
43
+ checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8"
44
+
45
+ checkpoint_file = "model.safetensors"
46
+ # checkpoint_file = "model_bf16.safetensors"
47
+
48
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
49
+ checkpoint_path = "/scratch/by2593/Bagel-Zebra-CoT-origin/results/checkpoints_smm_semantic_part1_v1_origin/0000050/model.safetensors"
50
+
51
+ print(f"Available GPUs: {torch.cuda.device_count()}")
52
+ print(f"GPU memory per device:")
53
+ for i in range(torch.cuda.device_count()):
54
+ props = torch.cuda.get_device_properties(i)
55
+ print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB")
56
+
57
+ # LLM config preparing (use base model configs)
58
+ llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json"))
59
+ llm_config.qk_norm = True
60
+ llm_config.tie_word_embeddings = False
61
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
62
+
63
+ # ViT config preparing (use base model configs)
64
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json"))
65
+ vit_config.rope = False
66
+ vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
67
+
68
+ # VAE loading (use base model VAE)
69
+ vae_model, vae_config = load_ae(local_path=os.path.join(origin_checkpoint_dir, "ae.safetensors"))
70
+
71
+ # Bagel config preparing
72
+ config = BagelConfig(
73
+ visual_gen=True,
74
+ visual_und=True,
75
+ llm_config=llm_config,
76
+ vit_config=vit_config,
77
+ vae_config=vae_config,
78
+ vit_max_num_patch_per_side=70,
79
+ connector_act='gelu_pytorch_tanh',
80
+ latent_patch_size=2,
81
+ max_latent_size=64,## 默认64,改为实际的latent尺寸
82
+ )
83
+
84
+ # Import the position embedding function first
85
+ from modeling.bagel.modeling_utils import get_2d_sincos_pos_embed
86
+
87
+ # Create model with empty weights
88
+ with init_empty_weights():
89
+ language_model = Qwen2ForCausalLM(llm_config)
90
+ vit_model = SiglipVisionModel(vit_config)
91
+ model = Bagel(language_model, vit_model, config)
92
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
93
+
94
+ # Initialize position embeddings with proper values BEFORE loading checkpoint
95
+ print("Initializing position embeddings before loading...")
96
+
97
+ # Initialize latent_pos_embed if it exists
98
+ if hasattr(model, 'latent_pos_embed'):
99
+ print("Initializing latent_pos_embed...")
100
+ pos_embed = get_2d_sincos_pos_embed(model.latent_pos_embed.hidden_size, model.latent_pos_embed.max_num_patch_per_side)
101
+ # Create parameter with actual values, not meta
102
+ model.latent_pos_embed.pos_embed = torch.nn.Parameter(
103
+ torch.from_numpy(pos_embed).float(), requires_grad=False
104
+ )
105
+ print(f"latent_pos_embed initialized with shape {model.latent_pos_embed.pos_embed.shape}")
106
+
107
+ # Initialize vit_pos_embed if it exists
108
+ if hasattr(model, 'vit_pos_embed'):
109
+ print("Initializing vit_pos_embed...")
110
+ pos_embed = get_2d_sincos_pos_embed(model.vit_pos_embed.hidden_size, model.vit_pos_embed.max_num_patch_per_side)
111
+ # Create parameter with actual values, not meta
112
+ model.vit_pos_embed.pos_embed = torch.nn.Parameter(
113
+ torch.from_numpy(pos_embed).float(), requires_grad=False
114
+ )
115
+ print(f"vit_pos_embed initialized with shape {model.vit_pos_embed.pos_embed.shape}")
116
+
117
+ print("Position embeddings initialized successfully")
118
+
119
+ # Tokenizer Preparing (use base model tokenizer)
120
+ tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir)
121
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
122
+
123
+ # Image Transform Preparing
124
+ vae_transform = ImageTransform(1024, 512, 16)
125
+ vit_transform = ImageTransform(980, 512, 14)
126
+
127
+ # Device mapping for 8x80GB GPUs - use bf16 directly
128
+ max_mem_per_gpu = "80GiB"
129
+
130
+ print("Setting up device mapping...")
131
+ device_map = infer_auto_device_map(
132
+ model,
133
+ max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
134
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
135
+ dtype=torch.bfloat16, # Use bf16 for device mapping
136
+ )
137
+
138
+ print("Device map:", device_map)
139
+
140
+ # Handle same-device modules
141
+ same_device_modules = [
142
+ 'language_model.model.embed_tokens',
143
+ 'time_embedder',
144
+ 'latent_pos_embed',
145
+ 'vae2llm',
146
+ 'llm2vae',
147
+ 'connector',
148
+ 'vit_pos_embed'
149
+ ]
150
+
151
+ if torch.cuda.device_count() == 1:
152
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
153
+ for k in same_device_modules:
154
+ if k in device_map:
155
+ device_map[k] = first_device
156
+ else:
157
+ device_map[k] = "cuda:0"
158
+ else:
159
+ first_device = device_map.get(same_device_modules[0])
160
+ if first_device is not None:
161
+ for k in same_device_modules:
162
+ if k in device_map:
163
+ device_map[k] = first_device
164
+
165
+ print("Final device map:", device_map)
166
+
167
+ # Load checkpoint directly in bf16
168
+ print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}")
169
+ print("Loading model from safetensors file...")
170
+
171
+ # Load model directly in bf16
172
+ model = load_checkpoint_and_dispatch(
173
+ model,
174
+ checkpoint=checkpoint_path,
175
+ device_map=device_map,
176
+ offload_buffers=False,
177
+ dtype=torch.bfloat16, # Load directly as bf16
178
+ force_hooks=True,
179
+ )
180
+
181
+ model = model.eval()
182
+
183
+ print('Model loaded directly in bfloat16!')
184
+ print(f"Model dtype: {next(model.parameters()).dtype}")
185
+
186
+ # Position embeddings were already initialized before model loading
187
+ print("Position embeddings were pre-initialized before loading checkpoint")
188
+
189
+ print("Model loading completed successfully!")
190
+
191
+ # Check memory usage
192
+ print("GPU memory usage after loading:")
193
+ for i in range(torch.cuda.device_count()):
194
+ if torch.cuda.memory_allocated(i) > 0:
195
+ allocated = torch.cuda.memory_allocated(i) / 1e9
196
+ cached = torch.cuda.memory_reserved(i) / 1e9
197
+ print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached")
198
+
199
+ # Rest of inference code
200
+ from inferencer import InterleaveInferencer
201
+
202
+ inferencer = InterleaveInferencer(
203
+ model=model,
204
+ vae_model=vae_model,
205
+ tokenizer=tokenizer,
206
+ vae_transform=vae_transform,
207
+ vit_transform=vit_transform,
208
+ new_token_ids=new_token_ids
209
+ )
210
+
211
+ import random
212
+ import numpy as np
213
+
214
+ seed = 42
215
+ random.seed(seed)
216
+ np.random.seed(seed)
217
+ torch.manual_seed(seed)
218
+ if torch.cuda.is_available():
219
+ torch.cuda.manual_seed(seed)
220
+ torch.cuda.manual_seed_all(seed)
221
+ torch.backends.cudnn.deterministic = True
222
+ torch.backends.cudnn.benchmark = False
223
+
224
+ inference_hyper=dict(
225
+ do_sample=False,
226
+ text_temperature=0.0,
227
+ cfg_text_scale=4.0,
228
+ cfg_img_scale=2.0,
229
+ cfg_interval=[0.0, 1.0],
230
+ timestep_shift=3.0,
231
+ num_timesteps=50,
232
+ cfg_renorm_min=0.0,
233
+ cfg_renorm_type="text_channel",
234
+ )
235
+
236
+ INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving.'''
237
+ INTERLEAVED_SYSTEM_PROMPT = ''
238
+
239
+ # Original example (004 case) - commented out
240
+ # prompt = '''My goal is to generate a visual guide for constructing a specific shape using a set of blocks. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input includes 2 images of multiple blocks that will be used <image_start>[problem_image_1]<image_end><image_start>[problem_image_2]<image_end> and an image of the final desired shape<image_start>[problem_image_3]<image_end>. I need to imagine and generate images of intermediate steps, leading up to the final construction. Step 0 has been completed: a red arch block has been placed on top of the ground. The image after step 0 is provided<image_start>[problem_image_4]<image_end>. Now I need to generate the image for step 1, considering spatial relationships and stability.'''
241
+
242
+ # Use the new example data (145 case)
243
+ prompt = '''Based on the construction task shown below, follow the instructions to complete the build. Given the final desired shape of blocks shown in the first image<image_start>[problem_image_1]<image_end> which is viewed from a Front45 angle, perform a series of specified manipulations. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input also includes 3 images of multiple blocks that will be used.<image_start>[problem_image_2]<image_end><image_start>[problem_image_3]<image_end><image_start>[problem_image_4]<image_end> Step 0 has been completed: a orange cylinder block has been placed on top of the ground. The image after step 0 is provided.<image_start>[problem_image_5]<image_end>'''
244
+
245
+ # Load images from the new example paths (145 case)
246
+ image = []
247
+ base_path = '/scratch/by2593/project/SMM'
248
+ image_paths = [
249
+ f'{base_path}/semantic_blocks_part1/145/final_state/145_final_1.png', # problem_image_1 - final desired shape
250
+ f'{base_path}/SMM_data/each_block_views_diffposes/cylinder_orange.png', # problem_image_2 - orange cylinder
251
+ f'{base_path}/SMM_data/each_block_views_diffposes/cuboid3_yellow.png', # problem_image_3 - yellow cuboid3
252
+ f'{base_path}/SMM_data/each_block_views_diffposes/triangle_orange.png', # problem_image_4 - orange triangle
253
+ f'{base_path}/semantic_blocks_part1/145/steps/view_1/145_step0_1.png', # problem_image_5 - image after step 0
254
+ ]
255
+
256
+ print("Loading input images:")
257
+ for i, img_path in enumerate(image_paths):
258
+ try:
259
+ img = Image.open(img_path).convert('RGB')
260
+ image.append(img)
261
+ print(f" ✓ Loaded problem_image_{i+1}: {img_path}")
262
+ print(f" Image size: {img.size}")
263
+ except Exception as e:
264
+ print(f" ✗ Failed to load {img_path}: {e}")
265
+ # Create a placeholder image if file not found
266
+ img = Image.new('RGB', (512, 512), color='gray')
267
+ image.append(img)
268
+ print(f" ⚠ Using placeholder for problem_image_{i+1}")
269
+
270
+ print(prompt)
271
+ print('-'*50)
272
+
273
+ # Create output folder with timestamp
274
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
275
+ output_folder = f"reasoning_output_example_145_{timestamp}"
276
+ images_folder = os.path.join(output_folder, "images")
277
+ os.makedirs(images_folder, exist_ok=True)
278
+
279
+ print(f"Output will be saved to: {output_folder}")
280
+
281
+ # Save the original problem images if they exist
282
+ problem_image_paths = []
283
+ if image is not None:
284
+ if isinstance(image, list):
285
+ # Handle multiple images
286
+ for i, img in enumerate(image):
287
+ problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png")
288
+ relative_path = os.path.join("images", f"problem_image_{i+1}.png")
289
+ img.save(problem_image_path)
290
+ problem_image_paths.append(relative_path)
291
+ print(f"Problem image {i+1} saved at '{problem_image_path}'")
292
+ else:
293
+ # Handle single image
294
+ problem_image_path = os.path.join(images_folder, "problem_image.png")
295
+ relative_path = os.path.join("images", "problem_image.png")
296
+ image.save(problem_image_path)
297
+ problem_image_paths.append(relative_path)
298
+ print(f"Problem image saved at '{problem_image_path}'")
299
+
300
+ reasoning_text = []
301
+ reasoning_images = []
302
+ generated_image_paths = [] # Store relative paths to generated reasoning images
303
+
304
+ # Create input with multiple images properly flattened
305
+ if image is not None:
306
+ if isinstance(image, list):
307
+ current_input = [prompt] + image # Flatten the list of images
308
+ else:
309
+ current_input = [prompt, image]
310
+ else:
311
+ current_input = [prompt]
312
+
313
+ # Loop until no more vision_start tokens
314
+ iteration = 0
315
+ while True:
316
+ # Get understanding output
317
+ print(f"iteration: {iteration}")
318
+ output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
319
+
320
+ # Check for stopping conditions
321
+ has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0]
322
+
323
+ # Stop if we have a final answer OR if there's no vision token (no more images to generate)
324
+ # should_stop = has_final_answer or not has_vision_token
325
+ should_stop = has_final_answer
326
+
327
+
328
+ if should_stop:
329
+ if output[0].strip():
330
+ extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
331
+ reasoning_text.append(extracted_text)
332
+ print(f"{extracted_text}")
333
+ current_input = current_input + [extracted_text]
334
+ break
335
+
336
+ extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
337
+ reasoning_text.append(extracted_text)
338
+ print(f"{extracted_text}")
339
+
340
+ # Generate image based on current reasoning
341
+ current_input_with_reasoning = current_input + [extracted_text]
342
+ output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
343
+ image_output = output[0]
344
+
345
+ # Save and collect the generated image
346
+ reasoning_images.append(image_output)
347
+ image_filename = f'reasoning_image_{iteration + 1}.png'
348
+ image_path = os.path.join(images_folder, image_filename)
349
+ relative_image_path = os.path.join("images", image_filename) # Relative path for JSON
350
+
351
+ image_output.save(image_path)
352
+ generated_image_paths.append(relative_image_path)
353
+ print(f"Image saved at '{image_path}'")
354
+
355
+ # Update input for next iteration
356
+ current_input = current_input_with_reasoning + [image_output]
357
+
358
+ iteration += 1
359
+ print('-'*50)
360
+
361
+ # Save reasoning data to JSON
362
+ reasoning_data = {
363
+ "timestamp": timestamp,
364
+ "prompt": prompt,
365
+ "system_prompt": INTERLEAVED_SYSTEM_PROMPT,
366
+ "problem_image_paths": problem_image_paths if problem_image_paths else None,
367
+ "response": [
368
+ {
369
+ "step": i + 1,
370
+ "text": text,
371
+ "image_path": generated_image_paths[i] if i < len(generated_image_paths) else None
372
+ }
373
+ for i, text in enumerate(reasoning_text)
374
+ ],
375
+ "total_steps": len(reasoning_text),
376
+ "total_images": len(generated_image_paths)
377
+ }
378
+
379
+ # Save JSON file
380
+ json_path = os.path.join(output_folder, "reasoning_data.json")
381
+ with open(json_path, 'w', encoding='utf-8') as f:
382
+ json.dump(reasoning_data, f, indent=2, ensure_ascii=False)
383
+
384
+ print(f"\nReasoning complete!")
385
+ print(f"Output folder: {output_folder}")
386
+ print(f"JSON metadata: {json_path}")
387
+ print(f"Generated {len(generated_image_paths)} images and {len(reasoning_text)} text steps")
388
+
389
+ # python infz_bf16.py
390
+
391
+
392
+ # import os
393
+ # import json
394
+ # from datetime import datetime
395
+ # from copy import deepcopy
396
+ # from typing import (
397
+ # Any,
398
+ # AsyncIterable,
399
+ # Callable,
400
+ # Dict,
401
+ # Generator,
402
+ # List,
403
+ # NamedTuple,
404
+ # Optional,
405
+ # Tuple,
406
+ # Union,
407
+ # )
408
+ # import requests
409
+ # from io import BytesIO
410
+
411
+ # from PIL import Image
412
+ # import torch
413
+ # from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
414
+
415
+ # from data.transforms import ImageTransform
416
+ # from data.data_utils import pil_img2rgb, add_special_tokens
417
+ # from modeling.bagel import (
418
+ # BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
419
+ # )
420
+ # from modeling.qwen2 import Qwen2Tokenizer
421
+ # from modeling.bagel.qwen2_navit import NaiveCache
422
+ # from modeling.autoencoder import load_ae
423
+
424
+ # # Set paths for your trained checkpoint
425
+ # checkpoint_dir = "path/to/your/HF_HOME/models/Bagel-Zebra-CoT"
426
+ # checkpoint_file = "model_bf16.safetensors"
427
+ # checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
428
+
429
+
430
+ # print(f"Available GPUs: {torch.cuda.device_count()}")
431
+ # print(f"GPU memory per device:")
432
+ # for i in range(torch.cuda.device_count()):
433
+ # props = torch.cuda.get_device_properties(i)
434
+ # print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB")
435
+
436
+ # # LLM config preparing (use base model configs)
437
+ # llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json"))
438
+ # llm_config.qk_norm = True
439
+ # llm_config.tie_word_embeddings = False
440
+ # llm_config.layer_module = "Qwen2MoTDecoderLayer"
441
+
442
+ # # ViT config preparing (use base model configs)
443
+ # vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json"))
444
+ # vit_config.rope = False
445
+ # vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
446
+
447
+ # # VAE loading (use base model VAE)
448
+ # vae_model, vae_config = load_ae(local_path=os.path.join(checkpoint_dir, "ae.safetensors"))
449
+
450
+ # # Bagel config preparing
451
+ # config = BagelConfig(
452
+ # visual_gen=True,
453
+ # visual_und=True,
454
+ # llm_config=llm_config,
455
+ # vit_config=vit_config,
456
+ # vae_config=vae_config,
457
+ # vit_max_num_patch_per_side=70,
458
+ # connector_act='gelu_pytorch_tanh',
459
+ # latent_patch_size=2,
460
+ # max_latent_size=64,
461
+ # )
462
+
463
+ # # Create model with empty weights - IMPORTANT: Use float32 initially to match checkpoint
464
+ # with init_empty_weights():
465
+ # language_model = Qwen2ForCausalLM(llm_config)
466
+ # vit_model = SiglipVisionModel(vit_config)
467
+ # model = Bagel(language_model, vit_model, config)
468
+ # model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
469
+
470
+ # # Tokenizer Preparing (use base model tokenizer)
471
+ # tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir)
472
+ # tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
473
+
474
+ # # Image Transform Preparing
475
+ # vae_transform = ImageTransform(1024, 512, 16)
476
+ # vit_transform = ImageTransform(980, 512, 14)
477
+
478
+ # # Device mapping for 8x80GB GPUs - use bf16 directly
479
+ # max_mem_per_gpu = "80GiB"
480
+
481
+ # print("Setting up device mapping...")
482
+ # device_map = infer_auto_device_map(
483
+ # model,
484
+ # max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
485
+ # no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
486
+ # dtype=torch.bfloat16, # Use bf16 for device mapping
487
+ # )
488
+
489
+ # print("Device map:", device_map)
490
+
491
+ # # Handle same-device modules
492
+ # same_device_modules = [
493
+ # 'language_model.model.embed_tokens',
494
+ # 'time_embedder',
495
+ # 'latent_pos_embed',
496
+ # 'vae2llm',
497
+ # 'llm2vae',
498
+ # 'connector',
499
+ # 'vit_pos_embed'
500
+ # ]
501
+
502
+ # if torch.cuda.device_count() == 1:
503
+ # first_device = device_map.get(same_device_modules[0], "cuda:0")
504
+ # for k in same_device_modules:
505
+ # if k in device_map:
506
+ # device_map[k] = first_device
507
+ # else:
508
+ # device_map[k] = "cuda:0"
509
+ # else:
510
+ # first_device = device_map.get(same_device_modules[0])
511
+ # if first_device is not None:
512
+ # for k in same_device_modules:
513
+ # if k in device_map:
514
+ # device_map[k] = first_device
515
+
516
+ # print("Final device map:", device_map)
517
+
518
+ # # Load checkpoint directly in bf16
519
+ # print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}")
520
+ # print("Loading model from safetensors file...")
521
+
522
+ # # Load model directly in bf16
523
+ # model = load_checkpoint_and_dispatch(
524
+ # model,
525
+ # checkpoint=checkpoint_path,
526
+ # device_map=device_map,
527
+ # offload_buffers=False,
528
+ # dtype=torch.bfloat16, # Load directly as bf16
529
+ # force_hooks=True,
530
+ # )
531
+
532
+ # model = model.eval()
533
+
534
+ # print('Model loaded directly in bfloat16!')
535
+ # print(f"Model dtype: {next(model.parameters()).dtype}")
536
+ # print("Model loading completed successfully!")
537
+
538
+ # # Check memory usage
539
+ # print("GPU memory usage after loading:")
540
+ # for i in range(torch.cuda.device_count()):
541
+ # if torch.cuda.memory_allocated(i) > 0:
542
+ # allocated = torch.cuda.memory_allocated(i) / 1e9
543
+ # cached = torch.cuda.memory_reserved(i) / 1e9
544
+ # print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached")
545
+
546
+ # # Rest of inference code
547
+ # from inferencer import InterleaveInferencer
548
+
549
+ # inferencer = InterleaveInferencer(
550
+ # model=model,
551
+ # vae_model=vae_model,
552
+ # tokenizer=tokenizer,
553
+ # vae_transform=vae_transform,
554
+ # vit_transform=vit_transform,
555
+ # new_token_ids=new_token_ids
556
+ # )
557
+
558
+ # import random
559
+ # import numpy as np
560
+
561
+ # seed = 42
562
+ # random.seed(seed)
563
+ # np.random.seed(seed)
564
+ # torch.manual_seed(seed)
565
+ # if torch.cuda.is_available():
566
+ # torch.cuda.manual_seed(seed)
567
+ # torch.cuda.manual_seed_all(seed)
568
+ # torch.backends.cudnn.deterministic = True
569
+ # torch.backends.cudnn.benchmark = False
570
+
571
+ # inference_hyper=dict(
572
+ # do_sample=True,
573
+ # text_temperature=0.3,
574
+ # cfg_text_scale=4.0,
575
+ # cfg_img_scale=2.0,
576
+ # cfg_interval=[0.0, 1.0],
577
+ # timestep_shift=3.0,
578
+ # num_timesteps=50,
579
+ # cfg_renorm_min=0.0,
580
+ # cfg_renorm_type="text_channel",
581
+ # )
582
+
583
+ # INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving. Provide your final conclusion clearly in the format of "Final Answer: <answer here>"'''
584
+
585
+ # prompt = '''Subtract all cylinders. Add 1 red sphere. How many objects are left?'''
586
+ # image = Image.open('test_images/image.png')
587
+
588
+ # print(prompt)
589
+ # print('-'*50)
590
+
591
+ # # Create output folder with timestamp
592
+ # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
593
+ # output_folder = f"reasoning_output_{timestamp}"
594
+ # images_folder = os.path.join(output_folder, "images")
595
+ # os.makedirs(images_folder, exist_ok=True)
596
+
597
+ # # Save the original problem images if they exist
598
+ # problem_image_paths = []
599
+ # if image is not None:
600
+ # if isinstance(image, list):
601
+ # # Handle multiple images
602
+ # for i, img in enumerate(image):
603
+ # problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png")
604
+ # relative_path = os.path.join("images", f"problem_image_{i+1}.png")
605
+ # img.save(problem_image_path)
606
+ # problem_image_paths.append(relative_path)
607
+ # print(f"Problem image {i+1} saved at '{problem_image_path}'")
608
+ # else:
609
+ # # Handle single image
610
+ # problem_image_path = os.path.join(images_folder, "problem_image.png")
611
+ # relative_path = os.path.join("images", "problem_image.png")
612
+ # image.save(problem_image_path)
613
+ # problem_image_paths.append(relative_path)
614
+ # print(f"Problem image saved at '{problem_image_path}'")
615
+
616
+ # reasoning_text = []
617
+ # reasoning_images = []
618
+ # image_paths = [] # Store relative paths to images
619
+
620
+ # # Create input with multiple images properly flattened
621
+ # if image is not None:
622
+ # if isinstance(image, list):
623
+ # current_input = [prompt] + image # Flatten the list of images
624
+ # else:
625
+ # current_input = [prompt, image]
626
+ # else:
627
+ # current_input = [prompt]
628
+
629
+ # # Loop until no more vision_start tokens
630
+ # iteration = 0
631
+ # while True:
632
+ # # Get understanding output
633
+ # print(f"iteration: {iteration}")
634
+ # output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
635
+
636
+ # # Check for stopping conditions
637
+ # has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0]
638
+
639
+ # # Stop if we have a final answer OR if there's no vision token (no more images to generate)
640
+ # # should_stop = has_final_answer or not has_vision_token
641
+ # should_stop = has_final_answer
642
+
643
+
644
+ # if should_stop:
645
+ # if output[0].strip():
646
+ # extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
647
+ # reasoning_text.append(extracted_text)
648
+ # print(f"{extracted_text}")
649
+ # current_input = current_input + [extracted_text]
650
+ # break
651
+
652
+ # extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
653
+ # reasoning_text.append(extracted_text)
654
+ # print(f"{extracted_text}")
655
+
656
+ # # Generate image based on current reasoning
657
+ # current_input_with_reasoning = current_input + [extracted_text]
658
+ # output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
659
+ # image_output = output[0]
660
+
661
+ # # Save and collect the generated image
662
+ # reasoning_images.append(image_output)
663
+ # image_filename = f'reasoning_image_{iteration + 1}.png'
664
+ # image_path = os.path.join(images_folder, image_filename)
665
+ # relative_image_path = os.path.join("images", image_filename) # Relative path for JSON
666
+
667
+ # image_output.save(image_path)
668
+ # image_paths.append(relative_image_path)
669
+ # print(f"Image saved at '{image_path}'")
670
+
671
+ # # Update input for next iteration
672
+ # current_input = current_input_with_reasoning + [image_output]
673
+
674
+ # iteration += 1
675
+ # print('-'*50)
676
+
677
+ # # Save reasoning data to JSON
678
+ # reasoning_data = {
679
+ # "timestamp": timestamp,
680
+ # "prompt": prompt,
681
+ # "system_prompt": INTERLEAVED_SYSTEM_PROMPT,
682
+ # "problem_image_paths": problem_image_paths if problem_image_paths else None,
683
+ # "response": [
684
+ # {
685
+ # "step": i + 1,
686
+ # "text": text,
687
+ # "image_path": image_paths[i] if i < len(image_paths) else None
688
+ # }
689
+ # for i, text in enumerate(reasoning_text)
690
+ # ],
691
+ # "total_steps": len(reasoning_text),
692
+ # "total_images": len(image_paths)
693
+ # }
694
+
695
+ # # Save JSON file
696
+ # json_path = os.path.join(output_folder, "reasoning_data.json")
697
+ # with open(json_path, 'w', encoding='utf-8') as f:
698
+ # json.dump(reasoning_data, f, indent=2, ensure_ascii=False)
699
+
700
+ # print(f"\nReasoning complete!")
701
+ # print(f"Output folder: {output_folder}")
702
+ # print(f"JSON metadata: {json_path}")
703
+ # print(f"Generated {len(image_paths)} images and {len(reasoning_text)} text steps")
704
+
modeling/bagel/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from .bagel import BagelConfig, Bagel
6
+ from .qwen2_navit import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
7
+ from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
8
+
9
+
10
+ __all__ = [
11
+ 'BagelConfig',
12
+ 'Bagel',
13
+ 'Qwen2Config',
14
+ 'Qwen2Model',
15
+ 'Qwen2ForCausalLM',
16
+ 'SiglipVisionConfig',
17
+ 'SiglipVisionModel',
18
+ ]
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord==0.6.0
2
+ einops==0.8.1
3
+ huggingface_hub==0.29.1
4
+ matplotlib==3.7.0
5
+ numpy==1.24.4
6
+ opencv-python-headless
7
+ pyarrow==11.0.0
8
+ PyYAML==6.0.2
9
+ Requests==2.32.3
10
+ safetensors==0.4.5
11
+ scipy==1.10.1
12
+ sentencepiece==0.1.99
13
+ torch==2.5.1
14
+ torchvision==0.20.1
15
+ transformers==4.49.0
16
+ accelerate>=0.34.0
17
+ wandb
18
+ gradio
19
+ setuptools
20
+ wheel
21
+ ninja
22
+ bitsandbytes
23
+ xlsxwriter
24
+ triton ; sys_platform != 'win32'
25
+ triton-windows ; sys_platform == 'win32'