Add files using upload-large-folder tool
Browse files- .gitignore +16 -0
- EVAL.md +387 -0
- LICENSE +201 -0
- README.md +139 -0
- TRAIN.md +168 -0
- app.py +613 -0
- bug.log +13 -0
- data/configs/example_smm_semantic.yaml +50 -0
- data/data_utils.py +177 -0
- data/interleave_datasets/__init__.py +6 -0
- data/parquet_utils.py +89 -0
- data/t2i_dataset.py +128 -0
- data/transforms.py +287 -0
- data/video_utils.py +165 -0
- data/vlm_dataset.py +195 -0
- download_model.py +12 -0
- inference.ipynb +535 -0
- inferencer.py +300 -0
- infz_bf16.py +704 -0
- modeling/bagel/__init__.py +18 -0
- requirements.txt +25 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb
|
| 2 |
+
__pycache__
|
| 3 |
+
.vscode
|
| 4 |
+
notebooks
|
| 5 |
+
results
|
| 6 |
+
*.ipynb_checkpoints
|
| 7 |
+
eval_results
|
| 8 |
+
tests
|
| 9 |
+
.DS_Store
|
| 10 |
+
gradio.sh
|
| 11 |
+
models
|
| 12 |
+
bagel_example
|
| 13 |
+
Zebra-CoT
|
| 14 |
+
model_bf16.safetensors
|
| 15 |
+
zebra-cot.tar.gz
|
| 16 |
+
reasoning_output*
|
EVAL.md
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VLM
|
| 2 |
+
We follow [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html) to evaluate the performance on MME, MMBench, MMMU, MMVet, MathVista and MMVP.
|
| 3 |
+
|
| 4 |
+
## Data prepration
|
| 5 |
+
Please follow the [InternVL2](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the corresponding data. And the link the data under `vlm`.
|
| 6 |
+
|
| 7 |
+
The final directory structure is:
|
| 8 |
+
```shell
|
| 9 |
+
data
|
| 10 |
+
├── MathVista
|
| 11 |
+
├── mmbench
|
| 12 |
+
├── mme
|
| 13 |
+
├── MMMU
|
| 14 |
+
├── mm-vet
|
| 15 |
+
└── MMVP
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Evaluation
|
| 19 |
+
|
| 20 |
+
Directly run `scripts/eval/run_eval_vlm.sh` to evaluate different benchmarks. The output will be saved in `$output_path`.
|
| 21 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 22 |
+
- Increase `GPUS` if you want to run faster.
|
| 23 |
+
- For MMBench, please use the official [evaluation server](https://mmbench.opencompass.org.cn/mmbench-submission).
|
| 24 |
+
- For MMVet, please use the official [evaluation server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator).
|
| 25 |
+
- For MathVista, please set `$openai_api_key` in `scripts/eval/run_eval_vlm.sh` and `your_api_url` in `eval/vlm/eval/mathvista/utilities.py`. The default GPT version is `gpt-4o-2024-11-20`.
|
| 26 |
+
- For MMMU, we use CoT in the report, which improve the accuracy by about 2%. For evaluation of the oprn-ended answer, we use GPT-4o for judgement.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# GenEval
|
| 30 |
+
We modify the code in [GenEval](https://github.com/djghosh13/geneval/tree/main) for faster evaluation.
|
| 31 |
+
|
| 32 |
+
## Setup
|
| 33 |
+
Install the following dependencies:
|
| 34 |
+
```shell
|
| 35 |
+
pip install open-clip-torch
|
| 36 |
+
pip install clip-benchmark
|
| 37 |
+
pip install --upgrade setuptools
|
| 38 |
+
|
| 39 |
+
sudo pip install -U openmim
|
| 40 |
+
sudo mim install mmengine mmcv-full==1.7.2
|
| 41 |
+
|
| 42 |
+
git clone https://github.com/open-mmlab/mmdetection.git
|
| 43 |
+
cd mmdetection; git checkout 2.x
|
| 44 |
+
pip install -v -e .
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Download Detector:
|
| 48 |
+
```shell
|
| 49 |
+
cd ./eval/gen/geneval
|
| 50 |
+
mkdir model
|
| 51 |
+
|
| 52 |
+
bash ./evaluation/download_models.sh ./model
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Evaluation
|
| 56 |
+
Directly run `scripts/eval/run_geneval.sh` to evaluate GenEVAL. The output will be saved in `$output_path`.
|
| 57 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 58 |
+
- Set `metadata_file` to `./eval/gen/geneval/prompts/evaluation_metadata.jsonl` for original GenEval prompts.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# WISE
|
| 62 |
+
We modify the code in [WISE](https://github.com/PKU-YuanGroup/WISE/tree/main) for faster evaluation.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## Evaluation
|
| 66 |
+
Directly run `scripts/eval/run_wise.sh` to evaluate WISE. The output will be saved in `$output_path`.
|
| 67 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 68 |
+
- Set `$openai_api_key` in `scripts/eval/run_wise.sh` and `your_api_url` in `eval/gen/wise/gpt_eval_mp.py`. The default GPT version is `gpt-4o-2024-11-20`.
|
| 69 |
+
- Use `think` for thinking mode.
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# GEdit-Bench
|
| 74 |
+
We adopt the code in [GEdit-Bench](https://github.com/stepfun-ai/Step1X-Edit/blob/main/GEdit-Bench/EVAL.md) for evaluation.
|
| 75 |
+
|
| 76 |
+
## Evaluation
|
| 77 |
+
|
| 78 |
+
Modify the model path, the output path, the api key, and the api url in `scripts/eval/run_gedit.sh`. Then, run the following command:
|
| 79 |
+
```shell
|
| 80 |
+
bash script/eval/run_gedit.sh
|
| 81 |
+
```
|
| 82 |
+
The GPT version for evaluation is `gpt-4.1-2025-04-14`.
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# IntelligentBench
|
| 86 |
+
TBD
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# KRIS
|
| 90 |
+
We modify the code in [KRIS-Bench](https://github.com/mercurystraw/Kris_Bench) for faster evaluation.
|
| 91 |
+
|
| 92 |
+
## Data prepration
|
| 93 |
+
Please download the benchmark data from [KRIS-Bench](https://huggingface.co/datasets/Liang0223/KRIS_Bench) and and place it in the `KRIS_Bench` directory.
|
| 94 |
+
|
| 95 |
+
The final directory structure is:
|
| 96 |
+
```shell
|
| 97 |
+
KRIS_Bench
|
| 98 |
+
├── abstract_reasoning
|
| 99 |
+
├── anomaly_correction
|
| 100 |
+
├── biology
|
| 101 |
+
├── chemistry
|
| 102 |
+
├── color_change
|
| 103 |
+
├── count_change
|
| 104 |
+
├── geography
|
| 105 |
+
├── humanities
|
| 106 |
+
├── mathematics
|
| 107 |
+
├── medicine
|
| 108 |
+
├── multi-element_composition
|
| 109 |
+
├── multi-instruction_execution
|
| 110 |
+
├── part_completion
|
| 111 |
+
├── physics
|
| 112 |
+
├── position_movement
|
| 113 |
+
├── practical_knowledge
|
| 114 |
+
├── rule-based_reasoning
|
| 115 |
+
├── size_adjustment
|
| 116 |
+
├── temporal_prediction
|
| 117 |
+
└── viewpoint_change
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## Evaluation
|
| 121 |
+
Directly run `scripts/eval/run_kris.sh` to evaluate KRIS-Bench. The output will be saved in `$output_path`.
|
| 122 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 123 |
+
- Set `$openai_api_key` in `scripts/eval/run_kris.sh` and `your_api_url` in `eval/gen/kris/metrics_xx.py`. The default GPT version is `gpt-4o-2024-11-20`.
|
| 124 |
+
- Use `think` for thinking mode.
|
| 125 |
+
- We set `cfg_text_scale=4` and `cfg_img_scale=1.5` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
|
| 126 |
+
|
| 127 |
+
<details>
|
| 128 |
+
<summary><b>Results</b></summary>
|
| 129 |
+
<pre>
|
| 130 |
+
Category, meta-category, and overall average scores (100-point scale):
|
| 131 |
+
Attribute Perception:
|
| 132 |
+
VC: 76.64
|
| 133 |
+
VQ: 74.45
|
| 134 |
+
IF: 41.73
|
| 135 |
+
AVG: 64.27
|
| 136 |
+
Spatial Perception:
|
| 137 |
+
VC: 70.25
|
| 138 |
+
VQ: 80.00
|
| 139 |
+
IF: 37.00
|
| 140 |
+
AVG: 62.42
|
| 141 |
+
Temporal Prediction:
|
| 142 |
+
VC: 36.49
|
| 143 |
+
VQ: 61.82
|
| 144 |
+
IF: 29.05
|
| 145 |
+
AVG: 42.45
|
| 146 |
+
Social Science:
|
| 147 |
+
VC: 76.20
|
| 148 |
+
VQ: 78.80
|
| 149 |
+
IF: 37.00
|
| 150 |
+
KP: 29.60
|
| 151 |
+
AVG: 55.40
|
| 152 |
+
Natural Science:
|
| 153 |
+
VC: 69.59
|
| 154 |
+
VQ: 84.03
|
| 155 |
+
IF: 40.27
|
| 156 |
+
KP: 30.15
|
| 157 |
+
AVG: 56.01
|
| 158 |
+
Logical Reasoning:
|
| 159 |
+
VC: 80.17
|
| 160 |
+
VQ: 85.67
|
| 161 |
+
IF: 26.33
|
| 162 |
+
KP: 18.00
|
| 163 |
+
AVG: 52.54
|
| 164 |
+
Instruction Decomposition:
|
| 165 |
+
VC: 40.17
|
| 166 |
+
VQ: 69.50
|
| 167 |
+
IF: 42.00
|
| 168 |
+
AVG: 50.56
|
| 169 |
+
Factual Knowledge:
|
| 170 |
+
AVG: 60.26
|
| 171 |
+
Conceptual Knowledge:
|
| 172 |
+
AVG: 55.86
|
| 173 |
+
Procedural Knowledge:
|
| 174 |
+
AVG: 51.69
|
| 175 |
+
Overall:
|
| 176 |
+
AVG: 56.21
|
| 177 |
+
</pre>
|
| 178 |
+
</details>
|
| 179 |
+
|
| 180 |
+
<details>
|
| 181 |
+
<summary><b>Results w/ CoT</b></summary>
|
| 182 |
+
<pre>
|
| 183 |
+
Category, meta-category, and overall average scores (100-point scale):
|
| 184 |
+
Attribute Perception:
|
| 185 |
+
VC: 75.09
|
| 186 |
+
VQ: 74.00
|
| 187 |
+
IF: 53.18
|
| 188 |
+
AVG: 67.42
|
| 189 |
+
Spatial Perception:
|
| 190 |
+
VC: 78.75
|
| 191 |
+
VQ: 87.25
|
| 192 |
+
IF: 39.00
|
| 193 |
+
AVG: 68.33
|
| 194 |
+
Temporal Prediction:
|
| 195 |
+
VC: 48.31
|
| 196 |
+
VQ: 81.08
|
| 197 |
+
IF: 46.62
|
| 198 |
+
AVG: 58.67
|
| 199 |
+
Social Science:
|
| 200 |
+
VC: 80.40
|
| 201 |
+
VQ: 79.40
|
| 202 |
+
IF: 51.60
|
| 203 |
+
KP: 42.80
|
| 204 |
+
AVG: 63.55
|
| 205 |
+
Natural Science:
|
| 206 |
+
VC: 67.68
|
| 207 |
+
VQ: 82.95
|
| 208 |
+
IF: 52.10
|
| 209 |
+
KP: 42.88
|
| 210 |
+
AVG: 61.40
|
| 211 |
+
Logical Reasoning:
|
| 212 |
+
VC: 62.83
|
| 213 |
+
VQ: 79.67
|
| 214 |
+
IF: 28.33
|
| 215 |
+
KP: 21.67
|
| 216 |
+
AVG: 48.12
|
| 217 |
+
Instruction Decomposition:
|
| 218 |
+
VC: 47.83
|
| 219 |
+
VQ: 66.83
|
| 220 |
+
IF: 36.00
|
| 221 |
+
AVG: 50.22
|
| 222 |
+
Factual Knowledge:
|
| 223 |
+
AVG: 66.18
|
| 224 |
+
Conceptual Knowledge:
|
| 225 |
+
AVG: 61.92
|
| 226 |
+
Procedural Knowledge:
|
| 227 |
+
AVG: 49.02
|
| 228 |
+
Overall:
|
| 229 |
+
AVG: 60.18
|
| 230 |
+
</pre>
|
| 231 |
+
</details>
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# RISE
|
| 235 |
+
We modify the code in [RISEBench](https://github.com/PhoenixZ810/RISEBench) for faster evaluation.
|
| 236 |
+
|
| 237 |
+
## Data prepration
|
| 238 |
+
Please download the benchmark data from [RISEBench](https://huggingface.co/datasets/PhoenixZ/RISEBench) and and place it in the `data` directory.
|
| 239 |
+
|
| 240 |
+
The final directory structure is:
|
| 241 |
+
```shell
|
| 242 |
+
data
|
| 243 |
+
├── datav2_total_w_subtask.json
|
| 244 |
+
├── causal_reasoning_images
|
| 245 |
+
├── logical_reasoning_images
|
| 246 |
+
├── spatial_reasoning_images
|
| 247 |
+
└── temporal_reasoning_images
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## Evaluation
|
| 251 |
+
Directly run `scripts/eval/run_rise.sh` to evaluate RISEBench. The output will be saved in `$output_path`.
|
| 252 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 253 |
+
- Set `$openai_api_key` in `scripts/eval/run_rise.sh` and `your_api_url` in `eval/gen/rise/gpt_eval.py`. The default GPT version is `gpt-4.1-2025-04-14`.
|
| 254 |
+
- Use `think` for thinking mode.
|
| 255 |
+
- We set `cfg_text_scale=4` and `cfg_img_scale=2.0` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
|
| 256 |
+
|
| 257 |
+
<details>
|
| 258 |
+
<summary><b>Results (cfg_img_scale=1.5)</b></summary>
|
| 259 |
+
<pre>
|
| 260 |
+
- Score-Origin Score-Percentage Accuracy
|
| 261 |
+
0 Overall 2.537778 38.444444 0.061111
|
| 262 |
+
1 Temporal 2.654118 41.352941 0.023529
|
| 263 |
+
2 Causal 2.788889 44.722222 0.055556
|
| 264 |
+
3 Spatial 3.452000 61.300000 0.140000
|
| 265 |
+
4 Logical 1.080000 2.000000 0.011765
|
| 266 |
+
5 Overall_Reasoning 2.458333 36.458333 NaN
|
| 267 |
+
6 Overall_ApprConsistency 3.141643 53.541076 NaN
|
| 268 |
+
7 Overall_VisualPlausibility_total 3.920000 73.000000 NaN
|
| 269 |
+
8 Temporal_Reasoning 2.588235 39.705882 NaN
|
| 270 |
+
9 Temporal_Consistency 3.250000 56.250000 NaN
|
| 271 |
+
10 Temporal_Quality 3.505882 62.647059 NaN
|
| 272 |
+
11 Causal_Reasoning 2.733333 43.333333 NaN
|
| 273 |
+
12 Causal_Consistency 3.579545 64.488636 NaN
|
| 274 |
+
13 Causal_Quality 3.688889 67.222222 NaN
|
| 275 |
+
14 Spatial_Reasoning 3.300000 57.500000 NaN
|
| 276 |
+
15 Spatial_Consistency 3.330000 58.250000 NaN
|
| 277 |
+
16 Spatial_Quality 4.480000 87.000000 NaN
|
| 278 |
+
17 Logical_Reasoning 1.047059 1.176471 NaN
|
| 279 |
+
18 Logical_Consistency 2.364706 34.117647 NaN
|
| 280 |
+
19 Temp-Life Progression 2.757895 43.947368 0.000000
|
| 281 |
+
20 Temp-Material Progression 2.500000 37.500000 0.021739
|
| 282 |
+
21 Temp-Environmental Cycles 3.061538 51.538462 0.076923
|
| 283 |
+
22 Temp-Societal Transformation 2.628571 40.714286 0.000000
|
| 284 |
+
23 Causal-Structural Deformation 2.766667 44.166667 0.055556
|
| 285 |
+
24 Causal-State Transition 3.112000 52.800000 0.080000
|
| 286 |
+
25 Causal-Chemical and Biological Transformation 2.325000 33.125000 0.062500
|
| 287 |
+
26 Causal-Physics Manifestation 2.800000 45.000000 0.000000
|
| 288 |
+
27 Spa-Component Assembly 3.434783 60.869565 0.043478
|
| 289 |
+
28 Spa-Object Arrangement 2.733333 43.333333 0.000000
|
| 290 |
+
29 Spa-Viewpoint Generation 3.629630 65.740741 0.222222
|
| 291 |
+
30 Spa-Structural Inference 4.066667 76.666667 0.133333
|
| 292 |
+
31 Spa-Layout Reasoning 3.234783 55.869565 0.217391
|
| 293 |
+
32 Logic-Pattern Prediction 1.035484 0.887097 0.000000
|
| 294 |
+
33 Logic-Mathematical Derivation 1.350000 8.750000 0.071429
|
| 295 |
+
34 Logic-Puzzle Solving 1.020000 0.500000 0.000000
|
| 296 |
+
</pre>
|
| 297 |
+
</details>
|
| 298 |
+
|
| 299 |
+
<details>
|
| 300 |
+
<summary><b>Results w/ CoT</b></summary>
|
| 301 |
+
<pre>
|
| 302 |
+
- Score-Origin Score-Percentage Accuracy
|
| 303 |
+
0 Overall 2.933333 48.333333 0.119444
|
| 304 |
+
1 Temporal 3.336471 58.411765 0.058824
|
| 305 |
+
2 Causal 3.608889 65.222222 0.177778
|
| 306 |
+
3 Spatial 3.492000 62.300000 0.210000
|
| 307 |
+
4 Logical 1.157647 3.941176 0.011765
|
| 308 |
+
5 Overall_Reasoning 2.836111 45.902778 NaN
|
| 309 |
+
6 Overall_ApprConsistency 3.951841 73.796034 NaN
|
| 310 |
+
7 Overall_VisualPlausibility_total 4.203636 80.090909 NaN
|
| 311 |
+
8 Temporal_Reasoning 3.188235 54.705882 NaN
|
| 312 |
+
9 Temporal_Consistency 4.225000 80.625000 NaN
|
| 313 |
+
10 Temporal_Quality 4.200000 80.000000 NaN
|
| 314 |
+
11 Causal_Reasoning 3.533333 63.333333 NaN
|
| 315 |
+
12 Causal_Consistency 4.386364 84.659091 NaN
|
| 316 |
+
13 Causal_Quality 4.100000 77.500000 NaN
|
| 317 |
+
14 Spatial_Reasoning 3.350000 58.750000 NaN
|
| 318 |
+
15 Spatial_Consistency 4.300000 82.500000 NaN
|
| 319 |
+
16 Spatial_Quality 4.300000 82.500000 NaN
|
| 320 |
+
17 Logical_Reasoning 1.141176 3.529412 NaN
|
| 321 |
+
18 Logical_Consistency 2.835294 45.882353 NaN
|
| 322 |
+
19 Temp-Life Progression 3.526316 63.157895 0.052632
|
| 323 |
+
20 Temp-Material Progression 3.208696 55.217391 0.086957
|
| 324 |
+
21 Temp-Environmental Cycles 3.584615 64.615385 0.000000
|
| 325 |
+
22 Temp-Societal Transformation 3.200000 55.000000 0.000000
|
| 326 |
+
23 Causal-Structural Deformation 3.750000 68.750000 0.138889
|
| 327 |
+
24 Causal-State Transition 3.792000 69.800000 0.320000
|
| 328 |
+
25 Causal-Chemical and Biological Transformation 3.512500 62.812500 0.062500
|
| 329 |
+
26 Causal-Physics Manifestation 2.984615 49.615385 0.153846
|
| 330 |
+
27 Spa-Component Assembly 3.652174 66.304348 0.304348
|
| 331 |
+
28 Spa-Object Arrangement 2.700000 42.500000 0.000000
|
| 332 |
+
29 Spa-Viewpoint Generation 3.800000 70.000000 0.259259
|
| 333 |
+
30 Spa-Structural Inference 3.680000 67.000000 0.266667
|
| 334 |
+
31 Spa-Layout Reasoning 3.260870 56.521739 0.130435
|
| 335 |
+
32 Logic-Pattern Prediction 1.064516 1.612903 0.000000
|
| 336 |
+
33 Logic-Mathematical Derivation 1.707143 17.678571 0.071429
|
| 337 |
+
34 Logic-Puzzle Solving 1.037500 0.937500 0.000000
|
| 338 |
+
</pre>
|
| 339 |
+
</details>
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ImgEdit
|
| 343 |
+
We modify the code in [ImgEdit](https://github.com/PKU-YuanGroup/ImgEdit) for faster evaluation.
|
| 344 |
+
|
| 345 |
+
## Data prepration
|
| 346 |
+
Please download the benchmark data from [ImgEdit-Bench](https://huggingface.co/datasets/sysuyy/ImgEdit/blob/main/Benchmark.tar) and and place it in the `Benchmark` directory.
|
| 347 |
+
|
| 348 |
+
The final directory structure is:
|
| 349 |
+
```shell
|
| 350 |
+
Benchmark
|
| 351 |
+
├── hard
|
| 352 |
+
├── multiturn
|
| 353 |
+
└── singleturn
|
| 354 |
+
├── judge_prompt.json
|
| 355 |
+
├── singleturn.json
|
| 356 |
+
├── animal
|
| 357 |
+
├── architecture
|
| 358 |
+
├── clothes
|
| 359 |
+
├── compose
|
| 360 |
+
├── daily object
|
| 361 |
+
├── for_add
|
| 362 |
+
├── human
|
| 363 |
+
├── style
|
| 364 |
+
└── transport
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
## Evaluation
|
| 368 |
+
Directly run `scripts/eval/run_imgedit.sh` to evaluate ImgEdit-Bench. The output will be saved in `$output_path`.
|
| 369 |
+
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
|
| 370 |
+
- Set `$openai_api_key` in `scripts/eval/run_imgedit.sh` and `your_api_url` in `eval/gen/imgedit/basic_bench.py`. The default GPT version is `gpt-4o-2024-11-20`.
|
| 371 |
+
- We set `cfg_text_scale=4` and `cfg_img_scale=1.5` by default. Additionally, `cfg_renorm_min=0` is specified for CFG Renorm.
|
| 372 |
+
|
| 373 |
+
<details>
|
| 374 |
+
<summary><b>Results</b></summary>
|
| 375 |
+
<pre>
|
| 376 |
+
background: 3.28
|
| 377 |
+
adjust: 3.23
|
| 378 |
+
style: 4.26
|
| 379 |
+
extract: 1.48
|
| 380 |
+
remove: 2.99
|
| 381 |
+
add: 3.45
|
| 382 |
+
replace: 3.76
|
| 383 |
+
compose: 3.18
|
| 384 |
+
action: 4.38
|
| 385 |
+
overall: 3.28
|
| 386 |
+
</pre>
|
| 387 |
+
</details>
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Zebra-CoT: A Dataset for Interleaved Vision-Language Reasoning
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
### BAGEL Training Zebra-CoT
|
| 5 |
+
|
| 6 |
+
This repository is adapted from the [Bagel](https://github.com/ByteDance-Seed/Bagel) repository.
|
| 7 |
+
### Setup
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
git clone https://github.com/multimodal-reasoning-lab/Bagel-Zebra-CoT.git
|
| 11 |
+
cd Bagel-Zebra-CoT
|
| 12 |
+
conda create -n bagel python=3.10 -y
|
| 13 |
+
conda activate bagel
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
pip install flash_attn --no-build-isolation
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
### Download checkpoint
|
| 19 |
+
|
| 20 |
+
Set the `HF_HOME` in `download_model.py` to the path of the checkpoint you want to download.
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
python download_model.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
You can also do this straight from python if your `HF_HOME` has already been set.
|
| 27 |
+
```python
|
| 28 |
+
from huggingface_hub import snapshot_download
|
| 29 |
+
|
| 30 |
+
snapshot_download(
|
| 31 |
+
repo_id="multimodal-reasoning-lab/Bagel-Zebra-CoT",
|
| 32 |
+
local_dir_use_symlinks=False,
|
| 33 |
+
resume_download=True,
|
| 34 |
+
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
|
| 35 |
+
)
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### Inference
|
| 39 |
+
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+
The inference script (`infz_bf16.py`) supports inherent interleaved text and visual reasoning. To customize it for your
|
| 43 |
+
specific use case:
|
| 44 |
+
|
| 45 |
+
##### 1. Model Checkpoint Path
|
| 46 |
+
|
| 47 |
+
Update the checkpoint path to point to your model:
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
checkpoint_dir = "/path/to/your/HF_HOME/models/Bagel-Zebra-CoT"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
For example, under the `HF_HOME`, the path to the checkpoint folder is:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
checkpoint_dir = f"{HF_HOME}/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/c1ff3c56dd5909841523e3a6b554c77d919c2b28
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
You can also use the local dir:
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
checkpoint_dir = f"{HF_HOME}/models/Bagel-Zebra-CoT
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
##### 2. Setting up prompt and images
|
| 66 |
+
|
| 67 |
+
Edit the prompt and image variables in `infz_bf16.py` (around lines 203-211):
|
| 68 |
+
|
| 69 |
+
**For single image problems:**
|
| 70 |
+
```python
|
| 71 |
+
prompt = "Your question here"
|
| 72 |
+
image = Image.open('path/to/your/image.png')
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
**For multiple image problems:**
|
| 76 |
+
```python
|
| 77 |
+
prompt = "Your question about multiple images"
|
| 78 |
+
image_1 = Image.open('path/to/image1.jpg')
|
| 79 |
+
image_2 = Image.open('path/to/image2.jpg')
|
| 80 |
+
image_3 = Image.open('path/to/image3.jpg')
|
| 81 |
+
image = [image_1, image_2, image_3] # List of images
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
**For text-only problems:**
|
| 85 |
+
```python
|
| 86 |
+
prompt = "Your text-only question"
|
| 87 |
+
image = None
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
##### 3. Inference Parameters
|
| 91 |
+
|
| 92 |
+
You can adjust the generation parameters in the `inference_hyper` dictionary:
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
inference_hyper = dict(
|
| 96 |
+
do_sample=True,
|
| 97 |
+
text_temperature=0.3,
|
| 98 |
+
cfg_text_scale=4.0,
|
| 99 |
+
cfg_img_scale=2.0,
|
| 100 |
+
cfg_interval=[0.0, 1.0],
|
| 101 |
+
timestep_shift=3.0,
|
| 102 |
+
num_timesteps=50,
|
| 103 |
+
cfg_renorm_min=0.0,
|
| 104 |
+
cfg_renorm_type="text_channel",
|
| 105 |
+
)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
For details, refer to the original jupyter notebook [here](inference.ipynb).
|
| 109 |
+
|
| 110 |
+
#### Example Use Cases
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
prompt = "Subtract all cylinders. Add 1 red sphere. How many objects are left?"
|
| 114 |
+
image = Image.open('test_images/image.png')
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Training
|
| 118 |
+
For training, run
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
bash scripts/train.sh
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
For details, please refer to the original repo [README](https://github.com/bytedance-seed/BAGEL).
|
| 125 |
+
|
| 126 |
+
The interleaved reasoning data customized for Zebra-CoT can be found in [think_trace_dataset.py](data/interleave_datasets/think_trace_dataset.py).
|
| 127 |
+
|
| 128 |
+
### Cite
|
| 129 |
+
```bibtex
|
| 130 |
+
@misc{li2025zebracot,
|
| 131 |
+
title={Zebra-CoT: A Dataset for Interleaved Vision Language Reasoning},
|
| 132 |
+
author={Ang Li and Charles Wang and Kaiyu Yue and Zikui Cai and Ollie Liu and Deqing Fu and Peng Guo and Wang Bill Zhu and Vatsal Sharan and Robin Jia and Willie Neiswanger and Furong Huang and Tom Goldstein and Micah Goldblum},
|
| 133 |
+
year={2025},
|
| 134 |
+
eprint={2507.16746},
|
| 135 |
+
archivePrefix={arXiv},
|
| 136 |
+
primaryClass={cs.CV},
|
| 137 |
+
url={https://arxiv.org/abs/2507.16746},
|
| 138 |
+
}
|
| 139 |
+
```
|
TRAIN.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data prepration
|
| 2 |
+
|
| 3 |
+
We provide data examples for **T2I**, **Editing**, and **VLM** tasks. The T2I dataset is generated using [FLUX.1‑dev](https://huggingface.co/black-forest-labs/FLUX.1-dev); the editing examples are randomly sampled from [SEED‑Data‑Edit‑Part3](https://huggingface.co/datasets/AILab-CVC/SEED-Data-Edit-Part2-3); and the VLM set is sourced from [LLaVA‑OneVision‑Data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data).
|
| 4 |
+
|
| 5 |
+
We offer examples in both raw-image folder and parquet shard formats. For other data formats, you can use our dataset code as a template and extend it as needed.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
1. **Download the sample dataset**
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
wget -O bagel_example.zip \
|
| 12 |
+
https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/bagel_example.zip
|
| 13 |
+
unzip bagel_example.zip -d /data
|
| 14 |
+
```
|
| 15 |
+
2. **Expected hierarchy**
|
| 16 |
+
|
| 17 |
+
```text
|
| 18 |
+
bagel_example
|
| 19 |
+
├── t2i/ # text-to-image (parquet)
|
| 20 |
+
├── editing/ # image editing (parquet)
|
| 21 |
+
│ ├── seedxedit_multi/
|
| 22 |
+
│ └── parquet_info/
|
| 23 |
+
└── vlm/
|
| 24 |
+
├── images/ # JPEG / PNG frames
|
| 25 |
+
└── llava_ov_si.jsonl # vision‑language SFT conversations
|
| 26 |
+
```
|
| 27 |
+
3. Edit every `your_data_path` placeholder in **`data/dataset_info.py`**.
|
| 28 |
+
4. *(Optional)* Extend `DATASET_INFO` with your own parquet shards or JSONL files to mix extra data.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
# Training
|
| 33 |
+
|
| 34 |
+
The baseline training recipe looks like this (replace environment variables with real paths or values):
|
| 35 |
+
|
| 36 |
+
```shell
|
| 37 |
+
# Pre-training
|
| 38 |
+
torchrun \
|
| 39 |
+
--nnodes=$num_nodes \
|
| 40 |
+
--node_rank=$node_rank \
|
| 41 |
+
--nproc_per_node=8 \
|
| 42 |
+
--master_addr=$master_addr \
|
| 43 |
+
--master_port=$master_port \
|
| 44 |
+
train/pretrain_unified_navit.py \
|
| 45 |
+
--dataset_config_file ./data/configs/example.yaml \
|
| 46 |
+
--llm_path $llm_path \
|
| 47 |
+
--vae_path $vae_path \
|
| 48 |
+
--vit_path $vit_path \
|
| 49 |
+
--layer_module Qwen2MoTDecoderLayer \
|
| 50 |
+
--use_flex True \
|
| 51 |
+
--resume_from $resume_from \
|
| 52 |
+
--results_dir $output_path \
|
| 53 |
+
--checkpoint_dir $ckpt_path \
|
| 54 |
+
--max_latent_size 64 # 32 for low-resolution pre-training
|
| 55 |
+
|
| 56 |
+
# Fine-tuning
|
| 57 |
+
torchrun \
|
| 58 |
+
--nnodes=$num_nodes \
|
| 59 |
+
--node_rank=$node_rank \
|
| 60 |
+
--nproc_per_node=8 \
|
| 61 |
+
--master_addr=$master_addr \
|
| 62 |
+
--master_port=$master_port \
|
| 63 |
+
train/pretrain_unified_navit.py \
|
| 64 |
+
--dataset_config_file ./data/configs/example.yaml \
|
| 65 |
+
--model_path $model_path \
|
| 66 |
+
--layer_module Qwen2MoTDecoderLayer \
|
| 67 |
+
--max_latent_size 64 \
|
| 68 |
+
--resume-from $model_path \
|
| 69 |
+
--finetune_from_hf True \
|
| 70 |
+
--auto_resume True \
|
| 71 |
+
--resume-model-only True \
|
| 72 |
+
--finetune-from-ema True \
|
| 73 |
+
--log_every 1 \
|
| 74 |
+
--lr 2e-5 \
|
| 75 |
+
--num_worker 1 \
|
| 76 |
+
--expected_num_tokens 10240 \
|
| 77 |
+
--max_num_tokens 11520 \
|
| 78 |
+
--max_num_tokens_per_sample 10240
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
- **When fine-tuning BAGEL, set `max_latent_size=64` to ensure the correct pretrained weights are loaded.** If this is not set, an out-of-bounds error may occur.
|
| 82 |
+
- The total value of `num_used_data` should be greater than `NUM_GPUS × NUM_WORKERS`. (For toy data, use `num_worker=1`.)
|
| 83 |
+
- For T2I-only fine-tuning, set `visual_und=False`. For VLM-only fine-tuning, set `visual_gen=False`.
|
| 84 |
+
- For debugging purposes, use smaller values for `expected_num_tokens`, `max_num_tokens`, and `max_num_tokens_per_sample`.
|
| 85 |
+
- When fine-tuning on toy data, the loss behaves as follows:
|
| 86 |
+
```shell
|
| 87 |
+
[2025-05-25 17:01:37] (step=0000000) Train Loss mse: 0.4063, Train Loss ce: 0.5504, Train Steps/Sec: 0.01,
|
| 88 |
+
[2025-05-25 17:01:40] (step=0000001) Train Loss mse: 0.4121, Train Loss ce: 0.8152, Train Steps/Sec: 0.44,
|
| 89 |
+
[2025-05-25 17:01:42] (step=0000002) Train Loss mse: 0.3876, Train Loss ce: 1.3411, Train Steps/Sec: 0.40,
|
| 90 |
+
[2025-05-25 17:01:45] (step=0000003) Train Loss mse: 0.3825, Train Loss ce: 0.7360, Train Steps/Sec: 0.44,
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
You are encouraged to adjust any of these hyperparameters to fit your GPU budget and the scale of your dataset. If you encounter any issues, please open an issue for assistance. 🎉
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
## Model config
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
| Argument | Default | Description |
|
| 101 |
+
| ---------------------------- | ------------------------------------------- | --------------------------------------------------------------- |
|
| 102 |
+
| `llm_path` | `hf/Qwen2.5-0.5B-Instruct` | Language‑model backbone (HuggingFace repo or local folder). |
|
| 103 |
+
| `vae_path` | `flux/vae/ae.safetensors` | Pre‑trained VAE checkpoint for latent diffusion. |
|
| 104 |
+
| `vit_path` | `hf/siglip-so400m-14-980-flash-attn2-navit` | SigLIP ViT used for image understanding. |
|
| 105 |
+
| `max_latent_size` | `32` | Maximum latent grid side; defines highest generable resolution. |
|
| 106 |
+
| `latent_patch_size` | `2` | VAE pixels represented by one latent patch. |
|
| 107 |
+
| `vit_max_num_patch_per_side` | `70` | Max ViT patches per image side after resizing. |
|
| 108 |
+
| `text_cond_dropout_prob` | `0.1` | Probability to drop text conditioning while training. |
|
| 109 |
+
| `vae_cond_dropout_prob` | `0.3` | Dropout on VAE latent inputs. |
|
| 110 |
+
| `vit_cond_dropout_prob` | `0.3` | Dropout on visual features. |
|
| 111 |
+
|
| 112 |
+
*(See `ModelArguments` for many more options.)*
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
## Data config
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
| Argument | Default | Description |
|
| 119 |
+
| --------------------------- | --------------------------- | --------------------------------------------------------- |
|
| 120 |
+
| `dataset_config_file` | `data/configs/example.yaml` | YAML that groups datasets and assigns sampling weights. |
|
| 121 |
+
| `num_workers` | `4` | Background workers per rank for the PyTorch `DataLoader`. |
|
| 122 |
+
| `prefetch_factor` | `2` | Batches pre‑fetched by each worker. |
|
| 123 |
+
| `max_num_tokens_per_sample` | `16384` | Skip raw samples longer than this. |
|
| 124 |
+
| `max_num_tokens` | `36864` | Hard cap for a packed batch (prevents OOM). |
|
| 125 |
+
| `max_buffer_size` | `50` | Overflow buffer length for oversized samples. |
|
| 126 |
+
| `data_seed` | `42` | Seed for reproducible shuffling and sampling. |
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
## Training config
|
| 130 |
+
|
| 131 |
+
| Argument | Default | Description |
|
| 132 |
+
| -------------------------------------- | ---------------------- | ------------------------------------------------------ |
|
| 133 |
+
| `total_steps` | `500_000` | Optimiser steps to run. |
|
| 134 |
+
| `lr` | `1e-4` | Peak learning rate after warm‑up. |
|
| 135 |
+
| `lr_scheduler` | `constant` | Learning‑rate schedule (`constant` or `cosine`). |
|
| 136 |
+
| `warmup_steps` | `2000` | Linear warm‑up duration. |
|
| 137 |
+
| `ema` | `0.9999` | Exponential moving‑average decay for model weights. |
|
| 138 |
+
| `max_grad_norm` | `1.0` | Gradient‑clipping threshold. |
|
| 139 |
+
| `save_every` | `2000` | Checkpoint frequency (steps). |
|
| 140 |
+
| `visual_gen / visual_und` | `True` | Enable image generation / understanding branches. |
|
| 141 |
+
| `freeze_llm / freeze_vit / freeze_vae` | `False / False / True` | Freeze selected modules to save VRAM or for ablations. |
|
| 142 |
+
| `use_flex` | `True` (in example) | Enable FLEX packing for higher GPU utilisation. |
|
| 143 |
+
| `sharding_strategy` | `HYBRID_SHARD` | FSDP sharding mode. |
|
| 144 |
+
| `num_shard` | `8` | Parameter shards per rank in HYBRID mode. |
|
| 145 |
+
|
| 146 |
+
**Distributed‑launch environment variables**
|
| 147 |
+
|
| 148 |
+
| Var | Meaning |
|
| 149 |
+
| ----------------------------- | --------------------------------- |
|
| 150 |
+
| `num_nodes` / `node_rank` | Multi‑node orchestration indices. |
|
| 151 |
+
| `nproc_per_node` | Number of GPUs per node. |
|
| 152 |
+
| `master_addr` / `master_port` | NCCL rendezvous endpoint. |
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
## Logging config
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
| Argument | Default | Description |
|
| 159 |
+
| ---------------- | --------------------- | ---------------------------------------------------- |
|
| 160 |
+
| `results_dir` | `results` | Root directory for logs and metrics. |
|
| 161 |
+
| `checkpoint_dir` | `results/checkpoints` | Checkpoints are saved here. |
|
| 162 |
+
| `log_every` | `10` | Steps between console / W\&B logs. |
|
| 163 |
+
| `wandb_project` | `bagel` | Weights & Biases project name. |
|
| 164 |
+
| `wandb_name` | `run` | Run name inside the project. |
|
| 165 |
+
| `wandb_offline` | `False` | Switch to offline mode (logs locally, sync later). |
|
| 166 |
+
| `wandb_resume` | `allow` | Resumption policy if an existing run ID is detected. |
|
| 167 |
+
|
| 168 |
+
> **Tip** Export `WANDB_API_KEY` before launching if you want online dashboards.
|
app.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from data.data_utils import add_special_tokens, pil_img2rgb
|
| 11 |
+
from data.transforms import ImageTransform
|
| 12 |
+
from inferencer import InterleaveInferencer
|
| 13 |
+
from modeling.autoencoder import load_ae
|
| 14 |
+
from modeling.bagel.qwen2_navit import NaiveCache
|
| 15 |
+
from modeling.bagel import (
|
| 16 |
+
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
|
| 17 |
+
SiglipVisionConfig, SiglipVisionModel
|
| 18 |
+
)
|
| 19 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
| 27 |
+
parser.add_argument("--server_port", type=int, default=7860)
|
| 28 |
+
parser.add_argument("--share", action="store_true")
|
| 29 |
+
parser.add_argument("--model_path", type=str, default="models/BAGEL-7B-MoT")
|
| 30 |
+
parser.add_argument("--mode", type=int, default=1)
|
| 31 |
+
parser.add_argument("--zh", action="store_true")
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
# Model Initialization
|
| 35 |
+
model_path = args.model_path #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT to models/BAGEL-7B-MoT
|
| 36 |
+
|
| 37 |
+
model_path = args.model_path
|
| 38 |
+
|
| 39 |
+
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
|
| 40 |
+
llm_config.qk_norm = True
|
| 41 |
+
llm_config.tie_word_embeddings = False
|
| 42 |
+
llm_config.layer_module = "Qwen2MoTDecoderLayer"
|
| 43 |
+
|
| 44 |
+
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
|
| 45 |
+
vit_config.rope = False
|
| 46 |
+
vit_config.num_hidden_layers -= 1
|
| 47 |
+
|
| 48 |
+
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
|
| 49 |
+
|
| 50 |
+
config = BagelConfig(
|
| 51 |
+
visual_gen=True,
|
| 52 |
+
visual_und=True,
|
| 53 |
+
llm_config=llm_config,
|
| 54 |
+
vit_config=vit_config,
|
| 55 |
+
vae_config=vae_config,
|
| 56 |
+
vit_max_num_patch_per_side=70,
|
| 57 |
+
connector_act='gelu_pytorch_tanh',
|
| 58 |
+
latent_patch_size=2,
|
| 59 |
+
max_latent_size=64,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with init_empty_weights():
|
| 63 |
+
language_model = Qwen2ForCausalLM(llm_config)
|
| 64 |
+
vit_model = SiglipVisionModel(vit_config)
|
| 65 |
+
model = Bagel(language_model, vit_model, config)
|
| 66 |
+
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
|
| 67 |
+
|
| 68 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
|
| 69 |
+
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
|
| 70 |
+
|
| 71 |
+
vae_transform = ImageTransform(1024, 512, 16)
|
| 72 |
+
vit_transform = ImageTransform(980, 224, 14)
|
| 73 |
+
|
| 74 |
+
# Model Loading and Multi GPU Infernece Preparing
|
| 75 |
+
device_map = infer_auto_device_map(
|
| 76 |
+
model,
|
| 77 |
+
max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
|
| 78 |
+
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
same_device_modules = [
|
| 82 |
+
'language_model.model.embed_tokens',
|
| 83 |
+
'time_embedder',
|
| 84 |
+
'latent_pos_embed',
|
| 85 |
+
'vae2llm',
|
| 86 |
+
'llm2vae',
|
| 87 |
+
'connector',
|
| 88 |
+
'vit_pos_embed'
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
if torch.cuda.device_count() == 1:
|
| 92 |
+
first_device = device_map.get(same_device_modules[0], "cuda:0")
|
| 93 |
+
for k in same_device_modules:
|
| 94 |
+
if k in device_map:
|
| 95 |
+
device_map[k] = first_device
|
| 96 |
+
else:
|
| 97 |
+
device_map[k] = "cuda:0"
|
| 98 |
+
else:
|
| 99 |
+
first_device = device_map.get(same_device_modules[0])
|
| 100 |
+
for k in same_device_modules:
|
| 101 |
+
if k in device_map:
|
| 102 |
+
device_map[k] = first_device
|
| 103 |
+
|
| 104 |
+
if args.mode == 1:
|
| 105 |
+
model = load_checkpoint_and_dispatch(
|
| 106 |
+
model,
|
| 107 |
+
checkpoint=os.path.join(model_path, "ema.safetensors"),
|
| 108 |
+
device_map=device_map,
|
| 109 |
+
offload_buffers=True,
|
| 110 |
+
offload_folder="offload",
|
| 111 |
+
dtype=torch.bfloat16,
|
| 112 |
+
force_hooks=True,
|
| 113 |
+
).eval()
|
| 114 |
+
elif args.mode == 2: # NF4
|
| 115 |
+
bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4")
|
| 116 |
+
model = load_and_quantize_model(
|
| 117 |
+
model,
|
| 118 |
+
weights_location=os.path.join(model_path, "ema.safetensors"),
|
| 119 |
+
bnb_quantization_config=bnb_quantization_config,
|
| 120 |
+
device_map=device_map,
|
| 121 |
+
offload_folder="offload",
|
| 122 |
+
).eval()
|
| 123 |
+
elif args.mode == 3: # INT8
|
| 124 |
+
bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, torch_dtype=torch.float32)
|
| 125 |
+
model = load_and_quantize_model(
|
| 126 |
+
model,
|
| 127 |
+
weights_location=os.path.join(model_path, "ema.safetensors"),
|
| 128 |
+
bnb_quantization_config=bnb_quantization_config,
|
| 129 |
+
device_map=device_map,
|
| 130 |
+
offload_folder="offload",
|
| 131 |
+
).eval()
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError
|
| 134 |
+
|
| 135 |
+
# Inferencer Preparing
|
| 136 |
+
inferencer = InterleaveInferencer(
|
| 137 |
+
model=model,
|
| 138 |
+
vae_model=vae_model,
|
| 139 |
+
tokenizer=tokenizer,
|
| 140 |
+
vae_transform=vae_transform,
|
| 141 |
+
vit_transform=vit_transform,
|
| 142 |
+
new_token_ids=new_token_ids,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def set_seed(seed):
|
| 147 |
+
"""Set random seeds for reproducibility"""
|
| 148 |
+
if seed > 0:
|
| 149 |
+
random.seed(seed)
|
| 150 |
+
np.random.seed(seed)
|
| 151 |
+
torch.manual_seed(seed)
|
| 152 |
+
if torch.cuda.is_available():
|
| 153 |
+
torch.cuda.manual_seed(seed)
|
| 154 |
+
torch.cuda.manual_seed_all(seed)
|
| 155 |
+
torch.backends.cudnn.deterministic = True
|
| 156 |
+
torch.backends.cudnn.benchmark = False
|
| 157 |
+
return seed
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Text to Image function with thinking option and hyperparameters
|
| 161 |
+
def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
|
| 162 |
+
timestep_shift=3.0, num_timesteps=50,
|
| 163 |
+
cfg_renorm_min=0.0, cfg_renorm_type="global",
|
| 164 |
+
max_think_token_n=1024, do_sample=False, text_temperature=0.3,
|
| 165 |
+
seed=0, image_ratio="1:1"):
|
| 166 |
+
# Set seed for reproducibility
|
| 167 |
+
set_seed(seed)
|
| 168 |
+
|
| 169 |
+
if image_ratio == "1:1":
|
| 170 |
+
image_shapes = (1024, 1024)
|
| 171 |
+
elif image_ratio == "4:3":
|
| 172 |
+
image_shapes = (768, 1024)
|
| 173 |
+
elif image_ratio == "3:4":
|
| 174 |
+
image_shapes = (1024, 768)
|
| 175 |
+
elif image_ratio == "16:9":
|
| 176 |
+
image_shapes = (576, 1024)
|
| 177 |
+
elif image_ratio == "9:16":
|
| 178 |
+
image_shapes = (1024, 576)
|
| 179 |
+
|
| 180 |
+
# Set hyperparameters
|
| 181 |
+
inference_hyper = dict(
|
| 182 |
+
max_think_token_n=max_think_token_n if show_thinking else 1024,
|
| 183 |
+
do_sample=do_sample if show_thinking else False,
|
| 184 |
+
text_temperature=text_temperature if show_thinking else 0.3,
|
| 185 |
+
cfg_text_scale=cfg_text_scale,
|
| 186 |
+
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
|
| 187 |
+
timestep_shift=timestep_shift,
|
| 188 |
+
num_timesteps=num_timesteps,
|
| 189 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 190 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 191 |
+
image_shapes=image_shapes,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Call inferencer with or without think parameter based on user choice
|
| 195 |
+
result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
|
| 196 |
+
return result["image"], result.get("text", None)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Image Understanding function with thinking option and hyperparameters
|
| 200 |
+
def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
|
| 201 |
+
do_sample=False, text_temperature=0.3, max_new_tokens=512):
|
| 202 |
+
if image is None:
|
| 203 |
+
return "Please upload an image."
|
| 204 |
+
|
| 205 |
+
if isinstance(image, np.ndarray):
|
| 206 |
+
image = Image.fromarray(image)
|
| 207 |
+
|
| 208 |
+
image = pil_img2rgb(image)
|
| 209 |
+
|
| 210 |
+
# Set hyperparameters
|
| 211 |
+
inference_hyper = dict(
|
| 212 |
+
do_sample=do_sample,
|
| 213 |
+
text_temperature=text_temperature,
|
| 214 |
+
max_think_token_n=max_new_tokens, # Set max_length
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Use show_thinking parameter to control thinking process
|
| 218 |
+
result = inferencer(image=image, text=prompt, think=show_thinking,
|
| 219 |
+
understanding_output=True, **inference_hyper)
|
| 220 |
+
return result["text"]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# Image Editing function with thinking option and hyperparameters
|
| 224 |
+
def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
|
| 225 |
+
cfg_img_scale=2.0, cfg_interval=0.0,
|
| 226 |
+
timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=0.0,
|
| 227 |
+
cfg_renorm_type="text_channel", max_think_token_n=1024,
|
| 228 |
+
do_sample=False, text_temperature=0.3, seed=0):
|
| 229 |
+
# Set seed for reproducibility
|
| 230 |
+
set_seed(seed)
|
| 231 |
+
|
| 232 |
+
if image is None:
|
| 233 |
+
return "Please upload an image.", ""
|
| 234 |
+
|
| 235 |
+
if isinstance(image, np.ndarray):
|
| 236 |
+
image = Image.fromarray(image)
|
| 237 |
+
|
| 238 |
+
image = pil_img2rgb(image)
|
| 239 |
+
|
| 240 |
+
# Set hyperparameters
|
| 241 |
+
inference_hyper = dict(
|
| 242 |
+
max_think_token_n=max_think_token_n if show_thinking else 1024,
|
| 243 |
+
do_sample=do_sample if show_thinking else False,
|
| 244 |
+
text_temperature=text_temperature if show_thinking else 0.3,
|
| 245 |
+
cfg_text_scale=cfg_text_scale,
|
| 246 |
+
cfg_img_scale=cfg_img_scale,
|
| 247 |
+
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
|
| 248 |
+
timestep_shift=timestep_shift,
|
| 249 |
+
num_timesteps=num_timesteps,
|
| 250 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 251 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Include thinking parameter based on user choice
|
| 255 |
+
result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
|
| 256 |
+
return result["image"], result.get("text", "")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Helper function to load example images
|
| 260 |
+
def load_example_image(image_path):
|
| 261 |
+
try:
|
| 262 |
+
return Image.open(image_path)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"Error loading example image: {e}")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# Gradio UI
|
| 269 |
+
with gr.Blocks() as demo:
|
| 270 |
+
gr.Markdown("""
|
| 271 |
+
<div>
|
| 272 |
+
<img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
|
| 273 |
+
</div>
|
| 274 |
+
""")
|
| 275 |
+
|
| 276 |
+
with gr.Tab("📝 Text to Image"):
|
| 277 |
+
txt_input = gr.Textbox(
|
| 278 |
+
label="Prompt",
|
| 279 |
+
value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
with gr.Row():
|
| 283 |
+
show_thinking = gr.Checkbox(label="Thinking", value=False)
|
| 284 |
+
|
| 285 |
+
# Add hyperparameter controls in an accordion
|
| 286 |
+
with gr.Accordion("Inference Hyperparameters", open=False):
|
| 287 |
+
with gr.Group():
|
| 288 |
+
with gr.Row():
|
| 289 |
+
seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
|
| 290 |
+
label="Seed", info="0 for random seed, positive for reproducible results")
|
| 291 |
+
image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
|
| 292 |
+
value="1:1", label="Image Ratio",
|
| 293 |
+
info="The longer size is fixed to 1024")
|
| 294 |
+
|
| 295 |
+
with gr.Row():
|
| 296 |
+
cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
|
| 297 |
+
label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
|
| 298 |
+
cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
|
| 299 |
+
label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
|
| 300 |
+
|
| 301 |
+
with gr.Row():
|
| 302 |
+
cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
|
| 303 |
+
value="global", label="CFG Renorm Type",
|
| 304 |
+
info="If the genrated image is blurry, use 'global'")
|
| 305 |
+
cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
|
| 306 |
+
label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
|
| 307 |
+
|
| 308 |
+
with gr.Row():
|
| 309 |
+
num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
|
| 310 |
+
label="Timesteps", info="Total denoising steps")
|
| 311 |
+
timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
|
| 312 |
+
label="Timestep Shift", info="Higher values for layout, lower for details")
|
| 313 |
+
|
| 314 |
+
# Thinking parameters in a single row
|
| 315 |
+
thinking_params = gr.Group(visible=False)
|
| 316 |
+
with thinking_params:
|
| 317 |
+
with gr.Row():
|
| 318 |
+
do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
|
| 319 |
+
max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
|
| 320 |
+
label="Max Think Tokens", info="Maximum number of tokens for thinking")
|
| 321 |
+
text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
|
| 322 |
+
label="Temperature", info="Controls randomness in text generation")
|
| 323 |
+
|
| 324 |
+
thinking_output = gr.Textbox(label="Thinking Process", visible=False)
|
| 325 |
+
img_output = gr.Image(label="Generated Image")
|
| 326 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
| 327 |
+
|
| 328 |
+
# Dynamically show/hide thinking process box and parameters
|
| 329 |
+
def update_thinking_visibility(show):
|
| 330 |
+
return gr.update(visible=show), gr.update(visible=show)
|
| 331 |
+
|
| 332 |
+
show_thinking.change(
|
| 333 |
+
fn=update_thinking_visibility,
|
| 334 |
+
inputs=[show_thinking],
|
| 335 |
+
outputs=[thinking_output, thinking_params]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Process function based on thinking option and hyperparameters
|
| 339 |
+
def process_text_to_image(prompt, show_thinking, cfg_text_scale,
|
| 340 |
+
cfg_interval, timestep_shift,
|
| 341 |
+
num_timesteps, cfg_renorm_min, cfg_renorm_type,
|
| 342 |
+
max_think_token_n, do_sample, text_temperature, seed, image_ratio):
|
| 343 |
+
image, thinking = text_to_image(
|
| 344 |
+
prompt, show_thinking, cfg_text_scale, cfg_interval,
|
| 345 |
+
timestep_shift, num_timesteps,
|
| 346 |
+
cfg_renorm_min, cfg_renorm_type,
|
| 347 |
+
max_think_token_n, do_sample, text_temperature, seed, image_ratio
|
| 348 |
+
)
|
| 349 |
+
return image, thinking if thinking else ""
|
| 350 |
+
|
| 351 |
+
gr.on(
|
| 352 |
+
triggers=[gen_btn.click, txt_input.submit],
|
| 353 |
+
fn=process_text_to_image,
|
| 354 |
+
inputs=[
|
| 355 |
+
txt_input, show_thinking, cfg_text_scale,
|
| 356 |
+
cfg_interval, timestep_shift,
|
| 357 |
+
num_timesteps, cfg_renorm_min, cfg_renorm_type,
|
| 358 |
+
max_think_token_n, do_sample, text_temperature, seed, image_ratio
|
| 359 |
+
],
|
| 360 |
+
outputs=[img_output, thinking_output]
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
with gr.Tab("🖌️ Image Edit"):
|
| 364 |
+
with gr.Row():
|
| 365 |
+
with gr.Column(scale=1):
|
| 366 |
+
edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
|
| 367 |
+
edit_prompt = gr.Textbox(
|
| 368 |
+
label="Prompt",
|
| 369 |
+
value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
with gr.Column(scale=1):
|
| 373 |
+
edit_image_output = gr.Image(label="Result")
|
| 374 |
+
edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
|
| 375 |
+
|
| 376 |
+
with gr.Row():
|
| 377 |
+
edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
|
| 378 |
+
|
| 379 |
+
# Add hyperparameter controls in an accordion
|
| 380 |
+
with gr.Accordion("Inference Hyperparameters", open=False):
|
| 381 |
+
with gr.Group():
|
| 382 |
+
with gr.Row():
|
| 383 |
+
edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
|
| 384 |
+
label="Seed", info="0 for random seed, positive for reproducible results")
|
| 385 |
+
edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
|
| 386 |
+
label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
|
| 387 |
+
|
| 388 |
+
with gr.Row():
|
| 389 |
+
edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
|
| 390 |
+
label="CFG Image Scale", info="Controls how much the model preserves input image details")
|
| 391 |
+
edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
|
| 392 |
+
label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
|
| 393 |
+
|
| 394 |
+
with gr.Row():
|
| 395 |
+
edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
|
| 396 |
+
value="text_channel", label="CFG Renorm Type",
|
| 397 |
+
info="If the genrated image is blurry, use 'global'")
|
| 398 |
+
edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
|
| 399 |
+
label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
|
| 400 |
+
|
| 401 |
+
with gr.Row():
|
| 402 |
+
edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
|
| 403 |
+
label="Timesteps", info="Total denoising steps")
|
| 404 |
+
edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
|
| 405 |
+
label="Timestep Shift", info="Higher values for layout, lower for details")
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# Thinking parameters in a single row
|
| 409 |
+
edit_thinking_params = gr.Group(visible=False)
|
| 410 |
+
with edit_thinking_params:
|
| 411 |
+
with gr.Row():
|
| 412 |
+
edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
|
| 413 |
+
edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
|
| 414 |
+
label="Max Think Tokens", info="Maximum number of tokens for thinking")
|
| 415 |
+
edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
|
| 416 |
+
label="Temperature", info="Controls randomness in text generation")
|
| 417 |
+
|
| 418 |
+
edit_btn = gr.Button("Submit", variant="primary")
|
| 419 |
+
|
| 420 |
+
# Dynamically show/hide thinking process box for editing
|
| 421 |
+
def update_edit_thinking_visibility(show):
|
| 422 |
+
return gr.update(visible=show), gr.update(visible=show)
|
| 423 |
+
|
| 424 |
+
edit_show_thinking.change(
|
| 425 |
+
fn=update_edit_thinking_visibility,
|
| 426 |
+
inputs=[edit_show_thinking],
|
| 427 |
+
outputs=[edit_thinking_output, edit_thinking_params]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Process editing with thinking option and hyperparameters
|
| 431 |
+
def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
|
| 432 |
+
cfg_img_scale, cfg_interval,
|
| 433 |
+
timestep_shift, num_timesteps, cfg_renorm_min,
|
| 434 |
+
cfg_renorm_type, max_think_token_n, do_sample,
|
| 435 |
+
text_temperature, seed):
|
| 436 |
+
edited_image, thinking = edit_image(
|
| 437 |
+
image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
|
| 438 |
+
cfg_interval, timestep_shift,
|
| 439 |
+
num_timesteps, cfg_renorm_min, cfg_renorm_type,
|
| 440 |
+
max_think_token_n, do_sample, text_temperature, seed
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return edited_image, thinking if thinking else ""
|
| 444 |
+
|
| 445 |
+
gr.on(
|
| 446 |
+
triggers=[edit_btn.click, edit_prompt.submit],
|
| 447 |
+
fn=process_edit_image,
|
| 448 |
+
inputs=[
|
| 449 |
+
edit_image_input, edit_prompt, edit_show_thinking,
|
| 450 |
+
edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
|
| 451 |
+
edit_timestep_shift, edit_num_timesteps,
|
| 452 |
+
edit_cfg_renorm_min, edit_cfg_renorm_type,
|
| 453 |
+
edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
|
| 454 |
+
],
|
| 455 |
+
outputs=[edit_image_output, edit_thinking_output]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
with gr.Tab("🖼️ Image Understanding"):
|
| 459 |
+
with gr.Row():
|
| 460 |
+
with gr.Column(scale=1):
|
| 461 |
+
img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
|
| 462 |
+
understand_prompt = gr.Textbox(
|
| 463 |
+
label="Prompt",
|
| 464 |
+
value="Can someone explain what's funny about this meme??"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
with gr.Column(scale=1):
|
| 468 |
+
txt_output = gr.Textbox(label="Result", lines=20)
|
| 469 |
+
|
| 470 |
+
with gr.Row():
|
| 471 |
+
understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
|
| 472 |
+
|
| 473 |
+
# Add hyperparameter controls in an accordion
|
| 474 |
+
with gr.Accordion("Inference Hyperparameters", open=False):
|
| 475 |
+
with gr.Row():
|
| 476 |
+
understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
|
| 477 |
+
understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
|
| 478 |
+
label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
|
| 479 |
+
understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
|
| 480 |
+
label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
|
| 481 |
+
|
| 482 |
+
img_understand_btn = gr.Button("Submit", variant="primary")
|
| 483 |
+
|
| 484 |
+
# Process understanding with thinking option and hyperparameters
|
| 485 |
+
def process_understanding(image, prompt, show_thinking, do_sample,
|
| 486 |
+
text_temperature, max_new_tokens):
|
| 487 |
+
result = image_understanding(
|
| 488 |
+
image, prompt, show_thinking, do_sample,
|
| 489 |
+
text_temperature, max_new_tokens
|
| 490 |
+
)
|
| 491 |
+
return result
|
| 492 |
+
|
| 493 |
+
gr.on(
|
| 494 |
+
triggers=[img_understand_btn.click, understand_prompt.submit],
|
| 495 |
+
fn=process_understanding,
|
| 496 |
+
inputs=[
|
| 497 |
+
img_input, understand_prompt, understand_show_thinking,
|
| 498 |
+
understand_do_sample, understand_text_temperature, understand_max_new_tokens
|
| 499 |
+
],
|
| 500 |
+
outputs=txt_output
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
gr.Markdown("""
|
| 504 |
+
<div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
|
| 505 |
+
<a href="https://bagel-ai.org/">
|
| 506 |
+
<img
|
| 507 |
+
src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
|
| 508 |
+
alt="BAGEL Website"
|
| 509 |
+
/>
|
| 510 |
+
</a>
|
| 511 |
+
<a href="https://arxiv.org/abs/2505.14683">
|
| 512 |
+
<img
|
| 513 |
+
src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
|
| 514 |
+
alt="BAGEL Paper on arXiv"
|
| 515 |
+
/>
|
| 516 |
+
</a>
|
| 517 |
+
<a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
|
| 518 |
+
<img
|
| 519 |
+
src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
|
| 520 |
+
alt="BAGEL on Hugging Face"
|
| 521 |
+
/>
|
| 522 |
+
</a>
|
| 523 |
+
<a href="https://demo.bagel-ai.org/">
|
| 524 |
+
<img
|
| 525 |
+
src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
|
| 526 |
+
alt="BAGEL Demo"
|
| 527 |
+
/>
|
| 528 |
+
</a>
|
| 529 |
+
<a href="https://discord.gg/Z836xxzy">
|
| 530 |
+
<img
|
| 531 |
+
src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
|
| 532 |
+
alt="BAGEL Discord"
|
| 533 |
+
/>
|
| 534 |
+
</a>
|
| 535 |
+
<a href="mailto:[email protected]">
|
| 536 |
+
<img
|
| 537 |
+
src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
|
| 538 |
+
alt="BAGEL Email"
|
| 539 |
+
/>
|
| 540 |
+
</a>
|
| 541 |
+
</div>
|
| 542 |
+
""")
|
| 543 |
+
|
| 544 |
+
UI_TRANSLATIONS = {
|
| 545 |
+
"📝 Text to Image":"📝 文生图",
|
| 546 |
+
"Prompt":"提示词",
|
| 547 |
+
"Thinking":"思考模式",
|
| 548 |
+
"Inference Hyperparameters":"推理参数",
|
| 549 |
+
"Seed":"随机种子",
|
| 550 |
+
"0 for random seed, positive for reproducible results":"0为随机种子,正数表示可重复结果",
|
| 551 |
+
"Image Ratio":"图片比例",
|
| 552 |
+
"The longer size is fixed to 1024":"长边固定为1024",
|
| 553 |
+
"CFG Text Scale":"文本CFG强度",
|
| 554 |
+
"Controls how strongly the model follows the text prompt (4.0-8.0)":"控制模型是否遵循文本提示(4.0-8.0)",
|
| 555 |
+
"CFG Interval":"CFG应用间隔",
|
| 556 |
+
"Start of CFG application interval (end is fixed at 1.0)":"CFG应用间隔的开始(结束固定为1.0)",
|
| 557 |
+
"CFG Renorm Type":"CFG 重归一化类型",
|
| 558 |
+
"If the genrated image is blurry, use 'global'":"如果生成的图像模糊,请使用'global'",
|
| 559 |
+
"CFG Renorm Min":"CFG 重归一化最小值",
|
| 560 |
+
"1.0 disables CFG-Renorm":"1.0 禁用 CFG 重归一化",
|
| 561 |
+
"Timesteps":"时间步数",
|
| 562 |
+
"Total denoising steps":"总去噪步数",
|
| 563 |
+
"Timestep Shift":"时间步偏移",
|
| 564 |
+
"Higher values for layout, lower for details":"值更大更倾向于调整布局,值更小更倾向于调整细节",
|
| 565 |
+
"Sampling":"采样",
|
| 566 |
+
"Enable sampling for text generation":"为文本生成启用采样",
|
| 567 |
+
"Max Think Tokens":"最大思考token数",
|
| 568 |
+
"Maximum number of tokens for thinking":"思考的最大token数",
|
| 569 |
+
"Temperature":"温度系数",
|
| 570 |
+
"Controls randomness in text generation":"控制文本生成的随机性",
|
| 571 |
+
"Thinking Process":"思考过程",
|
| 572 |
+
"Generated Image":"生成图像",
|
| 573 |
+
"Generate":"开始生成",
|
| 574 |
+
"🖌️ Image Edit":"🖌️ 图像编辑",
|
| 575 |
+
"Input Image":"图像输入",
|
| 576 |
+
"Result":"结果",
|
| 577 |
+
"Controls how strongly the model follows the text prompt":"控制模型是否遵循文本提示的强度",
|
| 578 |
+
"CFG Image Scale":"图像CFG强度",
|
| 579 |
+
"Controls how much the model preserves input image details":"控制模型保留输入图像细节的强度",
|
| 580 |
+
"Submit":"开始生成",
|
| 581 |
+
"🖼️ Image Understanding":"🖼️ 图像理解",
|
| 582 |
+
"Controls randomness in text generation (0=deterministic, 1=creative)":"控制文本生成的随机性(0=确定,1=creative)",
|
| 583 |
+
"Max New Tokens":"最大新token数",
|
| 584 |
+
"Maximum length of generated text, including potential thinking":"生成文本的最大长度,包括可能的思考",
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
def apply_localization(block):
|
| 588 |
+
def process_component(component):
|
| 589 |
+
if not component:
|
| 590 |
+
return
|
| 591 |
+
|
| 592 |
+
for attr in ['label', 'info', 'placeholder']:
|
| 593 |
+
if hasattr(component, attr):
|
| 594 |
+
text = getattr(component, attr)
|
| 595 |
+
if text in UI_TRANSLATIONS:
|
| 596 |
+
setattr(component, attr, UI_TRANSLATIONS[text])
|
| 597 |
+
|
| 598 |
+
if hasattr(component, 'children'):
|
| 599 |
+
for child in component.children:
|
| 600 |
+
process_component(child)
|
| 601 |
+
|
| 602 |
+
process_component(block)
|
| 603 |
+
return block
|
| 604 |
+
|
| 605 |
+
if __name__ == "__main__":
|
| 606 |
+
if args.zh:
|
| 607 |
+
demo = apply_localization(demo)
|
| 608 |
+
demo.launch(
|
| 609 |
+
server_name=args.server_name,
|
| 610 |
+
server_port=args.server_port,
|
| 611 |
+
share=args.share,
|
| 612 |
+
inbrowser=True,
|
| 613 |
+
)
|
bug.log
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
现在我明白问题了!模型目录中有 model_bf16.safetensors 但是代码在寻找 model.safetensors。让我检查一下 fsdp_utils.py 中的加载逻辑:
|
| 2 |
+
|
| 3 |
+
Read fsdp_utils.py, lines 155 to 165
|
| 4 |
+
|
| 5 |
+
找到问题了!代码在寻找 model.safetensors,但实际文件名是 model_bf16.safetensors。我需要修改加载逻辑来适配实际的文件名。
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
cd /scratch/by2593/Bagel-Zebra-CoT-origin && PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
data/configs/example_smm_semantic.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
block_dataset:
|
| 2 |
+
dataset_names:
|
| 3 |
+
- block_dataset
|
| 4 |
+
jsonl_path_list: ["/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl"]
|
| 5 |
+
num_used_data: None
|
| 6 |
+
image_prefix_dir: "/scratch/by2593/project/SMM/semantic_blocks_part1"
|
| 7 |
+
image_transform_args:
|
| 8 |
+
image_stride: 16
|
| 9 |
+
max_image_size: 512 # VAE使用stride=16, 512/16=32 patches
|
| 10 |
+
min_image_size: 512
|
| 11 |
+
vit_image_transform_args:
|
| 12 |
+
image_stride: 14
|
| 13 |
+
max_image_size: 512 # ViT使用stride=14, 512/14=36 patches (匹配模型能力)
|
| 14 |
+
min_image_size: 512
|
| 15 |
+
weight: 1.0
|
| 16 |
+
is_mandatory: true
|
| 17 |
+
|
| 18 |
+
# unified_edit:
|
| 19 |
+
# dataset_names:
|
| 20 |
+
# - seedxedit_multi
|
| 21 |
+
# image_transform_args:
|
| 22 |
+
# image_stride: 16
|
| 23 |
+
# max_image_size: 1024
|
| 24 |
+
# min_image_size: 512
|
| 25 |
+
# vit_image_transform_args:
|
| 26 |
+
# image_stride: 14
|
| 27 |
+
# max_image_size: 518
|
| 28 |
+
# min_image_size: 224
|
| 29 |
+
# is_mandatory: true
|
| 30 |
+
# num_used_data:
|
| 31 |
+
# - 10
|
| 32 |
+
# weight: 1
|
| 33 |
+
|
| 34 |
+
# vlm_sft:
|
| 35 |
+
# dataset_names:
|
| 36 |
+
# - llava_ov
|
| 37 |
+
# image_transform_args:
|
| 38 |
+
# image_stride: 14
|
| 39 |
+
# max_image_size: 980
|
| 40 |
+
# min_image_size: 378
|
| 41 |
+
# max_pixels: 2_007_040
|
| 42 |
+
# frame_sampler_args:
|
| 43 |
+
# max_num_frames: 12
|
| 44 |
+
# min_num_frames: 8
|
| 45 |
+
# is_mandatory: true
|
| 46 |
+
# shuffle_lines: True
|
| 47 |
+
# shuffle_seed: 0
|
| 48 |
+
# num_used_data:
|
| 49 |
+
# - 1000
|
| 50 |
+
# weight: 1
|
data/data_utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn.attention.flex_attention import or_masks, and_masks
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
|
| 14 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
| 15 |
+
return q_idx >= kv_idx
|
| 16 |
+
|
| 17 |
+
def full_and_noise_mask(b, h, q_idx, kv_idx):
|
| 18 |
+
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
|
| 19 |
+
|
| 20 |
+
def remove_noise_mask(b, h, q_idx, kv_idx):
|
| 21 |
+
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
|
| 22 |
+
|
| 23 |
+
def sample_mask(b, h, q_idx, kv_idx):
|
| 24 |
+
return document_id[q_idx] == document_id[kv_idx]
|
| 25 |
+
|
| 26 |
+
full_and_noise_tmp = []
|
| 27 |
+
noise_tmp = []
|
| 28 |
+
|
| 29 |
+
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
|
| 30 |
+
value = i if model in ['full', 'noise'] else -1
|
| 31 |
+
full_and_noise_tmp.extend([value] * length)
|
| 32 |
+
value_noise = i if model == 'noise' else -1
|
| 33 |
+
noise_tmp.extend([value_noise] * length)
|
| 34 |
+
|
| 35 |
+
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
|
| 36 |
+
noise_seq_id = torch.Tensor(noise_tmp).to(device)
|
| 37 |
+
|
| 38 |
+
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
|
| 39 |
+
|
| 40 |
+
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def patchify(image, patch_size):
|
| 44 |
+
p = patch_size
|
| 45 |
+
c, h, w = image.shape
|
| 46 |
+
assert h % p == 0 and w % p == 0
|
| 47 |
+
image = image.reshape(c, h // p, p, w // p, p)
|
| 48 |
+
image = torch.einsum("chpwq->hwpqc", image)
|
| 49 |
+
image = image.reshape(-1, p**2 * c)
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
|
| 54 |
+
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
|
| 55 |
+
coords_h = torch.arange(0, num_patches_h)
|
| 56 |
+
coords_w = torch.arange(0, num_patches_w)
|
| 57 |
+
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
|
| 58 |
+
return pos_ids
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
|
| 62 |
+
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
|
| 63 |
+
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
|
| 64 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
|
| 65 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
|
| 66 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
| 67 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
| 68 |
+
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
|
| 69 |
+
return pos_ids
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
|
| 73 |
+
"""
|
| 74 |
+
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
|
| 75 |
+
a sample, where each sample contains multiple splits with different attn modes.
|
| 76 |
+
nested_attn_modes: whether to use full attn in each split.
|
| 77 |
+
"""
|
| 78 |
+
sample_len = sum(split_lens)
|
| 79 |
+
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
|
| 80 |
+
|
| 81 |
+
csum = 0
|
| 82 |
+
for s, attn_mode in zip(split_lens, attn_modes):
|
| 83 |
+
assert attn_mode in ['causal', 'full', 'noise']
|
| 84 |
+
if attn_mode == "causal":
|
| 85 |
+
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
|
| 86 |
+
attention_mask[csum:csum + s, :csum] = 1
|
| 87 |
+
else:
|
| 88 |
+
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
|
| 89 |
+
attention_mask[csum:csum + s, :csum] = 1
|
| 90 |
+
csum += s
|
| 91 |
+
|
| 92 |
+
csum = 0
|
| 93 |
+
for s, attn_mode in zip(split_lens, attn_modes):
|
| 94 |
+
if attn_mode == "noise":
|
| 95 |
+
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
|
| 96 |
+
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
|
| 97 |
+
csum += s
|
| 98 |
+
|
| 99 |
+
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
|
| 100 |
+
~attention_mask, float("-inf")
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return attention_mask
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def split_integer_exp_decay(S, ng_sample_decay=1.0):
|
| 107 |
+
if ng_sample_decay == 1.0:
|
| 108 |
+
N = random.randint(1, S)
|
| 109 |
+
else:
|
| 110 |
+
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
|
| 111 |
+
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
|
| 112 |
+
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
|
| 113 |
+
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
|
| 114 |
+
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
|
| 115 |
+
return result, cumsum
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def pil_img2rgb(image):
|
| 119 |
+
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
|
| 120 |
+
image = image.convert("RGBA")
|
| 121 |
+
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
|
| 122 |
+
white.paste(image, mask=image.split()[3])
|
| 123 |
+
image = white
|
| 124 |
+
else:
|
| 125 |
+
image = image.convert("RGB")
|
| 126 |
+
|
| 127 |
+
return image
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def add_special_tokens(tokenizer):
|
| 131 |
+
all_special_tokens = []
|
| 132 |
+
for k, v in tokenizer.special_tokens_map.items():
|
| 133 |
+
if isinstance(v, str):
|
| 134 |
+
all_special_tokens.append(v)
|
| 135 |
+
elif isinstance(v, list):
|
| 136 |
+
all_special_tokens += v
|
| 137 |
+
|
| 138 |
+
new_tokens = []
|
| 139 |
+
|
| 140 |
+
if '<|im_start|>' not in all_special_tokens:
|
| 141 |
+
new_tokens.append('<|im_start|>')
|
| 142 |
+
|
| 143 |
+
if '<|im_end|>' not in all_special_tokens:
|
| 144 |
+
new_tokens.append('<|im_end|>')
|
| 145 |
+
|
| 146 |
+
if '<|vision_start|>' not in all_special_tokens:
|
| 147 |
+
new_tokens.append('<|vision_start|>')
|
| 148 |
+
|
| 149 |
+
if '<|vision_end|>' not in all_special_tokens:
|
| 150 |
+
new_tokens.append('<|vision_end|>')
|
| 151 |
+
|
| 152 |
+
num_new_tokens = tokenizer.add_tokens(new_tokens)
|
| 153 |
+
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
|
| 154 |
+
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
|
| 155 |
+
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
|
| 156 |
+
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
|
| 157 |
+
|
| 158 |
+
new_token_ids = dict(
|
| 159 |
+
bos_token_id=bos_token_id,
|
| 160 |
+
eos_token_id=eos_token_id,
|
| 161 |
+
start_of_image=start_of_image,
|
| 162 |
+
end_of_image=end_of_image,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return tokenizer, new_token_ids, num_new_tokens
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def len2weight(x, loss_reduction='square'):
|
| 169 |
+
if x == 0:
|
| 170 |
+
return x
|
| 171 |
+
if loss_reduction == 'token':
|
| 172 |
+
return 1
|
| 173 |
+
if loss_reduction == 'sample':
|
| 174 |
+
return 1 / x
|
| 175 |
+
if loss_reduction == 'square':
|
| 176 |
+
return 1 / (x ** 0.5)
|
| 177 |
+
raise NotImplementedError(loss_reduction)
|
data/interleave_datasets/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from .edit_dataset import UnifiedEditIterableDataset
|
| 5 |
+
from .think_trace_dataset import ThinkTraceJSONLIterableDataset
|
| 6 |
+
|
data/parquet_utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
import pyarrow.fs as pf
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
|
| 16 |
+
num_data_dirs = len(data_dir_list)
|
| 17 |
+
if world_size > 1:
|
| 18 |
+
chunk_size = (num_data_dirs + world_size - 1) // world_size
|
| 19 |
+
start_idx = rank * chunk_size
|
| 20 |
+
end_idx = min(start_idx + chunk_size, num_data_dirs)
|
| 21 |
+
local_data_dir_list = data_dir_list[start_idx:end_idx]
|
| 22 |
+
local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
|
| 23 |
+
else:
|
| 24 |
+
local_data_dir_list = data_dir_list
|
| 25 |
+
local_num_sampled_data_paths = num_sampled_data_paths
|
| 26 |
+
|
| 27 |
+
local_data_paths = []
|
| 28 |
+
for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
|
| 29 |
+
if data_dir.startswith("hdfs://"):
|
| 30 |
+
files = hdfs_ls_cmd(data_dir)
|
| 31 |
+
data_paths_per_dir = [
|
| 32 |
+
file for file in files if file.endswith(".parquet")
|
| 33 |
+
]
|
| 34 |
+
else:
|
| 35 |
+
files = os.listdir(data_dir)
|
| 36 |
+
data_paths_per_dir = [
|
| 37 |
+
os.path.join(data_dir, name)
|
| 38 |
+
for name in files
|
| 39 |
+
if name.endswith(".parquet")
|
| 40 |
+
]
|
| 41 |
+
repeat = num_data_path // len(data_paths_per_dir)
|
| 42 |
+
data_paths_per_dir = data_paths_per_dir * (repeat + 1)
|
| 43 |
+
local_data_paths.extend(data_paths_per_dir[:num_data_path])
|
| 44 |
+
|
| 45 |
+
if world_size > 1:
|
| 46 |
+
gather_list = [None] * world_size
|
| 47 |
+
dist.all_gather_object(gather_list, local_data_paths)
|
| 48 |
+
|
| 49 |
+
combined_chunks = []
|
| 50 |
+
for chunk_list in gather_list:
|
| 51 |
+
if chunk_list is not None:
|
| 52 |
+
combined_chunks.extend(chunk_list)
|
| 53 |
+
else:
|
| 54 |
+
combined_chunks = local_data_paths
|
| 55 |
+
|
| 56 |
+
return combined_chunks
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# NOTE: cumtomize this function for your cluster
|
| 60 |
+
def get_hdfs_host():
|
| 61 |
+
return "hdfs://xxx"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# NOTE: cumtomize this function for your cluster
|
| 65 |
+
def get_hdfs_block_size():
|
| 66 |
+
return 134217728
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# NOTE: cumtomize this function for your cluster
|
| 70 |
+
def get_hdfs_extra_conf():
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_arrow_pf_fs(parquet_file_path):
|
| 75 |
+
if parquet_file_path.startswith("hdfs://"):
|
| 76 |
+
fs = pf.HadoopFileSystem(
|
| 77 |
+
host=get_hdfs_host(),
|
| 78 |
+
port=0,
|
| 79 |
+
buffer_size=get_hdfs_block_size(),
|
| 80 |
+
extra_conf=get_hdfs_extra_conf(),
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
fs = pf.LocalFileSystem()
|
| 84 |
+
return fs
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def hdfs_ls_cmd(dir):
|
| 88 |
+
result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
|
| 89 |
+
return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
|
data/t2i_dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import pyarrow.parquet as pq
|
| 7 |
+
import random
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .data_utils import pil_img2rgb
|
| 11 |
+
from .distributed_iterable_dataset import DistributedIterableDataset
|
| 12 |
+
from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
|
| 13 |
+
|
| 14 |
+
Image.MAX_IMAGE_PIXELS = 20_000_000
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class T2IIterableDataset(DistributedIterableDataset):
|
| 18 |
+
def __init__(
|
| 19 |
+
self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
|
| 20 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
data_dir_list: list of data directories contains parquet files
|
| 24 |
+
num_used_data: list of number of sampled data paths for each data directory
|
| 25 |
+
"""
|
| 26 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.tokenizer = tokenizer
|
| 29 |
+
self.data_status = data_status
|
| 30 |
+
self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
|
| 31 |
+
self.set_epoch()
|
| 32 |
+
|
| 33 |
+
def get_data_paths(self, data_dir_list, num_used_data):
|
| 34 |
+
return get_parquet_data_paths(data_dir_list, num_used_data)
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
| 38 |
+
if self.data_status is not None:
|
| 39 |
+
parquet_start_id = self.data_status[worker_id][0]
|
| 40 |
+
row_group_start_id = self.data_status[worker_id][1]
|
| 41 |
+
row_start_id = self.data_status[worker_id][2] + 1
|
| 42 |
+
else:
|
| 43 |
+
parquet_start_id = 0
|
| 44 |
+
row_group_start_id = 0
|
| 45 |
+
row_start_id = 0
|
| 46 |
+
transform_stride = self.transform.stride
|
| 47 |
+
|
| 48 |
+
print(
|
| 49 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
| 50 |
+
f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
while True:
|
| 54 |
+
data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
|
| 55 |
+
for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
|
| 56 |
+
fs = init_arrow_pf_fs(parquet_file_path)
|
| 57 |
+
with fs.open_input_file(parquet_file_path) as f:
|
| 58 |
+
fr = pq.ParquetFile(f)
|
| 59 |
+
row_group_ids = list(range(fr.num_row_groups))
|
| 60 |
+
row_group_ids_ = row_group_ids[row_group_start_id:]
|
| 61 |
+
|
| 62 |
+
for row_group_id in row_group_ids_:
|
| 63 |
+
df = fr.read_row_group(row_group_id).to_pandas()
|
| 64 |
+
df = df.iloc[row_start_id:]
|
| 65 |
+
|
| 66 |
+
for row_idx, row in df.iterrows():
|
| 67 |
+
num_tokens = 0
|
| 68 |
+
try:
|
| 69 |
+
image_byte = row['image']
|
| 70 |
+
image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
|
| 73 |
+
continue
|
| 74 |
+
image_tensor = self.transform(image)
|
| 75 |
+
height, width = image_tensor.shape[1:]
|
| 76 |
+
num_tokens += width * height // transform_stride ** 2
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
caption_dict = row['captions']
|
| 80 |
+
caption_dict = json.loads(caption_dict)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
|
| 86 |
+
if len(caps_token) == 0:
|
| 87 |
+
print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
|
| 88 |
+
caption_token = self.tokenizer.encode(' ')
|
| 89 |
+
else:
|
| 90 |
+
caption_token = random.choice(caps_token)
|
| 91 |
+
|
| 92 |
+
sequence_plan, text_ids_list = [], []
|
| 93 |
+
text_ids = caption_token
|
| 94 |
+
num_tokens += len(caption_token)
|
| 95 |
+
text_ids_list.append(text_ids)
|
| 96 |
+
sequence_plan.append({
|
| 97 |
+
'type': 'text',
|
| 98 |
+
'enable_cfg': 1,
|
| 99 |
+
'loss': 0,
|
| 100 |
+
'special_token_loss': 0,
|
| 101 |
+
'special_token_label': None,
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
sequence_plan.append({
|
| 105 |
+
'type': 'vae_image',
|
| 106 |
+
'enable_cfg': 0,
|
| 107 |
+
'loss': 1,
|
| 108 |
+
'special_token_loss': 0,
|
| 109 |
+
'special_token_label': None,
|
| 110 |
+
})
|
| 111 |
+
|
| 112 |
+
sample = dict(
|
| 113 |
+
image_tensor_list=[image_tensor],
|
| 114 |
+
text_ids_list=text_ids_list,
|
| 115 |
+
num_tokens=num_tokens,
|
| 116 |
+
sequence_plan=sequence_plan,
|
| 117 |
+
data_indexes={
|
| 118 |
+
"data_indexes": [parquet_idx, row_group_id, row_idx],
|
| 119 |
+
"worker_id": worker_id,
|
| 120 |
+
"dataset_name": self.dataset_name,
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
yield sample
|
| 124 |
+
|
| 125 |
+
row_start_id = 0
|
| 126 |
+
row_group_start_id = 0
|
| 127 |
+
parquet_start_id = 0
|
| 128 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
data/transforms.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from torchvision.transforms import functional as F
|
| 12 |
+
from torchvision.transforms import InterpolationMode
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
|
| 16 |
+
"""Resize the input image so that its longest side and shortest side are within a specified range,
|
| 17 |
+
ensuring that both sides are divisible by a specified stride.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
max_size (int): Maximum size for the longest edge of the image.
|
| 21 |
+
min_size (int): Minimum size for the shortest edge of the image.
|
| 22 |
+
stride (int): Value by which the height and width of the image must be divisible.
|
| 23 |
+
max_pixels (int): Maximum pixels for the full image.
|
| 24 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 25 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 26 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 27 |
+
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
|
| 28 |
+
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
|
| 29 |
+
antialias (bool, optional): Whether to apply antialiasing (default is True).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
max_size: int,
|
| 35 |
+
min_size: int,
|
| 36 |
+
stride: int,
|
| 37 |
+
max_pixels: int,
|
| 38 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 39 |
+
antialias=True
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.max_size = max_size
|
| 43 |
+
self.min_size = min_size
|
| 44 |
+
self.stride = stride
|
| 45 |
+
self.max_pixels = max_pixels
|
| 46 |
+
self.interpolation = interpolation
|
| 47 |
+
self.antialias = antialias
|
| 48 |
+
|
| 49 |
+
def _make_divisible(self, value, stride):
|
| 50 |
+
"""Ensure the value is divisible by the stride."""
|
| 51 |
+
return max(stride, int(round(value / stride) * stride))
|
| 52 |
+
|
| 53 |
+
def _apply_scale(self, width, height, scale):
|
| 54 |
+
new_width = round(width * scale)
|
| 55 |
+
new_height = round(height * scale)
|
| 56 |
+
new_width = self._make_divisible(new_width, self.stride)
|
| 57 |
+
new_height = self._make_divisible(new_height, self.stride)
|
| 58 |
+
return new_width, new_height
|
| 59 |
+
|
| 60 |
+
def forward(self, img, img_num=1):
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
img (PIL Image): Image to be resized.
|
| 64 |
+
img_num (int): Number of images, used to change max_tokens.
|
| 65 |
+
Returns:
|
| 66 |
+
PIL Image or Tensor: Rescaled image with divisible dimensions.
|
| 67 |
+
"""
|
| 68 |
+
if isinstance(img, torch.Tensor):
|
| 69 |
+
height, width = img.shape[-2:]
|
| 70 |
+
else:
|
| 71 |
+
width, height = img.size
|
| 72 |
+
|
| 73 |
+
scale = min(self.max_size / max(width, height), 1.0)
|
| 74 |
+
scale = max(scale, self.min_size / min(width, height))
|
| 75 |
+
new_width, new_height = self._apply_scale(width, height, scale)
|
| 76 |
+
|
| 77 |
+
# Ensure the number of pixels does not exceed max_pixels
|
| 78 |
+
if new_width * new_height > self.max_pixels / img_num:
|
| 79 |
+
scale = self.max_pixels / img_num / (new_width * new_height)
|
| 80 |
+
new_width, new_height = self._apply_scale(new_width, new_height, scale)
|
| 81 |
+
|
| 82 |
+
# Ensure longest edge does not exceed max_size
|
| 83 |
+
if max(new_width, new_height) > self.max_size:
|
| 84 |
+
scale = self.max_size / max(new_width, new_height)
|
| 85 |
+
new_width, new_height = self._apply_scale(new_width, new_height, scale)
|
| 86 |
+
|
| 87 |
+
return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ImageTransform:
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
max_image_size,
|
| 94 |
+
min_image_size,
|
| 95 |
+
image_stride,
|
| 96 |
+
max_pixels=14*14*9*1024,
|
| 97 |
+
image_mean=[0.5, 0.5, 0.5],
|
| 98 |
+
image_std=[0.5, 0.5, 0.5]
|
| 99 |
+
):
|
| 100 |
+
self.stride = image_stride
|
| 101 |
+
|
| 102 |
+
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
|
| 103 |
+
max_size=max_image_size,
|
| 104 |
+
min_size=min_image_size,
|
| 105 |
+
stride=image_stride,
|
| 106 |
+
max_pixels=max_pixels,
|
| 107 |
+
)
|
| 108 |
+
self.to_tensor_transform = transforms.ToTensor()
|
| 109 |
+
self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
|
| 110 |
+
|
| 111 |
+
def __call__(self, img, img_num=1):
|
| 112 |
+
img = self.resize_transform(img, img_num=img_num)
|
| 113 |
+
img = self.to_tensor_transform(img)
|
| 114 |
+
img = self.normalize_transform(img)
|
| 115 |
+
return img
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def decolorization(image):
|
| 119 |
+
gray_image = image.convert('L')
|
| 120 |
+
return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def downscale(image, scale_factor):
|
| 124 |
+
new_width = int(round(image.width * scale_factor))
|
| 125 |
+
new_height = int(round(image.height * scale_factor))
|
| 126 |
+
new_width = max(1, new_width)
|
| 127 |
+
new_height = max(1, new_height)
|
| 128 |
+
return image.resize((new_width, new_height), resample=Image.BICUBIC)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def crop(image, crop_factors):
|
| 132 |
+
target_h, target_w = crop_factors
|
| 133 |
+
img_w, img_h = image.size
|
| 134 |
+
|
| 135 |
+
if target_h > img_h or target_w > img_w:
|
| 136 |
+
raise ValueError("Crop size exceeds image dimensions")
|
| 137 |
+
|
| 138 |
+
x = random.randint(0, img_w - target_w)
|
| 139 |
+
y = random.randint(0, img_h - target_h)
|
| 140 |
+
|
| 141 |
+
return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def motion_blur_opencv(image, kernel_size=15, angle=0):
|
| 145 |
+
# 线性核
|
| 146 |
+
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
| 147 |
+
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
|
| 148 |
+
|
| 149 |
+
# 旋转核
|
| 150 |
+
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
|
| 151 |
+
M = cv2.getRotationMatrix2D(center, angle, 1)
|
| 152 |
+
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
|
| 153 |
+
|
| 154 |
+
# 归一化核
|
| 155 |
+
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
|
| 156 |
+
|
| 157 |
+
img = np.array(image)
|
| 158 |
+
if img.ndim == 2:
|
| 159 |
+
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
|
| 160 |
+
else:
|
| 161 |
+
# 对于彩色图像,各通道独立卷积
|
| 162 |
+
blurred = np.zeros_like(img)
|
| 163 |
+
for c in range(img.shape[2]):
|
| 164 |
+
blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
|
| 165 |
+
|
| 166 |
+
return Image.fromarray(blurred.astype(np.uint8))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def shuffle_patch(image, num_splits, gap_size=2):
|
| 170 |
+
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
|
| 171 |
+
h_splits, w_splits = num_splits
|
| 172 |
+
img_w, img_h = image.size
|
| 173 |
+
|
| 174 |
+
base_patch_h = img_h // h_splits
|
| 175 |
+
patch_heights = [base_patch_h] * (h_splits - 1)
|
| 176 |
+
patch_heights.append(img_h - sum(patch_heights))
|
| 177 |
+
|
| 178 |
+
base_patch_w = img_w // w_splits
|
| 179 |
+
patch_widths = [base_patch_w] * (w_splits - 1)
|
| 180 |
+
patch_widths.append(img_w - sum(patch_widths))
|
| 181 |
+
|
| 182 |
+
patches = []
|
| 183 |
+
current_y = 0
|
| 184 |
+
for i in range(h_splits):
|
| 185 |
+
current_x = 0
|
| 186 |
+
patch_h = patch_heights[i]
|
| 187 |
+
for j in range(w_splits):
|
| 188 |
+
patch_w = patch_widths[j]
|
| 189 |
+
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
|
| 190 |
+
patches.append(patch)
|
| 191 |
+
current_x += patch_w
|
| 192 |
+
current_y += patch_h
|
| 193 |
+
|
| 194 |
+
random.shuffle(patches)
|
| 195 |
+
|
| 196 |
+
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
|
| 197 |
+
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
|
| 198 |
+
new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
|
| 199 |
+
|
| 200 |
+
current_y = 0 # 当前行的起始 Y 坐标
|
| 201 |
+
patch_idx = 0 # 当前处理的块索引
|
| 202 |
+
for i in range(h_splits):
|
| 203 |
+
current_x = 0 # 当前列的起始 X 坐标
|
| 204 |
+
patch_h = patch_heights[i] # 当前行块的高度
|
| 205 |
+
for j in range(w_splits):
|
| 206 |
+
# 取出打乱后的块
|
| 207 |
+
patch = patches[patch_idx]
|
| 208 |
+
patch_w = patch_widths[j] # 当前列块的宽度
|
| 209 |
+
# 粘贴块(左上角坐标为 (current_x, current_y))
|
| 210 |
+
new_image.paste(patch, (current_x, current_y))
|
| 211 |
+
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
|
| 212 |
+
current_x += patch_w + gap_size
|
| 213 |
+
patch_idx += 1
|
| 214 |
+
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
|
| 215 |
+
current_y += patch_h + gap_size
|
| 216 |
+
|
| 217 |
+
return new_image
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
|
| 221 |
+
"""
|
| 222 |
+
图像分割后随机空白部分patch,用于inpainting任务
|
| 223 |
+
|
| 224 |
+
参数:
|
| 225 |
+
image: PIL.Image 输入图像(RGB模式)
|
| 226 |
+
h_splits: int 行分割数(垂直方向分割块数)
|
| 227 |
+
w_splits: int 列分割数(水平方向分割块数)
|
| 228 |
+
blank_ratio: float 空白patch的比例(0~1)
|
| 229 |
+
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
|
| 230 |
+
|
| 231 |
+
返回:
|
| 232 |
+
PIL.Image 处理后拼接的图像
|
| 233 |
+
"""
|
| 234 |
+
h_splits, w_splits = num_splits
|
| 235 |
+
img_w, img_h = image.size
|
| 236 |
+
|
| 237 |
+
base_patch_h = img_h // h_splits
|
| 238 |
+
patch_heights = [base_patch_h] * (h_splits - 1)
|
| 239 |
+
patch_heights.append(img_h - sum(patch_heights))
|
| 240 |
+
|
| 241 |
+
base_patch_w = img_w // w_splits
|
| 242 |
+
patch_widths = [base_patch_w] * (w_splits - 1)
|
| 243 |
+
patch_widths.append(img_w - sum(patch_widths))
|
| 244 |
+
|
| 245 |
+
patches = []
|
| 246 |
+
current_y = 0
|
| 247 |
+
for i in range(h_splits):
|
| 248 |
+
current_x = 0
|
| 249 |
+
patch_h = patch_heights[i]
|
| 250 |
+
for j in range(w_splits):
|
| 251 |
+
patch_w = patch_widths[j]
|
| 252 |
+
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
|
| 253 |
+
patches.append(patch)
|
| 254 |
+
current_x += patch_w
|
| 255 |
+
current_y += patch_h
|
| 256 |
+
|
| 257 |
+
total_patches = h_splits * w_splits
|
| 258 |
+
num_blank = int(total_patches * blank_ratio)
|
| 259 |
+
num_blank = max(0, min(num_blank, total_patches))
|
| 260 |
+
blank_indices = random.sample(range(total_patches), num_blank)
|
| 261 |
+
|
| 262 |
+
processed_patches = []
|
| 263 |
+
for idx, patch in enumerate(patches):
|
| 264 |
+
if idx in blank_indices:
|
| 265 |
+
blank_patch = Image.new("RGB", patch.size, color=blank_color)
|
| 266 |
+
processed_patches.append(blank_patch)
|
| 267 |
+
else:
|
| 268 |
+
processed_patches.append(patch)
|
| 269 |
+
|
| 270 |
+
# 创建结果图像(尺寸与原图一致)
|
| 271 |
+
result_image = Image.new("RGB", (img_w, img_h))
|
| 272 |
+
current_y = 0
|
| 273 |
+
patch_idx = 0
|
| 274 |
+
for i in range(h_splits):
|
| 275 |
+
current_x = 0
|
| 276 |
+
patch_h = patch_heights[i]
|
| 277 |
+
for j in range(w_splits):
|
| 278 |
+
# 取出处理后的patch
|
| 279 |
+
patch = processed_patches[patch_idx]
|
| 280 |
+
patch_w = patch_widths[j]
|
| 281 |
+
# 粘贴到原位置
|
| 282 |
+
result_image.paste(patch, (current_x, current_y))
|
| 283 |
+
current_x += patch_w
|
| 284 |
+
patch_idx += 1
|
| 285 |
+
current_y += patch_h
|
| 286 |
+
|
| 287 |
+
return result_image
|
data/video_utils.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 OpenGVLab
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: MIT
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under MIT, with the full license text
|
| 8 |
+
# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import io
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import decord
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
| 24 |
+
if sample in ['rand', 'middle']: # uniform sampling
|
| 25 |
+
acc_samples = min(num_frames, vlen)
|
| 26 |
+
# split the video into `acc_samples` intervals, and sample from each interval.
|
| 27 |
+
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
| 28 |
+
ranges = []
|
| 29 |
+
for idx, interv in enumerate(intervals[:-1]):
|
| 30 |
+
ranges.append((interv, intervals[idx + 1] - 1))
|
| 31 |
+
if sample == 'rand':
|
| 32 |
+
try:
|
| 33 |
+
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
| 34 |
+
except:
|
| 35 |
+
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
| 36 |
+
frame_indices.sort()
|
| 37 |
+
frame_indices = list(frame_indices)
|
| 38 |
+
elif fix_start is not None:
|
| 39 |
+
frame_indices = [x[0] + fix_start for x in ranges]
|
| 40 |
+
elif sample == 'middle':
|
| 41 |
+
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
if len(frame_indices) < num_frames: # padded with last frame
|
| 46 |
+
padded_frame_indices = [frame_indices[-1]] * num_frames
|
| 47 |
+
padded_frame_indices[:len(frame_indices)] = frame_indices
|
| 48 |
+
frame_indices = padded_frame_indices
|
| 49 |
+
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
| 50 |
+
output_fps = float(sample[3:])
|
| 51 |
+
duration = float(vlen) / input_fps
|
| 52 |
+
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
| 53 |
+
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
| 54 |
+
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
| 55 |
+
frame_indices = [e for e in frame_indices if e < vlen]
|
| 56 |
+
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
| 57 |
+
frame_indices = frame_indices[:max_num_frames]
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError
|
| 60 |
+
return frame_indices
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
|
| 64 |
+
video_reader = decord.VideoReader(video_path, num_threads=1)
|
| 65 |
+
vlen = len(video_reader)
|
| 66 |
+
fps = video_reader.get_avg_fps()
|
| 67 |
+
duration = vlen / float(fps)
|
| 68 |
+
if clip:
|
| 69 |
+
start, end = clip
|
| 70 |
+
duration = end - start
|
| 71 |
+
vlen = int(duration * fps)
|
| 72 |
+
start_index = int(start * fps)
|
| 73 |
+
|
| 74 |
+
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
|
| 75 |
+
|
| 76 |
+
frame_indices = get_frame_indices(
|
| 77 |
+
t_num_frames, vlen, sample=sample, fix_start=fix_start,
|
| 78 |
+
input_fps=fps
|
| 79 |
+
)
|
| 80 |
+
if clip:
|
| 81 |
+
frame_indices = [f + start_index for f in frame_indices]
|
| 82 |
+
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
|
| 83 |
+
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
| 84 |
+
return frames
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def extract_frame_number(filename):
|
| 88 |
+
# Extract the numeric part from the filename using regular expressions
|
| 89 |
+
match = re.search(r'_(\d+).jpg$', filename)
|
| 90 |
+
return int(match.group(1)) if match else -1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def sort_frames(frame_paths):
|
| 94 |
+
# Extract filenames from each path and sort by their numeric part
|
| 95 |
+
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
|
| 99 |
+
image_list = sort_frames(list(os.listdir(video_path)))
|
| 100 |
+
frames = []
|
| 101 |
+
for image in image_list:
|
| 102 |
+
fp = os.path.join(video_path, image)
|
| 103 |
+
frame = Image.open(fp).convert('RGB')
|
| 104 |
+
frames.append(frame)
|
| 105 |
+
vlen = len(frames)
|
| 106 |
+
|
| 107 |
+
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
|
| 108 |
+
|
| 109 |
+
if vlen > t_num_frames:
|
| 110 |
+
frame_indices = get_frame_indices(
|
| 111 |
+
t_num_frames, vlen, sample=sample, fix_start=fix_start
|
| 112 |
+
)
|
| 113 |
+
frames = [frames[i] for i in frame_indices]
|
| 114 |
+
return frames
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class FrameSampler:
|
| 118 |
+
def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
|
| 119 |
+
self.max_num_frames = max_num_frames
|
| 120 |
+
self.min_num_frames = min_num_frames
|
| 121 |
+
self.sample = sample
|
| 122 |
+
|
| 123 |
+
def __call__(self, file_name):
|
| 124 |
+
fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
|
| 125 |
+
frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
|
| 126 |
+
return frames
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def decode_video_byte(video_bytes):
|
| 130 |
+
video_stream = io.BytesIO(video_bytes)
|
| 131 |
+
vr = decord.VideoReader(video_stream)
|
| 132 |
+
return vr
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
|
| 136 |
+
if isinstance(mp4_p, str):
|
| 137 |
+
vr = decord.VideoReader(mp4_p, num_threads=1)
|
| 138 |
+
elif isinstance(mp4_p, decord.video_reader.VideoReader):
|
| 139 |
+
vr = mp4_p
|
| 140 |
+
video_fps = vr.get_avg_fps() # 获取视频的帧率
|
| 141 |
+
video_duration = len(vr) / video_fps
|
| 142 |
+
if n_frames is not None:
|
| 143 |
+
if random_sample:
|
| 144 |
+
frame_indices = sorted(random.sample(range(len(vr)), n_frames))
|
| 145 |
+
else:
|
| 146 |
+
frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
|
| 147 |
+
else:
|
| 148 |
+
frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
|
| 149 |
+
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
|
| 150 |
+
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
|
| 151 |
+
if not return_frame_indices:
|
| 152 |
+
return frames, video_duration
|
| 153 |
+
else:
|
| 154 |
+
return frames, video_duration, frame_indices
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
|
| 158 |
+
if isinstance(mp4_p, str):
|
| 159 |
+
vr = decord.VideoReader(mp4_p, num_threads=1)
|
| 160 |
+
elif isinstance(mp4_p, decord.video_reader.VideoReader):
|
| 161 |
+
vr = mp4_p
|
| 162 |
+
# sample the frames in frame_indices
|
| 163 |
+
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
|
| 164 |
+
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
|
| 165 |
+
return frames
|
data/vlm_dataset.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import traceback
|
| 7 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
| 8 |
+
|
| 9 |
+
from .data_utils import pil_img2rgb
|
| 10 |
+
from .distributed_iterable_dataset import DistributedIterableDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
| 14 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 15 |
+
MaximumDecompressedSize = 1024
|
| 16 |
+
MegaByte = 2 ** 20
|
| 17 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SftJSONLIterableDataset(DistributedIterableDataset):
|
| 21 |
+
def __init__(
|
| 22 |
+
self, dataset_name, transform, tokenizer, frame_sampler,
|
| 23 |
+
jsonl_path_list, data_dir_list, num_used_data,
|
| 24 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
| 25 |
+
shuffle_lines=False, shuffle_seed=0,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
jsonl_path_list: list of jsonl file paths
|
| 29 |
+
data_dir_list: list of image directories containing the images of each jsonl file
|
| 30 |
+
num_used_data: list of number of sampled data points for each jsonl
|
| 31 |
+
"""
|
| 32 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
| 33 |
+
self.transform = transform
|
| 34 |
+
self.tokenizer = tokenizer
|
| 35 |
+
self.frame_sampler = frame_sampler
|
| 36 |
+
self.data_status = data_status
|
| 37 |
+
self.data_paths = self.get_data_paths(
|
| 38 |
+
jsonl_path_list,
|
| 39 |
+
data_dir_list,
|
| 40 |
+
num_used_data,
|
| 41 |
+
shuffle_lines,
|
| 42 |
+
shuffle_seed,
|
| 43 |
+
)
|
| 44 |
+
self.set_epoch()
|
| 45 |
+
|
| 46 |
+
def get_data_paths(
|
| 47 |
+
self,
|
| 48 |
+
jsonl_path_list,
|
| 49 |
+
data_dir_list,
|
| 50 |
+
num_used_data,
|
| 51 |
+
shuffle_lines,
|
| 52 |
+
shuffle_seed,
|
| 53 |
+
):
|
| 54 |
+
data_paths = []
|
| 55 |
+
for jsonl_path, image_dir, num_data_point in zip(
|
| 56 |
+
jsonl_path_list, data_dir_list, num_used_data
|
| 57 |
+
):
|
| 58 |
+
with open(jsonl_path, 'r') as f:
|
| 59 |
+
raw_data = f.readlines()
|
| 60 |
+
if shuffle_lines:
|
| 61 |
+
self.rng.seed(shuffle_seed)
|
| 62 |
+
self.rng.shuffle(raw_data)
|
| 63 |
+
raw_data = raw_data[:num_data_point]
|
| 64 |
+
data_paths.extend([(json_data, image_dir) for json_data in raw_data])
|
| 65 |
+
return data_paths
|
| 66 |
+
|
| 67 |
+
def change_format(self, data, num_images):
|
| 68 |
+
elements = []
|
| 69 |
+
for conversation in data['conversations']:
|
| 70 |
+
if conversation['from'] == 'human':
|
| 71 |
+
if '<image>' not in conversation['value']:
|
| 72 |
+
elements.append({
|
| 73 |
+
'type': 'text',
|
| 74 |
+
'has_loss': 0,
|
| 75 |
+
'text': conversation['value'],
|
| 76 |
+
})
|
| 77 |
+
else:
|
| 78 |
+
text_list = conversation['value'].split('<image>')
|
| 79 |
+
for idx, text in enumerate(text_list):
|
| 80 |
+
if text.strip() != '':
|
| 81 |
+
elements.append({
|
| 82 |
+
'type': 'text',
|
| 83 |
+
'has_loss': 0,
|
| 84 |
+
'text': text.strip(),
|
| 85 |
+
})
|
| 86 |
+
if (idx != len(text_list) - 1) and (idx < num_images):
|
| 87 |
+
elements.append({'type': 'image',})
|
| 88 |
+
elif conversation['from'] == 'gpt':
|
| 89 |
+
elements.append({
|
| 90 |
+
'type': 'text',
|
| 91 |
+
'has_loss': 1,
|
| 92 |
+
'text': conversation['value'],
|
| 93 |
+
})
|
| 94 |
+
return elements
|
| 95 |
+
|
| 96 |
+
def __iter__(self):
|
| 97 |
+
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
| 98 |
+
if self.data_status is not None:
|
| 99 |
+
row_start_id = self.data_status[worker_id] + 1
|
| 100 |
+
else:
|
| 101 |
+
row_start_id = 0
|
| 102 |
+
transform_stride = self.transform.stride
|
| 103 |
+
|
| 104 |
+
print(
|
| 105 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
| 106 |
+
f"resuming data at row#{row_start_id}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
while True:
|
| 110 |
+
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
|
| 111 |
+
for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
|
| 112 |
+
num_tokens = 0
|
| 113 |
+
image_tensor_list = []
|
| 114 |
+
text_ids_list = []
|
| 115 |
+
sequence_plan = []
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
data_item = json.loads(data)
|
| 119 |
+
raw_images = None
|
| 120 |
+
if 'image' in data_item:
|
| 121 |
+
if type(data_item['image']) == list:
|
| 122 |
+
raw_images = [
|
| 123 |
+
pil_img2rgb(Image.open(os.path.join(image_dir, image)))
|
| 124 |
+
for image in data_item['image']
|
| 125 |
+
]
|
| 126 |
+
else:
|
| 127 |
+
raw_images = [
|
| 128 |
+
pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
|
| 129 |
+
]
|
| 130 |
+
elif 'video' in data_item:
|
| 131 |
+
raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
|
| 132 |
+
special_tokens = '<image>' * len(raw_images)
|
| 133 |
+
for item in data_item['conversations']:
|
| 134 |
+
if '<video>' in item['value']:
|
| 135 |
+
item['value'] = item['value'].replace('<video>', special_tokens)
|
| 136 |
+
break
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError("Cannot find <video> in the conversation!")
|
| 139 |
+
except:
|
| 140 |
+
traceback.print_exc()
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
if raw_images:
|
| 144 |
+
for raw_image in raw_images:
|
| 145 |
+
image_tensor = self.transform(raw_image, img_num=len(raw_images))
|
| 146 |
+
image_tensor_list.append(image_tensor)
|
| 147 |
+
height, width = image_tensor.shape[1:]
|
| 148 |
+
num_tokens += width * height // transform_stride ** 2
|
| 149 |
+
|
| 150 |
+
elements = self.change_format(data_item, len(image_tensor_list))
|
| 151 |
+
|
| 152 |
+
for item in elements:
|
| 153 |
+
if item['type'] == 'text':
|
| 154 |
+
text_data = item['text']
|
| 155 |
+
text_ids = self.tokenizer.encode(text_data)
|
| 156 |
+
if len(text_ids) > 0:
|
| 157 |
+
text_ids_list.append(text_ids)
|
| 158 |
+
num_tokens += len(text_ids)
|
| 159 |
+
current_plan = {
|
| 160 |
+
'type': 'text',
|
| 161 |
+
'enable_cfg': 0,
|
| 162 |
+
'loss': item['has_loss'],
|
| 163 |
+
'special_token_loss': 0,
|
| 164 |
+
'special_token_label': None,
|
| 165 |
+
}
|
| 166 |
+
sequence_plan.append(current_plan)
|
| 167 |
+
elif item['type'] == 'image':
|
| 168 |
+
current_plan = {
|
| 169 |
+
'type': 'vit_image',
|
| 170 |
+
'enable_cfg': 0,
|
| 171 |
+
'loss': 0,
|
| 172 |
+
'special_token_loss': 0,
|
| 173 |
+
'special_token_label': None,
|
| 174 |
+
}
|
| 175 |
+
sequence_plan.append(current_plan)
|
| 176 |
+
|
| 177 |
+
has_loss = [item['loss'] for item in sequence_plan]
|
| 178 |
+
if sum(has_loss) == 0:
|
| 179 |
+
print(f'No loss defined, skipped.')
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
yield dict(
|
| 183 |
+
image_tensor_list=image_tensor_list,
|
| 184 |
+
text_ids_list=text_ids_list,
|
| 185 |
+
sequence_plan=sequence_plan,
|
| 186 |
+
num_tokens=num_tokens,
|
| 187 |
+
data_indexes={
|
| 188 |
+
"data_indexes": row_idx,
|
| 189 |
+
"worker_id": worker_id,
|
| 190 |
+
"dataset_name": self.dataset_name,
|
| 191 |
+
}
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
row_start_id = 0
|
| 195 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
download_model.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
HF_HOME = "/mnt/wsfuse/kaiyuyue/cache/huggingface"
|
| 4 |
+
repo_id = "multimodal-reasoning-lab/Bagel-Zebra-CoT"
|
| 5 |
+
|
| 6 |
+
snapshot_download(
|
| 7 |
+
cache_dir=HF_HOME,
|
| 8 |
+
repo_id=repo_id,
|
| 9 |
+
local_dir_use_symlinks=False,
|
| 10 |
+
resume_download=True,
|
| 11 |
+
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
|
| 12 |
+
)
|
inference.ipynb
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"tags": []
|
| 8 |
+
},
|
| 9 |
+
"outputs": [],
|
| 10 |
+
"source": [
|
| 11 |
+
"# Copyright 2025 Bytedance Ltd. and/or its affiliates.\n",
|
| 12 |
+
"# SPDX-License-Identifier: Apache-2.0"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"metadata": {
|
| 19 |
+
"tags": []
|
| 20 |
+
},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"%load_ext autoreload\n",
|
| 24 |
+
"%autoreload 2"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {
|
| 31 |
+
"tags": []
|
| 32 |
+
},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"import os\n",
|
| 36 |
+
"from copy import deepcopy\n",
|
| 37 |
+
"from typing import (\n",
|
| 38 |
+
" Any,\n",
|
| 39 |
+
" AsyncIterable,\n",
|
| 40 |
+
" Callable,\n",
|
| 41 |
+
" Dict,\n",
|
| 42 |
+
" Generator,\n",
|
| 43 |
+
" List,\n",
|
| 44 |
+
" NamedTuple,\n",
|
| 45 |
+
" Optional,\n",
|
| 46 |
+
" Tuple,\n",
|
| 47 |
+
" Union,\n",
|
| 48 |
+
")\n",
|
| 49 |
+
"import requests\n",
|
| 50 |
+
"from io import BytesIO\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"from PIL import Image\n",
|
| 53 |
+
"import torch\n",
|
| 54 |
+
"from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"from data.transforms import ImageTransform\n",
|
| 57 |
+
"from data.data_utils import pil_img2rgb, add_special_tokens\n",
|
| 58 |
+
"from modeling.bagel import (\n",
|
| 59 |
+
" BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel\n",
|
| 60 |
+
")\n",
|
| 61 |
+
"from modeling.qwen2 import Qwen2Tokenizer\n",
|
| 62 |
+
"from modeling.bagel.qwen2_navit import NaiveCache\n",
|
| 63 |
+
"from modeling.autoencoder import load_ae\n",
|
| 64 |
+
"from safetensors.torch import load_file"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## Model Initialization"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {
|
| 78 |
+
"tags": []
|
| 79 |
+
},
|
| 80 |
+
"outputs": [],
|
| 81 |
+
"source": [
|
| 82 |
+
"model_path = \"/path/to/BAGEL-7B-MoT/weights\" # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# LLM config preparing\n",
|
| 85 |
+
"llm_config = Qwen2Config.from_json_file(os.path.join(model_path, \"llm_config.json\"))\n",
|
| 86 |
+
"llm_config.qk_norm = True\n",
|
| 87 |
+
"llm_config.tie_word_embeddings = False\n",
|
| 88 |
+
"llm_config.layer_module = \"Qwen2MoTDecoderLayer\"\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# ViT config preparing\n",
|
| 91 |
+
"vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, \"vit_config.json\"))\n",
|
| 92 |
+
"vit_config.rope = False\n",
|
| 93 |
+
"vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# VAE loading\n",
|
| 96 |
+
"vae_model, vae_config = load_ae(local_path=os.path.join(model_path, \"ae.safetensors\"))\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# Bagel config preparing\n",
|
| 99 |
+
"config = BagelConfig(\n",
|
| 100 |
+
" visual_gen=True,\n",
|
| 101 |
+
" visual_und=True,\n",
|
| 102 |
+
" llm_config=llm_config, \n",
|
| 103 |
+
" vit_config=vit_config,\n",
|
| 104 |
+
" vae_config=vae_config,\n",
|
| 105 |
+
" vit_max_num_patch_per_side=70,\n",
|
| 106 |
+
" connector_act='gelu_pytorch_tanh',\n",
|
| 107 |
+
" latent_patch_size=2,\n",
|
| 108 |
+
" max_latent_size=64,\n",
|
| 109 |
+
")\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"with init_empty_weights():\n",
|
| 112 |
+
" language_model = Qwen2ForCausalLM(llm_config)\n",
|
| 113 |
+
" vit_model = SiglipVisionModel(vit_config)\n",
|
| 114 |
+
" model = Bagel(language_model, vit_model, config)\n",
|
| 115 |
+
" model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# Tokenizer Preparing\n",
|
| 118 |
+
"tokenizer = Qwen2Tokenizer.from_pretrained(model_path)\n",
|
| 119 |
+
"tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# Image Transform Preparing\n",
|
| 122 |
+
"vae_transform = ImageTransform(1024, 512, 16)\n",
|
| 123 |
+
"vit_transform = ImageTransform(980, 224, 14)"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "markdown",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"source": [
|
| 130 |
+
"## Model Loading and Multi GPU Infernece Preparing"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "code",
|
| 135 |
+
"execution_count": null,
|
| 136 |
+
"metadata": {
|
| 137 |
+
"tags": []
|
| 138 |
+
},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"max_mem_per_gpu = \"80GiB\" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"device_map = infer_auto_device_map(\n",
|
| 144 |
+
" model,\n",
|
| 145 |
+
" max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},\n",
|
| 146 |
+
" no_split_module_classes=[\"Bagel\", \"Qwen2MoTDecoderLayer\"],\n",
|
| 147 |
+
")\n",
|
| 148 |
+
"print(device_map)\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"same_device_modules = [\n",
|
| 151 |
+
" 'language_model.model.embed_tokens',\n",
|
| 152 |
+
" 'time_embedder',\n",
|
| 153 |
+
" 'latent_pos_embed',\n",
|
| 154 |
+
" 'vae2llm',\n",
|
| 155 |
+
" 'llm2vae',\n",
|
| 156 |
+
" 'connector',\n",
|
| 157 |
+
" 'vit_pos_embed'\n",
|
| 158 |
+
"]\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"if torch.cuda.device_count() == 1:\n",
|
| 161 |
+
" first_device = device_map.get(same_device_modules[0], \"cuda:0\")\n",
|
| 162 |
+
" for k in same_device_modules:\n",
|
| 163 |
+
" if k in device_map:\n",
|
| 164 |
+
" device_map[k] = first_device\n",
|
| 165 |
+
" else:\n",
|
| 166 |
+
" device_map[k] = \"cuda:0\"\n",
|
| 167 |
+
"else:\n",
|
| 168 |
+
" first_device = device_map.get(same_device_modules[0])\n",
|
| 169 |
+
" for k in same_device_modules:\n",
|
| 170 |
+
" if k in device_map:\n",
|
| 171 |
+
" device_map[k] = first_device\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"# Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8\n",
|
| 174 |
+
"model = load_checkpoint_and_dispatch(\n",
|
| 175 |
+
" model,\n",
|
| 176 |
+
" checkpoint=os.path.join(model_path, \"ema.safetensors\"),\n",
|
| 177 |
+
" device_map=device_map,\n",
|
| 178 |
+
" offload_buffers=True,\n",
|
| 179 |
+
" dtype=torch.bfloat16,\n",
|
| 180 |
+
" force_hooks=True,\n",
|
| 181 |
+
" offload_folder=\"/tmp/offload\"\n",
|
| 182 |
+
")\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"model = model.eval()\n",
|
| 185 |
+
"print('Model loaded')"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": []
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "markdown",
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"source": [
|
| 199 |
+
"## Inferencer Preparing "
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"metadata": {
|
| 206 |
+
"tags": []
|
| 207 |
+
},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"from inferencer import InterleaveInferencer\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"inferencer = InterleaveInferencer(\n",
|
| 213 |
+
" model=model, \n",
|
| 214 |
+
" vae_model=vae_model, \n",
|
| 215 |
+
" tokenizer=tokenizer, \n",
|
| 216 |
+
" vae_transform=vae_transform, \n",
|
| 217 |
+
" vit_transform=vit_transform, \n",
|
| 218 |
+
" new_token_ids=new_token_ids\n",
|
| 219 |
+
")"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": null,
|
| 225 |
+
"metadata": {
|
| 226 |
+
"tags": []
|
| 227 |
+
},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"import random\n",
|
| 231 |
+
"import numpy as np\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"seed = 42\n",
|
| 234 |
+
"random.seed(seed)\n",
|
| 235 |
+
"np.random.seed(seed)\n",
|
| 236 |
+
"torch.manual_seed(seed)\n",
|
| 237 |
+
"if torch.cuda.is_available():\n",
|
| 238 |
+
" torch.cuda.manual_seed(seed)\n",
|
| 239 |
+
" torch.cuda.manual_seed_all(seed)\n",
|
| 240 |
+
"torch.backends.cudnn.deterministic = True\n",
|
| 241 |
+
"torch.backends.cudnn.benchmark = False"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "markdown",
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"source": [
|
| 248 |
+
"**About Inference Hyperparameters:**\n",
|
| 249 |
+
"- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.\n",
|
| 250 |
+
"- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.\n",
|
| 251 |
+
"- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.\n",
|
| 252 |
+
"- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).\n",
|
| 253 |
+
"- **`num_timesteps`:** Total denoising steps. Typical: `50`.\n",
|
| 254 |
+
"- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.\n",
|
| 255 |
+
"- **`cfg_renorm_type`:** CFG-Renorm method: \n",
|
| 256 |
+
" - `global`: Normalize over all tokens and channels (default for T2I).\n",
|
| 257 |
+
" - `channel`: Normalize across channels for each token.\n",
|
| 258 |
+
" - `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).\n",
|
| 259 |
+
"- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**\n"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "markdown",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"source": [
|
| 266 |
+
"## Image Generation"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "code",
|
| 271 |
+
"execution_count": null,
|
| 272 |
+
"metadata": {
|
| 273 |
+
"tags": []
|
| 274 |
+
},
|
| 275 |
+
"outputs": [],
|
| 276 |
+
"source": [
|
| 277 |
+
"inference_hyper=dict(\n",
|
| 278 |
+
" cfg_text_scale=4.0,\n",
|
| 279 |
+
" cfg_img_scale=1.0,\n",
|
| 280 |
+
" cfg_interval=[0.4, 1.0],\n",
|
| 281 |
+
" timestep_shift=3.0,\n",
|
| 282 |
+
" num_timesteps=50,\n",
|
| 283 |
+
" cfg_renorm_min=0.0,\n",
|
| 284 |
+
" cfg_renorm_type=\"global\",\n",
|
| 285 |
+
")"
|
| 286 |
+
]
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"cell_type": "code",
|
| 290 |
+
"execution_count": null,
|
| 291 |
+
"metadata": {
|
| 292 |
+
"tags": []
|
| 293 |
+
},
|
| 294 |
+
"outputs": [],
|
| 295 |
+
"source": [
|
| 296 |
+
"prompt = \"A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.\"\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"print(prompt)\n",
|
| 299 |
+
"print('-' * 10)\n",
|
| 300 |
+
"output_dict = inferencer(text=prompt, **inference_hyper)\n",
|
| 301 |
+
"display(output_dict['image'])"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "markdown",
|
| 306 |
+
"metadata": {
|
| 307 |
+
"tags": []
|
| 308 |
+
},
|
| 309 |
+
"source": [
|
| 310 |
+
"## Image Generation with Think"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "code",
|
| 315 |
+
"execution_count": null,
|
| 316 |
+
"metadata": {
|
| 317 |
+
"tags": []
|
| 318 |
+
},
|
| 319 |
+
"outputs": [],
|
| 320 |
+
"source": [
|
| 321 |
+
"inference_hyper=dict(\n",
|
| 322 |
+
" max_think_token_n=1000,\n",
|
| 323 |
+
" do_sample=False,\n",
|
| 324 |
+
" # text_temperature=0.3,\n",
|
| 325 |
+
" cfg_text_scale=4.0,\n",
|
| 326 |
+
" cfg_img_scale=1.0,\n",
|
| 327 |
+
" cfg_interval=[0.4, 1.0],\n",
|
| 328 |
+
" timestep_shift=3.0,\n",
|
| 329 |
+
" num_timesteps=50,\n",
|
| 330 |
+
" cfg_renorm_min=0.0,\n",
|
| 331 |
+
" cfg_renorm_type=\"global\",\n",
|
| 332 |
+
")"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "code",
|
| 337 |
+
"execution_count": null,
|
| 338 |
+
"metadata": {
|
| 339 |
+
"tags": []
|
| 340 |
+
},
|
| 341 |
+
"outputs": [],
|
| 342 |
+
"source": [
|
| 343 |
+
"prompt = 'a car made of small cars'\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"print(prompt)\n",
|
| 346 |
+
"print('-' * 10)\n",
|
| 347 |
+
"output_dict = inferencer(text=prompt, think=True, **inference_hyper)\n",
|
| 348 |
+
"print(output_dict['text'])\n",
|
| 349 |
+
"display(output_dict['image'])"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": null,
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"outputs": [],
|
| 357 |
+
"source": []
|
| 358 |
+
},
|
| 359 |
+
{
|
| 360 |
+
"cell_type": "markdown",
|
| 361 |
+
"metadata": {},
|
| 362 |
+
"source": [
|
| 363 |
+
"## Editing"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"cell_type": "code",
|
| 368 |
+
"execution_count": null,
|
| 369 |
+
"metadata": {
|
| 370 |
+
"tags": []
|
| 371 |
+
},
|
| 372 |
+
"outputs": [],
|
| 373 |
+
"source": [
|
| 374 |
+
"inference_hyper=dict(\n",
|
| 375 |
+
" cfg_text_scale=4.0,\n",
|
| 376 |
+
" cfg_img_scale=2.0,\n",
|
| 377 |
+
" cfg_interval=[0.0, 1.0],\n",
|
| 378 |
+
" timestep_shift=3.0,\n",
|
| 379 |
+
" num_timesteps=50,\n",
|
| 380 |
+
" cfg_renorm_min=0.0,\n",
|
| 381 |
+
" cfg_renorm_type=\"text_channel\",\n",
|
| 382 |
+
")"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": null,
|
| 388 |
+
"metadata": {
|
| 389 |
+
"tags": []
|
| 390 |
+
},
|
| 391 |
+
"outputs": [],
|
| 392 |
+
"source": [
|
| 393 |
+
"image = Image.open('test_images/women.jpg')\n",
|
| 394 |
+
"prompt = 'She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.'\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"display(image)\n",
|
| 397 |
+
"print(prompt)\n",
|
| 398 |
+
"print('-'*10)\n",
|
| 399 |
+
"output_dict = inferencer(image=image, text=prompt, **inference_hyper)\n",
|
| 400 |
+
"display(output_dict['image'])"
|
| 401 |
+
]
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"cell_type": "code",
|
| 405 |
+
"execution_count": null,
|
| 406 |
+
"metadata": {},
|
| 407 |
+
"outputs": [],
|
| 408 |
+
"source": []
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"cell_type": "markdown",
|
| 412 |
+
"metadata": {},
|
| 413 |
+
"source": [
|
| 414 |
+
"## Edit with Think"
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"cell_type": "code",
|
| 419 |
+
"execution_count": null,
|
| 420 |
+
"metadata": {
|
| 421 |
+
"tags": []
|
| 422 |
+
},
|
| 423 |
+
"outputs": [],
|
| 424 |
+
"source": [
|
| 425 |
+
"inference_hyper=dict(\n",
|
| 426 |
+
" max_think_token_n=1000,\n",
|
| 427 |
+
" do_sample=False,\n",
|
| 428 |
+
" # text_temperature=0.3,\n",
|
| 429 |
+
" cfg_text_scale=4.0,\n",
|
| 430 |
+
" cfg_img_scale=2.0,\n",
|
| 431 |
+
" cfg_interval=[0.0, 1.0],\n",
|
| 432 |
+
" timestep_shift=3.0,\n",
|
| 433 |
+
" num_timesteps=50,\n",
|
| 434 |
+
" cfg_renorm_min=0.0,\n",
|
| 435 |
+
" cfg_renorm_type=\"text_channel\",\n",
|
| 436 |
+
")"
|
| 437 |
+
]
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"cell_type": "code",
|
| 441 |
+
"execution_count": null,
|
| 442 |
+
"metadata": {
|
| 443 |
+
"tags": []
|
| 444 |
+
},
|
| 445 |
+
"outputs": [],
|
| 446 |
+
"source": [
|
| 447 |
+
"image = Image.open('test_images/octupusy.jpg')\n",
|
| 448 |
+
"prompt = 'Could you display the sculpture that takes after this design?'\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"display(image)\n",
|
| 451 |
+
"print('-'*10)\n",
|
| 452 |
+
"output_dict = inferencer(image=image, text=prompt, think=True, **inference_hyper)\n",
|
| 453 |
+
"print(output_dict['text'])\n",
|
| 454 |
+
"display(output_dict['image'])"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"cell_type": "code",
|
| 459 |
+
"execution_count": null,
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"outputs": [],
|
| 462 |
+
"source": []
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"cell_type": "markdown",
|
| 466 |
+
"metadata": {},
|
| 467 |
+
"source": [
|
| 468 |
+
"## Understanding"
|
| 469 |
+
]
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"cell_type": "code",
|
| 473 |
+
"execution_count": null,
|
| 474 |
+
"metadata": {
|
| 475 |
+
"tags": []
|
| 476 |
+
},
|
| 477 |
+
"outputs": [],
|
| 478 |
+
"source": [
|
| 479 |
+
"inference_hyper=dict(\n",
|
| 480 |
+
" max_think_token_n=1000,\n",
|
| 481 |
+
" do_sample=False,\n",
|
| 482 |
+
" # text_temperature=0.3,\n",
|
| 483 |
+
")"
|
| 484 |
+
]
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"cell_type": "code",
|
| 488 |
+
"execution_count": null,
|
| 489 |
+
"metadata": {
|
| 490 |
+
"tags": []
|
| 491 |
+
},
|
| 492 |
+
"outputs": [],
|
| 493 |
+
"source": [
|
| 494 |
+
"image = Image.open('test_images/meme.jpg')\n",
|
| 495 |
+
"prompt = \"Can someone explain what’s funny about this meme??\"\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"display(image)\n",
|
| 498 |
+
"print(prompt)\n",
|
| 499 |
+
"print('-'*10)\n",
|
| 500 |
+
"output_dict = inferencer(image=image, text=prompt, understanding_output=True, **inference_hyper)\n",
|
| 501 |
+
"print(output_dict['text'])"
|
| 502 |
+
]
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"cell_type": "code",
|
| 506 |
+
"execution_count": null,
|
| 507 |
+
"metadata": {},
|
| 508 |
+
"outputs": [],
|
| 509 |
+
"source": []
|
| 510 |
+
}
|
| 511 |
+
],
|
| 512 |
+
"metadata": {
|
| 513 |
+
"fileId": "1bfaa82d-51b0-4c13-9e4c-295ba28bcd8a",
|
| 514 |
+
"filePath": "/mnt/bn/seed-aws-va/chaorui/code/cdt-hf/notebooks/chat.ipynb",
|
| 515 |
+
"kernelspec": {
|
| 516 |
+
"display_name": "Python 3 (ipykernel)",
|
| 517 |
+
"language": "python",
|
| 518 |
+
"name": "python3"
|
| 519 |
+
},
|
| 520 |
+
"language_info": {
|
| 521 |
+
"codemirror_mode": {
|
| 522 |
+
"name": "ipython",
|
| 523 |
+
"version": 3
|
| 524 |
+
},
|
| 525 |
+
"file_extension": ".py",
|
| 526 |
+
"mimetype": "text/x-python",
|
| 527 |
+
"name": "python",
|
| 528 |
+
"nbconvert_exporter": "python",
|
| 529 |
+
"pygments_lexer": "ipython3",
|
| 530 |
+
"version": "3.11.2"
|
| 531 |
+
}
|
| 532 |
+
},
|
| 533 |
+
"nbformat": 4,
|
| 534 |
+
"nbformat_minor": 4
|
| 535 |
+
}
|
inferencer.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import List, Dict, Optional, Union, Any
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from data.data_utils import pil_img2rgb
|
| 11 |
+
from modeling.bagel.qwen2_navit import NaiveCache
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
VLM_THINK_SYSTEM_PROMPT = '''Generation Instructions: You should first think about the reasoning process in the mind and then provide the user with the answer.
|
| 16 |
+
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
|
| 17 |
+
|
| 18 |
+
GEN_THINK_SYSTEM_PROMPT = '''Generation Instructions: You should first think about the planning process in the mind and then generate the image.
|
| 19 |
+
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class InterleaveInferencer:
|
| 23 |
+
def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
|
| 24 |
+
self.model = model
|
| 25 |
+
self.vae_model = vae_model
|
| 26 |
+
self.tokenizer = tokenizer
|
| 27 |
+
self.vae_transform = vae_transform
|
| 28 |
+
self.vit_transform = vit_transform
|
| 29 |
+
self.new_token_ids = new_token_ids
|
| 30 |
+
|
| 31 |
+
def init_gen_context(self):
|
| 32 |
+
gen_context = {
|
| 33 |
+
'kv_lens': [0],
|
| 34 |
+
'ropes': [0],
|
| 35 |
+
'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
|
| 36 |
+
}
|
| 37 |
+
return gen_context
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def update_context_text(self, text, gen_context):
|
| 41 |
+
# used for interleave data, currently only support 1 data inference,
|
| 42 |
+
|
| 43 |
+
past_key_values = gen_context['past_key_values']
|
| 44 |
+
kv_lens = gen_context['kv_lens']
|
| 45 |
+
ropes = gen_context['ropes']
|
| 46 |
+
generation_input, kv_lens, ropes = self.model.prepare_prompts(
|
| 47 |
+
curr_kvlens=kv_lens,
|
| 48 |
+
curr_rope=ropes,
|
| 49 |
+
prompts=[text],
|
| 50 |
+
tokenizer=self.tokenizer,
|
| 51 |
+
new_token_ids=self.new_token_ids,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
|
| 55 |
+
gen_context['kv_lens'] = kv_lens
|
| 56 |
+
gen_context['ropes'] = ropes
|
| 57 |
+
gen_context['past_key_values'] = past_key_values
|
| 58 |
+
|
| 59 |
+
return gen_context
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def update_context_image(self, image, gen_context, vae=True, vit=True):
|
| 63 |
+
# used for interleave data, currently only support 1 data inference,
|
| 64 |
+
|
| 65 |
+
assert vae or vit
|
| 66 |
+
past_key_values = gen_context['past_key_values']
|
| 67 |
+
kv_lens = gen_context['kv_lens']
|
| 68 |
+
ropes = gen_context['ropes']
|
| 69 |
+
|
| 70 |
+
if vae:
|
| 71 |
+
## update vae
|
| 72 |
+
generation_input, kv_lens, ropes = self.model.prepare_vae_images(
|
| 73 |
+
curr_kvlens=kv_lens,
|
| 74 |
+
curr_rope=ropes,
|
| 75 |
+
images=[image],
|
| 76 |
+
transforms=self.vae_transform,
|
| 77 |
+
new_token_ids=self.new_token_ids,
|
| 78 |
+
)
|
| 79 |
+
past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
|
| 80 |
+
|
| 81 |
+
if vit:
|
| 82 |
+
## update vit
|
| 83 |
+
generation_input, kv_lens, ropes = self.model.prepare_vit_images(
|
| 84 |
+
curr_kvlens=kv_lens,
|
| 85 |
+
curr_rope=ropes,
|
| 86 |
+
images=[image],
|
| 87 |
+
transforms=self.vit_transform,
|
| 88 |
+
new_token_ids=self.new_token_ids,
|
| 89 |
+
)
|
| 90 |
+
past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
|
| 91 |
+
|
| 92 |
+
gen_context['kv_lens'] = kv_lens
|
| 93 |
+
gen_context['ropes'] = ropes
|
| 94 |
+
gen_context['past_key_values'] = past_key_values
|
| 95 |
+
|
| 96 |
+
return gen_context
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def gen_image(
|
| 100 |
+
self,
|
| 101 |
+
image_shape,
|
| 102 |
+
gen_context,
|
| 103 |
+
cfg_text_scale=4.0,
|
| 104 |
+
cfg_img_scale=1.5,
|
| 105 |
+
|
| 106 |
+
cfg_text_precontext=None,
|
| 107 |
+
cfg_img_precontext=None,
|
| 108 |
+
cfg_interval=(0.4, 1.0),
|
| 109 |
+
cfg_renorm_min=0.0,
|
| 110 |
+
cfg_renorm_type="global",
|
| 111 |
+
|
| 112 |
+
num_timesteps=50,
|
| 113 |
+
timestep_shift=3.0
|
| 114 |
+
):
|
| 115 |
+
# print(cfg_renorm_type)
|
| 116 |
+
past_key_values = gen_context['past_key_values']
|
| 117 |
+
kv_lens = gen_context['kv_lens']
|
| 118 |
+
ropes = gen_context['ropes']
|
| 119 |
+
generation_input = self.model.prepare_vae_latent(
|
| 120 |
+
curr_kvlens=kv_lens,
|
| 121 |
+
curr_rope=ropes,
|
| 122 |
+
image_sizes=[image_shape],
|
| 123 |
+
new_token_ids=self.new_token_ids,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# text cfg
|
| 127 |
+
cfg_text_past_key_values = cfg_text_precontext['past_key_values']
|
| 128 |
+
kv_lens_cfg = cfg_text_precontext['kv_lens']
|
| 129 |
+
ropes_cfg = cfg_text_precontext['ropes']
|
| 130 |
+
generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
|
| 131 |
+
curr_kvlens=kv_lens_cfg,
|
| 132 |
+
curr_rope=ropes_cfg,
|
| 133 |
+
image_sizes=[image_shape],
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# img cfg
|
| 137 |
+
cfg_img_past_key_values = cfg_img_precontext['past_key_values']
|
| 138 |
+
kv_lens_cfg = cfg_img_precontext['kv_lens']
|
| 139 |
+
ropes_cfg = cfg_img_precontext['ropes']
|
| 140 |
+
generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
|
| 141 |
+
curr_kvlens=kv_lens_cfg,
|
| 142 |
+
curr_rope=ropes_cfg,
|
| 143 |
+
image_sizes=[image_shape],
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
unpacked_latent = self.model.generate_image(
|
| 147 |
+
past_key_values=past_key_values,
|
| 148 |
+
cfg_text_past_key_values=cfg_text_past_key_values,
|
| 149 |
+
cfg_img_past_key_values=cfg_img_past_key_values,
|
| 150 |
+
num_timesteps=num_timesteps,
|
| 151 |
+
cfg_text_scale=cfg_text_scale,
|
| 152 |
+
cfg_img_scale=cfg_img_scale,
|
| 153 |
+
cfg_interval=cfg_interval,
|
| 154 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 155 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 156 |
+
timestep_shift=timestep_shift,
|
| 157 |
+
**generation_input,
|
| 158 |
+
cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
|
| 159 |
+
cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
|
| 160 |
+
cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
|
| 161 |
+
cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
|
| 162 |
+
cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
|
| 163 |
+
cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
|
| 164 |
+
cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
|
| 165 |
+
cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
image = self.decode_image(unpacked_latent[0], image_shape)
|
| 169 |
+
return image
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def decode_image(self, latent, image_shape):
|
| 173 |
+
H, W = image_shape
|
| 174 |
+
h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
|
| 175 |
+
|
| 176 |
+
latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
|
| 177 |
+
latent = torch.einsum("nhwpqc->nchpwq", latent)
|
| 178 |
+
latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
|
| 179 |
+
image = self.vae_model.decode(latent)
|
| 180 |
+
image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
|
| 181 |
+
image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
|
| 182 |
+
|
| 183 |
+
return image
|
| 184 |
+
|
| 185 |
+
@torch.no_grad()
|
| 186 |
+
def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
|
| 187 |
+
gen_context = deepcopy(gen_context)
|
| 188 |
+
past_key_values = gen_context['past_key_values']
|
| 189 |
+
kv_lens = gen_context['kv_lens']
|
| 190 |
+
ropes = gen_context['ropes']
|
| 191 |
+
|
| 192 |
+
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
|
| 193 |
+
unpacked_latent = self.model.generate_text(
|
| 194 |
+
past_key_values=past_key_values,
|
| 195 |
+
max_length=max_length,
|
| 196 |
+
do_sample=do_sample,
|
| 197 |
+
temperature=temperature,
|
| 198 |
+
end_token_id=self.new_token_ids['eos_token_id'],
|
| 199 |
+
# end_token_id=151652,
|
| 200 |
+
**generation_input,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
output = self.tokenizer.decode(unpacked_latent[:,0])
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def interleave_inference(
|
| 208 |
+
self,
|
| 209 |
+
input_lists: List[Union[str, Image.Image]],
|
| 210 |
+
understanding_output=False,
|
| 211 |
+
system_prompt=None,
|
| 212 |
+
max_think_token_n=1000,
|
| 213 |
+
do_sample=False,
|
| 214 |
+
text_temperature=0.3,
|
| 215 |
+
cfg_text_scale=3.0,
|
| 216 |
+
cfg_img_scale=1.5,
|
| 217 |
+
cfg_interval=[0.4, 1.0],
|
| 218 |
+
timestep_shift=3.0,
|
| 219 |
+
num_timesteps=50,
|
| 220 |
+
cfg_renorm_min=0.0,
|
| 221 |
+
cfg_renorm_type="global",
|
| 222 |
+
image_shapes=(1024, 1024),
|
| 223 |
+
) -> List[Union[str, Image.Image]]:
|
| 224 |
+
|
| 225 |
+
output_list = []
|
| 226 |
+
gen_context = self.init_gen_context()
|
| 227 |
+
cfg_text_context = deepcopy(gen_context)
|
| 228 |
+
cfg_img_context = deepcopy(gen_context)
|
| 229 |
+
|
| 230 |
+
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
| 231 |
+
if system_prompt:
|
| 232 |
+
gen_context = self.update_context_text(system_prompt, gen_context)
|
| 233 |
+
cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
|
| 234 |
+
|
| 235 |
+
for input_term in input_lists:
|
| 236 |
+
if isinstance(input_term, str):
|
| 237 |
+
cfg_text_context = deepcopy(gen_context)
|
| 238 |
+
gen_context = self.update_context_text(input_term, gen_context)
|
| 239 |
+
cfg_img_context = self.update_context_text(input_term, cfg_img_context)
|
| 240 |
+
|
| 241 |
+
elif isinstance(input_term, Image.Image):
|
| 242 |
+
input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
|
| 243 |
+
gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
|
| 244 |
+
|
| 245 |
+
image_shapes = input_term.size[::-1]
|
| 246 |
+
cfg_text_context = deepcopy(gen_context)
|
| 247 |
+
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(f"Unsupported input type: {type(input_term)}")
|
| 250 |
+
|
| 251 |
+
if understanding_output:
|
| 252 |
+
gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
|
| 253 |
+
output_list.append(gen_text)
|
| 254 |
+
|
| 255 |
+
else:
|
| 256 |
+
img = self.gen_image(
|
| 257 |
+
image_shapes,
|
| 258 |
+
gen_context,
|
| 259 |
+
cfg_text_precontext=cfg_text_context,
|
| 260 |
+
cfg_img_precontext=cfg_img_context,
|
| 261 |
+
|
| 262 |
+
cfg_text_scale=cfg_text_scale,
|
| 263 |
+
cfg_img_scale=cfg_img_scale,
|
| 264 |
+
cfg_interval=cfg_interval,
|
| 265 |
+
timestep_shift=timestep_shift,
|
| 266 |
+
num_timesteps=num_timesteps,
|
| 267 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 268 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
output_list.append(img)
|
| 272 |
+
|
| 273 |
+
return output_list
|
| 274 |
+
|
| 275 |
+
def __call__(
|
| 276 |
+
self,
|
| 277 |
+
image: Optional[Image.Image] = None,
|
| 278 |
+
text: Optional[str] = None,
|
| 279 |
+
**kargs
|
| 280 |
+
) -> Dict[str, Any]:
|
| 281 |
+
output_dict = {'image': None, 'text': None}
|
| 282 |
+
|
| 283 |
+
if image is None and text is None:
|
| 284 |
+
print('Please provide at least one input: either an image or text.')
|
| 285 |
+
return output_dict
|
| 286 |
+
|
| 287 |
+
input_list = []
|
| 288 |
+
if image is not None:
|
| 289 |
+
input_list.append(image)
|
| 290 |
+
if text is not None:
|
| 291 |
+
input_list.append(text)
|
| 292 |
+
|
| 293 |
+
output_list = self.interleave_inference(input_list, **kargs)
|
| 294 |
+
|
| 295 |
+
for i in output_list:
|
| 296 |
+
if isinstance(i, Image.Image):
|
| 297 |
+
output_dict['image'] = i
|
| 298 |
+
elif isinstance(i, str):
|
| 299 |
+
output_dict['text'] = i
|
| 300 |
+
return output_dict
|
infz_bf16.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
AsyncIterable,
|
| 9 |
+
Callable,
|
| 10 |
+
Dict,
|
| 11 |
+
Generator,
|
| 12 |
+
List,
|
| 13 |
+
NamedTuple,
|
| 14 |
+
Optional,
|
| 15 |
+
Tuple,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
import requests
|
| 19 |
+
from io import BytesIO
|
| 20 |
+
|
| 21 |
+
from PIL import Image
|
| 22 |
+
import torch
|
| 23 |
+
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
|
| 24 |
+
|
| 25 |
+
from data.transforms import ImageTransform
|
| 26 |
+
from data.data_utils import pil_img2rgb, add_special_tokens
|
| 27 |
+
from modeling.bagel import (
|
| 28 |
+
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
|
| 29 |
+
)
|
| 30 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 31 |
+
from modeling.bagel.qwen2_navit import NaiveCache
|
| 32 |
+
from modeling.autoencoder import load_ae
|
| 33 |
+
|
| 34 |
+
# Set paths for your trained checkpoint
|
| 35 |
+
# checkpoint_dir = "/scratch/by2593/merged_checkpoint_final"
|
| 36 |
+
origin_checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8"
|
| 37 |
+
checkpoint_dir = "/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_questionimage/0000150"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2_test/000010'
|
| 41 |
+
checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2/000150'
|
| 42 |
+
checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_v1_final/0000500'
|
| 43 |
+
checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8"
|
| 44 |
+
|
| 45 |
+
checkpoint_file = "model.safetensors"
|
| 46 |
+
# checkpoint_file = "model_bf16.safetensors"
|
| 47 |
+
|
| 48 |
+
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
|
| 49 |
+
checkpoint_path = "/scratch/by2593/Bagel-Zebra-CoT-origin/results/checkpoints_smm_semantic_part1_v1_origin/0000050/model.safetensors"
|
| 50 |
+
|
| 51 |
+
print(f"Available GPUs: {torch.cuda.device_count()}")
|
| 52 |
+
print(f"GPU memory per device:")
|
| 53 |
+
for i in range(torch.cuda.device_count()):
|
| 54 |
+
props = torch.cuda.get_device_properties(i)
|
| 55 |
+
print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB")
|
| 56 |
+
|
| 57 |
+
# LLM config preparing (use base model configs)
|
| 58 |
+
llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json"))
|
| 59 |
+
llm_config.qk_norm = True
|
| 60 |
+
llm_config.tie_word_embeddings = False
|
| 61 |
+
llm_config.layer_module = "Qwen2MoTDecoderLayer"
|
| 62 |
+
|
| 63 |
+
# ViT config preparing (use base model configs)
|
| 64 |
+
vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json"))
|
| 65 |
+
vit_config.rope = False
|
| 66 |
+
vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
|
| 67 |
+
|
| 68 |
+
# VAE loading (use base model VAE)
|
| 69 |
+
vae_model, vae_config = load_ae(local_path=os.path.join(origin_checkpoint_dir, "ae.safetensors"))
|
| 70 |
+
|
| 71 |
+
# Bagel config preparing
|
| 72 |
+
config = BagelConfig(
|
| 73 |
+
visual_gen=True,
|
| 74 |
+
visual_und=True,
|
| 75 |
+
llm_config=llm_config,
|
| 76 |
+
vit_config=vit_config,
|
| 77 |
+
vae_config=vae_config,
|
| 78 |
+
vit_max_num_patch_per_side=70,
|
| 79 |
+
connector_act='gelu_pytorch_tanh',
|
| 80 |
+
latent_patch_size=2,
|
| 81 |
+
max_latent_size=64,## 默认64,改为实际的latent尺寸
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Import the position embedding function first
|
| 85 |
+
from modeling.bagel.modeling_utils import get_2d_sincos_pos_embed
|
| 86 |
+
|
| 87 |
+
# Create model with empty weights
|
| 88 |
+
with init_empty_weights():
|
| 89 |
+
language_model = Qwen2ForCausalLM(llm_config)
|
| 90 |
+
vit_model = SiglipVisionModel(vit_config)
|
| 91 |
+
model = Bagel(language_model, vit_model, config)
|
| 92 |
+
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
|
| 93 |
+
|
| 94 |
+
# Initialize position embeddings with proper values BEFORE loading checkpoint
|
| 95 |
+
print("Initializing position embeddings before loading...")
|
| 96 |
+
|
| 97 |
+
# Initialize latent_pos_embed if it exists
|
| 98 |
+
if hasattr(model, 'latent_pos_embed'):
|
| 99 |
+
print("Initializing latent_pos_embed...")
|
| 100 |
+
pos_embed = get_2d_sincos_pos_embed(model.latent_pos_embed.hidden_size, model.latent_pos_embed.max_num_patch_per_side)
|
| 101 |
+
# Create parameter with actual values, not meta
|
| 102 |
+
model.latent_pos_embed.pos_embed = torch.nn.Parameter(
|
| 103 |
+
torch.from_numpy(pos_embed).float(), requires_grad=False
|
| 104 |
+
)
|
| 105 |
+
print(f"latent_pos_embed initialized with shape {model.latent_pos_embed.pos_embed.shape}")
|
| 106 |
+
|
| 107 |
+
# Initialize vit_pos_embed if it exists
|
| 108 |
+
if hasattr(model, 'vit_pos_embed'):
|
| 109 |
+
print("Initializing vit_pos_embed...")
|
| 110 |
+
pos_embed = get_2d_sincos_pos_embed(model.vit_pos_embed.hidden_size, model.vit_pos_embed.max_num_patch_per_side)
|
| 111 |
+
# Create parameter with actual values, not meta
|
| 112 |
+
model.vit_pos_embed.pos_embed = torch.nn.Parameter(
|
| 113 |
+
torch.from_numpy(pos_embed).float(), requires_grad=False
|
| 114 |
+
)
|
| 115 |
+
print(f"vit_pos_embed initialized with shape {model.vit_pos_embed.pos_embed.shape}")
|
| 116 |
+
|
| 117 |
+
print("Position embeddings initialized successfully")
|
| 118 |
+
|
| 119 |
+
# Tokenizer Preparing (use base model tokenizer)
|
| 120 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir)
|
| 121 |
+
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
|
| 122 |
+
|
| 123 |
+
# Image Transform Preparing
|
| 124 |
+
vae_transform = ImageTransform(1024, 512, 16)
|
| 125 |
+
vit_transform = ImageTransform(980, 512, 14)
|
| 126 |
+
|
| 127 |
+
# Device mapping for 8x80GB GPUs - use bf16 directly
|
| 128 |
+
max_mem_per_gpu = "80GiB"
|
| 129 |
+
|
| 130 |
+
print("Setting up device mapping...")
|
| 131 |
+
device_map = infer_auto_device_map(
|
| 132 |
+
model,
|
| 133 |
+
max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
|
| 134 |
+
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
|
| 135 |
+
dtype=torch.bfloat16, # Use bf16 for device mapping
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
print("Device map:", device_map)
|
| 139 |
+
|
| 140 |
+
# Handle same-device modules
|
| 141 |
+
same_device_modules = [
|
| 142 |
+
'language_model.model.embed_tokens',
|
| 143 |
+
'time_embedder',
|
| 144 |
+
'latent_pos_embed',
|
| 145 |
+
'vae2llm',
|
| 146 |
+
'llm2vae',
|
| 147 |
+
'connector',
|
| 148 |
+
'vit_pos_embed'
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
if torch.cuda.device_count() == 1:
|
| 152 |
+
first_device = device_map.get(same_device_modules[0], "cuda:0")
|
| 153 |
+
for k in same_device_modules:
|
| 154 |
+
if k in device_map:
|
| 155 |
+
device_map[k] = first_device
|
| 156 |
+
else:
|
| 157 |
+
device_map[k] = "cuda:0"
|
| 158 |
+
else:
|
| 159 |
+
first_device = device_map.get(same_device_modules[0])
|
| 160 |
+
if first_device is not None:
|
| 161 |
+
for k in same_device_modules:
|
| 162 |
+
if k in device_map:
|
| 163 |
+
device_map[k] = first_device
|
| 164 |
+
|
| 165 |
+
print("Final device map:", device_map)
|
| 166 |
+
|
| 167 |
+
# Load checkpoint directly in bf16
|
| 168 |
+
print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}")
|
| 169 |
+
print("Loading model from safetensors file...")
|
| 170 |
+
|
| 171 |
+
# Load model directly in bf16
|
| 172 |
+
model = load_checkpoint_and_dispatch(
|
| 173 |
+
model,
|
| 174 |
+
checkpoint=checkpoint_path,
|
| 175 |
+
device_map=device_map,
|
| 176 |
+
offload_buffers=False,
|
| 177 |
+
dtype=torch.bfloat16, # Load directly as bf16
|
| 178 |
+
force_hooks=True,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
model = model.eval()
|
| 182 |
+
|
| 183 |
+
print('Model loaded directly in bfloat16!')
|
| 184 |
+
print(f"Model dtype: {next(model.parameters()).dtype}")
|
| 185 |
+
|
| 186 |
+
# Position embeddings were already initialized before model loading
|
| 187 |
+
print("Position embeddings were pre-initialized before loading checkpoint")
|
| 188 |
+
|
| 189 |
+
print("Model loading completed successfully!")
|
| 190 |
+
|
| 191 |
+
# Check memory usage
|
| 192 |
+
print("GPU memory usage after loading:")
|
| 193 |
+
for i in range(torch.cuda.device_count()):
|
| 194 |
+
if torch.cuda.memory_allocated(i) > 0:
|
| 195 |
+
allocated = torch.cuda.memory_allocated(i) / 1e9
|
| 196 |
+
cached = torch.cuda.memory_reserved(i) / 1e9
|
| 197 |
+
print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached")
|
| 198 |
+
|
| 199 |
+
# Rest of inference code
|
| 200 |
+
from inferencer import InterleaveInferencer
|
| 201 |
+
|
| 202 |
+
inferencer = InterleaveInferencer(
|
| 203 |
+
model=model,
|
| 204 |
+
vae_model=vae_model,
|
| 205 |
+
tokenizer=tokenizer,
|
| 206 |
+
vae_transform=vae_transform,
|
| 207 |
+
vit_transform=vit_transform,
|
| 208 |
+
new_token_ids=new_token_ids
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
import random
|
| 212 |
+
import numpy as np
|
| 213 |
+
|
| 214 |
+
seed = 42
|
| 215 |
+
random.seed(seed)
|
| 216 |
+
np.random.seed(seed)
|
| 217 |
+
torch.manual_seed(seed)
|
| 218 |
+
if torch.cuda.is_available():
|
| 219 |
+
torch.cuda.manual_seed(seed)
|
| 220 |
+
torch.cuda.manual_seed_all(seed)
|
| 221 |
+
torch.backends.cudnn.deterministic = True
|
| 222 |
+
torch.backends.cudnn.benchmark = False
|
| 223 |
+
|
| 224 |
+
inference_hyper=dict(
|
| 225 |
+
do_sample=False,
|
| 226 |
+
text_temperature=0.0,
|
| 227 |
+
cfg_text_scale=4.0,
|
| 228 |
+
cfg_img_scale=2.0,
|
| 229 |
+
cfg_interval=[0.0, 1.0],
|
| 230 |
+
timestep_shift=3.0,
|
| 231 |
+
num_timesteps=50,
|
| 232 |
+
cfg_renorm_min=0.0,
|
| 233 |
+
cfg_renorm_type="text_channel",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving.'''
|
| 237 |
+
INTERLEAVED_SYSTEM_PROMPT = ''
|
| 238 |
+
|
| 239 |
+
# Original example (004 case) - commented out
|
| 240 |
+
# prompt = '''My goal is to generate a visual guide for constructing a specific shape using a set of blocks. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input includes 2 images of multiple blocks that will be used <image_start>[problem_image_1]<image_end><image_start>[problem_image_2]<image_end> and an image of the final desired shape<image_start>[problem_image_3]<image_end>. I need to imagine and generate images of intermediate steps, leading up to the final construction. Step 0 has been completed: a red arch block has been placed on top of the ground. The image after step 0 is provided<image_start>[problem_image_4]<image_end>. Now I need to generate the image for step 1, considering spatial relationships and stability.'''
|
| 241 |
+
|
| 242 |
+
# Use the new example data (145 case)
|
| 243 |
+
prompt = '''Based on the construction task shown below, follow the instructions to complete the build. Given the final desired shape of blocks shown in the first image<image_start>[problem_image_1]<image_end> which is viewed from a Front45 angle, perform a series of specified manipulations. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input also includes 3 images of multiple blocks that will be used.<image_start>[problem_image_2]<image_end><image_start>[problem_image_3]<image_end><image_start>[problem_image_4]<image_end> Step 0 has been completed: a orange cylinder block has been placed on top of the ground. The image after step 0 is provided.<image_start>[problem_image_5]<image_end>'''
|
| 244 |
+
|
| 245 |
+
# Load images from the new example paths (145 case)
|
| 246 |
+
image = []
|
| 247 |
+
base_path = '/scratch/by2593/project/SMM'
|
| 248 |
+
image_paths = [
|
| 249 |
+
f'{base_path}/semantic_blocks_part1/145/final_state/145_final_1.png', # problem_image_1 - final desired shape
|
| 250 |
+
f'{base_path}/SMM_data/each_block_views_diffposes/cylinder_orange.png', # problem_image_2 - orange cylinder
|
| 251 |
+
f'{base_path}/SMM_data/each_block_views_diffposes/cuboid3_yellow.png', # problem_image_3 - yellow cuboid3
|
| 252 |
+
f'{base_path}/SMM_data/each_block_views_diffposes/triangle_orange.png', # problem_image_4 - orange triangle
|
| 253 |
+
f'{base_path}/semantic_blocks_part1/145/steps/view_1/145_step0_1.png', # problem_image_5 - image after step 0
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
print("Loading input images:")
|
| 257 |
+
for i, img_path in enumerate(image_paths):
|
| 258 |
+
try:
|
| 259 |
+
img = Image.open(img_path).convert('RGB')
|
| 260 |
+
image.append(img)
|
| 261 |
+
print(f" ✓ Loaded problem_image_{i+1}: {img_path}")
|
| 262 |
+
print(f" Image size: {img.size}")
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f" ✗ Failed to load {img_path}: {e}")
|
| 265 |
+
# Create a placeholder image if file not found
|
| 266 |
+
img = Image.new('RGB', (512, 512), color='gray')
|
| 267 |
+
image.append(img)
|
| 268 |
+
print(f" ⚠ Using placeholder for problem_image_{i+1}")
|
| 269 |
+
|
| 270 |
+
print(prompt)
|
| 271 |
+
print('-'*50)
|
| 272 |
+
|
| 273 |
+
# Create output folder with timestamp
|
| 274 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 275 |
+
output_folder = f"reasoning_output_example_145_{timestamp}"
|
| 276 |
+
images_folder = os.path.join(output_folder, "images")
|
| 277 |
+
os.makedirs(images_folder, exist_ok=True)
|
| 278 |
+
|
| 279 |
+
print(f"Output will be saved to: {output_folder}")
|
| 280 |
+
|
| 281 |
+
# Save the original problem images if they exist
|
| 282 |
+
problem_image_paths = []
|
| 283 |
+
if image is not None:
|
| 284 |
+
if isinstance(image, list):
|
| 285 |
+
# Handle multiple images
|
| 286 |
+
for i, img in enumerate(image):
|
| 287 |
+
problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png")
|
| 288 |
+
relative_path = os.path.join("images", f"problem_image_{i+1}.png")
|
| 289 |
+
img.save(problem_image_path)
|
| 290 |
+
problem_image_paths.append(relative_path)
|
| 291 |
+
print(f"Problem image {i+1} saved at '{problem_image_path}'")
|
| 292 |
+
else:
|
| 293 |
+
# Handle single image
|
| 294 |
+
problem_image_path = os.path.join(images_folder, "problem_image.png")
|
| 295 |
+
relative_path = os.path.join("images", "problem_image.png")
|
| 296 |
+
image.save(problem_image_path)
|
| 297 |
+
problem_image_paths.append(relative_path)
|
| 298 |
+
print(f"Problem image saved at '{problem_image_path}'")
|
| 299 |
+
|
| 300 |
+
reasoning_text = []
|
| 301 |
+
reasoning_images = []
|
| 302 |
+
generated_image_paths = [] # Store relative paths to generated reasoning images
|
| 303 |
+
|
| 304 |
+
# Create input with multiple images properly flattened
|
| 305 |
+
if image is not None:
|
| 306 |
+
if isinstance(image, list):
|
| 307 |
+
current_input = [prompt] + image # Flatten the list of images
|
| 308 |
+
else:
|
| 309 |
+
current_input = [prompt, image]
|
| 310 |
+
else:
|
| 311 |
+
current_input = [prompt]
|
| 312 |
+
|
| 313 |
+
# Loop until no more vision_start tokens
|
| 314 |
+
iteration = 0
|
| 315 |
+
while True:
|
| 316 |
+
# Get understanding output
|
| 317 |
+
print(f"iteration: {iteration}")
|
| 318 |
+
output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
|
| 319 |
+
|
| 320 |
+
# Check for stopping conditions
|
| 321 |
+
has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0]
|
| 322 |
+
|
| 323 |
+
# Stop if we have a final answer OR if there's no vision token (no more images to generate)
|
| 324 |
+
# should_stop = has_final_answer or not has_vision_token
|
| 325 |
+
should_stop = has_final_answer
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if should_stop:
|
| 329 |
+
if output[0].strip():
|
| 330 |
+
extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
|
| 331 |
+
reasoning_text.append(extracted_text)
|
| 332 |
+
print(f"{extracted_text}")
|
| 333 |
+
current_input = current_input + [extracted_text]
|
| 334 |
+
break
|
| 335 |
+
|
| 336 |
+
extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
|
| 337 |
+
reasoning_text.append(extracted_text)
|
| 338 |
+
print(f"{extracted_text}")
|
| 339 |
+
|
| 340 |
+
# Generate image based on current reasoning
|
| 341 |
+
current_input_with_reasoning = current_input + [extracted_text]
|
| 342 |
+
output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
|
| 343 |
+
image_output = output[0]
|
| 344 |
+
|
| 345 |
+
# Save and collect the generated image
|
| 346 |
+
reasoning_images.append(image_output)
|
| 347 |
+
image_filename = f'reasoning_image_{iteration + 1}.png'
|
| 348 |
+
image_path = os.path.join(images_folder, image_filename)
|
| 349 |
+
relative_image_path = os.path.join("images", image_filename) # Relative path for JSON
|
| 350 |
+
|
| 351 |
+
image_output.save(image_path)
|
| 352 |
+
generated_image_paths.append(relative_image_path)
|
| 353 |
+
print(f"Image saved at '{image_path}'")
|
| 354 |
+
|
| 355 |
+
# Update input for next iteration
|
| 356 |
+
current_input = current_input_with_reasoning + [image_output]
|
| 357 |
+
|
| 358 |
+
iteration += 1
|
| 359 |
+
print('-'*50)
|
| 360 |
+
|
| 361 |
+
# Save reasoning data to JSON
|
| 362 |
+
reasoning_data = {
|
| 363 |
+
"timestamp": timestamp,
|
| 364 |
+
"prompt": prompt,
|
| 365 |
+
"system_prompt": INTERLEAVED_SYSTEM_PROMPT,
|
| 366 |
+
"problem_image_paths": problem_image_paths if problem_image_paths else None,
|
| 367 |
+
"response": [
|
| 368 |
+
{
|
| 369 |
+
"step": i + 1,
|
| 370 |
+
"text": text,
|
| 371 |
+
"image_path": generated_image_paths[i] if i < len(generated_image_paths) else None
|
| 372 |
+
}
|
| 373 |
+
for i, text in enumerate(reasoning_text)
|
| 374 |
+
],
|
| 375 |
+
"total_steps": len(reasoning_text),
|
| 376 |
+
"total_images": len(generated_image_paths)
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
# Save JSON file
|
| 380 |
+
json_path = os.path.join(output_folder, "reasoning_data.json")
|
| 381 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 382 |
+
json.dump(reasoning_data, f, indent=2, ensure_ascii=False)
|
| 383 |
+
|
| 384 |
+
print(f"\nReasoning complete!")
|
| 385 |
+
print(f"Output folder: {output_folder}")
|
| 386 |
+
print(f"JSON metadata: {json_path}")
|
| 387 |
+
print(f"Generated {len(generated_image_paths)} images and {len(reasoning_text)} text steps")
|
| 388 |
+
|
| 389 |
+
# python infz_bf16.py
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# import os
|
| 393 |
+
# import json
|
| 394 |
+
# from datetime import datetime
|
| 395 |
+
# from copy import deepcopy
|
| 396 |
+
# from typing import (
|
| 397 |
+
# Any,
|
| 398 |
+
# AsyncIterable,
|
| 399 |
+
# Callable,
|
| 400 |
+
# Dict,
|
| 401 |
+
# Generator,
|
| 402 |
+
# List,
|
| 403 |
+
# NamedTuple,
|
| 404 |
+
# Optional,
|
| 405 |
+
# Tuple,
|
| 406 |
+
# Union,
|
| 407 |
+
# )
|
| 408 |
+
# import requests
|
| 409 |
+
# from io import BytesIO
|
| 410 |
+
|
| 411 |
+
# from PIL import Image
|
| 412 |
+
# import torch
|
| 413 |
+
# from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
|
| 414 |
+
|
| 415 |
+
# from data.transforms import ImageTransform
|
| 416 |
+
# from data.data_utils import pil_img2rgb, add_special_tokens
|
| 417 |
+
# from modeling.bagel import (
|
| 418 |
+
# BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
|
| 419 |
+
# )
|
| 420 |
+
# from modeling.qwen2 import Qwen2Tokenizer
|
| 421 |
+
# from modeling.bagel.qwen2_navit import NaiveCache
|
| 422 |
+
# from modeling.autoencoder import load_ae
|
| 423 |
+
|
| 424 |
+
# # Set paths for your trained checkpoint
|
| 425 |
+
# checkpoint_dir = "path/to/your/HF_HOME/models/Bagel-Zebra-CoT"
|
| 426 |
+
# checkpoint_file = "model_bf16.safetensors"
|
| 427 |
+
# checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# print(f"Available GPUs: {torch.cuda.device_count()}")
|
| 431 |
+
# print(f"GPU memory per device:")
|
| 432 |
+
# for i in range(torch.cuda.device_count()):
|
| 433 |
+
# props = torch.cuda.get_device_properties(i)
|
| 434 |
+
# print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB")
|
| 435 |
+
|
| 436 |
+
# # LLM config preparing (use base model configs)
|
| 437 |
+
# llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json"))
|
| 438 |
+
# llm_config.qk_norm = True
|
| 439 |
+
# llm_config.tie_word_embeddings = False
|
| 440 |
+
# llm_config.layer_module = "Qwen2MoTDecoderLayer"
|
| 441 |
+
|
| 442 |
+
# # ViT config preparing (use base model configs)
|
| 443 |
+
# vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json"))
|
| 444 |
+
# vit_config.rope = False
|
| 445 |
+
# vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
|
| 446 |
+
|
| 447 |
+
# # VAE loading (use base model VAE)
|
| 448 |
+
# vae_model, vae_config = load_ae(local_path=os.path.join(checkpoint_dir, "ae.safetensors"))
|
| 449 |
+
|
| 450 |
+
# # Bagel config preparing
|
| 451 |
+
# config = BagelConfig(
|
| 452 |
+
# visual_gen=True,
|
| 453 |
+
# visual_und=True,
|
| 454 |
+
# llm_config=llm_config,
|
| 455 |
+
# vit_config=vit_config,
|
| 456 |
+
# vae_config=vae_config,
|
| 457 |
+
# vit_max_num_patch_per_side=70,
|
| 458 |
+
# connector_act='gelu_pytorch_tanh',
|
| 459 |
+
# latent_patch_size=2,
|
| 460 |
+
# max_latent_size=64,
|
| 461 |
+
# )
|
| 462 |
+
|
| 463 |
+
# # Create model with empty weights - IMPORTANT: Use float32 initially to match checkpoint
|
| 464 |
+
# with init_empty_weights():
|
| 465 |
+
# language_model = Qwen2ForCausalLM(llm_config)
|
| 466 |
+
# vit_model = SiglipVisionModel(vit_config)
|
| 467 |
+
# model = Bagel(language_model, vit_model, config)
|
| 468 |
+
# model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
|
| 469 |
+
|
| 470 |
+
# # Tokenizer Preparing (use base model tokenizer)
|
| 471 |
+
# tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir)
|
| 472 |
+
# tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
|
| 473 |
+
|
| 474 |
+
# # Image Transform Preparing
|
| 475 |
+
# vae_transform = ImageTransform(1024, 512, 16)
|
| 476 |
+
# vit_transform = ImageTransform(980, 512, 14)
|
| 477 |
+
|
| 478 |
+
# # Device mapping for 8x80GB GPUs - use bf16 directly
|
| 479 |
+
# max_mem_per_gpu = "80GiB"
|
| 480 |
+
|
| 481 |
+
# print("Setting up device mapping...")
|
| 482 |
+
# device_map = infer_auto_device_map(
|
| 483 |
+
# model,
|
| 484 |
+
# max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
|
| 485 |
+
# no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
|
| 486 |
+
# dtype=torch.bfloat16, # Use bf16 for device mapping
|
| 487 |
+
# )
|
| 488 |
+
|
| 489 |
+
# print("Device map:", device_map)
|
| 490 |
+
|
| 491 |
+
# # Handle same-device modules
|
| 492 |
+
# same_device_modules = [
|
| 493 |
+
# 'language_model.model.embed_tokens',
|
| 494 |
+
# 'time_embedder',
|
| 495 |
+
# 'latent_pos_embed',
|
| 496 |
+
# 'vae2llm',
|
| 497 |
+
# 'llm2vae',
|
| 498 |
+
# 'connector',
|
| 499 |
+
# 'vit_pos_embed'
|
| 500 |
+
# ]
|
| 501 |
+
|
| 502 |
+
# if torch.cuda.device_count() == 1:
|
| 503 |
+
# first_device = device_map.get(same_device_modules[0], "cuda:0")
|
| 504 |
+
# for k in same_device_modules:
|
| 505 |
+
# if k in device_map:
|
| 506 |
+
# device_map[k] = first_device
|
| 507 |
+
# else:
|
| 508 |
+
# device_map[k] = "cuda:0"
|
| 509 |
+
# else:
|
| 510 |
+
# first_device = device_map.get(same_device_modules[0])
|
| 511 |
+
# if first_device is not None:
|
| 512 |
+
# for k in same_device_modules:
|
| 513 |
+
# if k in device_map:
|
| 514 |
+
# device_map[k] = first_device
|
| 515 |
+
|
| 516 |
+
# print("Final device map:", device_map)
|
| 517 |
+
|
| 518 |
+
# # Load checkpoint directly in bf16
|
| 519 |
+
# print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}")
|
| 520 |
+
# print("Loading model from safetensors file...")
|
| 521 |
+
|
| 522 |
+
# # Load model directly in bf16
|
| 523 |
+
# model = load_checkpoint_and_dispatch(
|
| 524 |
+
# model,
|
| 525 |
+
# checkpoint=checkpoint_path,
|
| 526 |
+
# device_map=device_map,
|
| 527 |
+
# offload_buffers=False,
|
| 528 |
+
# dtype=torch.bfloat16, # Load directly as bf16
|
| 529 |
+
# force_hooks=True,
|
| 530 |
+
# )
|
| 531 |
+
|
| 532 |
+
# model = model.eval()
|
| 533 |
+
|
| 534 |
+
# print('Model loaded directly in bfloat16!')
|
| 535 |
+
# print(f"Model dtype: {next(model.parameters()).dtype}")
|
| 536 |
+
# print("Model loading completed successfully!")
|
| 537 |
+
|
| 538 |
+
# # Check memory usage
|
| 539 |
+
# print("GPU memory usage after loading:")
|
| 540 |
+
# for i in range(torch.cuda.device_count()):
|
| 541 |
+
# if torch.cuda.memory_allocated(i) > 0:
|
| 542 |
+
# allocated = torch.cuda.memory_allocated(i) / 1e9
|
| 543 |
+
# cached = torch.cuda.memory_reserved(i) / 1e9
|
| 544 |
+
# print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached")
|
| 545 |
+
|
| 546 |
+
# # Rest of inference code
|
| 547 |
+
# from inferencer import InterleaveInferencer
|
| 548 |
+
|
| 549 |
+
# inferencer = InterleaveInferencer(
|
| 550 |
+
# model=model,
|
| 551 |
+
# vae_model=vae_model,
|
| 552 |
+
# tokenizer=tokenizer,
|
| 553 |
+
# vae_transform=vae_transform,
|
| 554 |
+
# vit_transform=vit_transform,
|
| 555 |
+
# new_token_ids=new_token_ids
|
| 556 |
+
# )
|
| 557 |
+
|
| 558 |
+
# import random
|
| 559 |
+
# import numpy as np
|
| 560 |
+
|
| 561 |
+
# seed = 42
|
| 562 |
+
# random.seed(seed)
|
| 563 |
+
# np.random.seed(seed)
|
| 564 |
+
# torch.manual_seed(seed)
|
| 565 |
+
# if torch.cuda.is_available():
|
| 566 |
+
# torch.cuda.manual_seed(seed)
|
| 567 |
+
# torch.cuda.manual_seed_all(seed)
|
| 568 |
+
# torch.backends.cudnn.deterministic = True
|
| 569 |
+
# torch.backends.cudnn.benchmark = False
|
| 570 |
+
|
| 571 |
+
# inference_hyper=dict(
|
| 572 |
+
# do_sample=True,
|
| 573 |
+
# text_temperature=0.3,
|
| 574 |
+
# cfg_text_scale=4.0,
|
| 575 |
+
# cfg_img_scale=2.0,
|
| 576 |
+
# cfg_interval=[0.0, 1.0],
|
| 577 |
+
# timestep_shift=3.0,
|
| 578 |
+
# num_timesteps=50,
|
| 579 |
+
# cfg_renorm_min=0.0,
|
| 580 |
+
# cfg_renorm_type="text_channel",
|
| 581 |
+
# )
|
| 582 |
+
|
| 583 |
+
# INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving. Provide your final conclusion clearly in the format of "Final Answer: <answer here>"'''
|
| 584 |
+
|
| 585 |
+
# prompt = '''Subtract all cylinders. Add 1 red sphere. How many objects are left?'''
|
| 586 |
+
# image = Image.open('test_images/image.png')
|
| 587 |
+
|
| 588 |
+
# print(prompt)
|
| 589 |
+
# print('-'*50)
|
| 590 |
+
|
| 591 |
+
# # Create output folder with timestamp
|
| 592 |
+
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 593 |
+
# output_folder = f"reasoning_output_{timestamp}"
|
| 594 |
+
# images_folder = os.path.join(output_folder, "images")
|
| 595 |
+
# os.makedirs(images_folder, exist_ok=True)
|
| 596 |
+
|
| 597 |
+
# # Save the original problem images if they exist
|
| 598 |
+
# problem_image_paths = []
|
| 599 |
+
# if image is not None:
|
| 600 |
+
# if isinstance(image, list):
|
| 601 |
+
# # Handle multiple images
|
| 602 |
+
# for i, img in enumerate(image):
|
| 603 |
+
# problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png")
|
| 604 |
+
# relative_path = os.path.join("images", f"problem_image_{i+1}.png")
|
| 605 |
+
# img.save(problem_image_path)
|
| 606 |
+
# problem_image_paths.append(relative_path)
|
| 607 |
+
# print(f"Problem image {i+1} saved at '{problem_image_path}'")
|
| 608 |
+
# else:
|
| 609 |
+
# # Handle single image
|
| 610 |
+
# problem_image_path = os.path.join(images_folder, "problem_image.png")
|
| 611 |
+
# relative_path = os.path.join("images", "problem_image.png")
|
| 612 |
+
# image.save(problem_image_path)
|
| 613 |
+
# problem_image_paths.append(relative_path)
|
| 614 |
+
# print(f"Problem image saved at '{problem_image_path}'")
|
| 615 |
+
|
| 616 |
+
# reasoning_text = []
|
| 617 |
+
# reasoning_images = []
|
| 618 |
+
# image_paths = [] # Store relative paths to images
|
| 619 |
+
|
| 620 |
+
# # Create input with multiple images properly flattened
|
| 621 |
+
# if image is not None:
|
| 622 |
+
# if isinstance(image, list):
|
| 623 |
+
# current_input = [prompt] + image # Flatten the list of images
|
| 624 |
+
# else:
|
| 625 |
+
# current_input = [prompt, image]
|
| 626 |
+
# else:
|
| 627 |
+
# current_input = [prompt]
|
| 628 |
+
|
| 629 |
+
# # Loop until no more vision_start tokens
|
| 630 |
+
# iteration = 0
|
| 631 |
+
# while True:
|
| 632 |
+
# # Get understanding output
|
| 633 |
+
# print(f"iteration: {iteration}")
|
| 634 |
+
# output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
|
| 635 |
+
|
| 636 |
+
# # Check for stopping conditions
|
| 637 |
+
# has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0]
|
| 638 |
+
|
| 639 |
+
# # Stop if we have a final answer OR if there's no vision token (no more images to generate)
|
| 640 |
+
# # should_stop = has_final_answer or not has_vision_token
|
| 641 |
+
# should_stop = has_final_answer
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
# if should_stop:
|
| 645 |
+
# if output[0].strip():
|
| 646 |
+
# extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
|
| 647 |
+
# reasoning_text.append(extracted_text)
|
| 648 |
+
# print(f"{extracted_text}")
|
| 649 |
+
# current_input = current_input + [extracted_text]
|
| 650 |
+
# break
|
| 651 |
+
|
| 652 |
+
# extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1]
|
| 653 |
+
# reasoning_text.append(extracted_text)
|
| 654 |
+
# print(f"{extracted_text}")
|
| 655 |
+
|
| 656 |
+
# # Generate image based on current reasoning
|
| 657 |
+
# current_input_with_reasoning = current_input + [extracted_text]
|
| 658 |
+
# output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper)
|
| 659 |
+
# image_output = output[0]
|
| 660 |
+
|
| 661 |
+
# # Save and collect the generated image
|
| 662 |
+
# reasoning_images.append(image_output)
|
| 663 |
+
# image_filename = f'reasoning_image_{iteration + 1}.png'
|
| 664 |
+
# image_path = os.path.join(images_folder, image_filename)
|
| 665 |
+
# relative_image_path = os.path.join("images", image_filename) # Relative path for JSON
|
| 666 |
+
|
| 667 |
+
# image_output.save(image_path)
|
| 668 |
+
# image_paths.append(relative_image_path)
|
| 669 |
+
# print(f"Image saved at '{image_path}'")
|
| 670 |
+
|
| 671 |
+
# # Update input for next iteration
|
| 672 |
+
# current_input = current_input_with_reasoning + [image_output]
|
| 673 |
+
|
| 674 |
+
# iteration += 1
|
| 675 |
+
# print('-'*50)
|
| 676 |
+
|
| 677 |
+
# # Save reasoning data to JSON
|
| 678 |
+
# reasoning_data = {
|
| 679 |
+
# "timestamp": timestamp,
|
| 680 |
+
# "prompt": prompt,
|
| 681 |
+
# "system_prompt": INTERLEAVED_SYSTEM_PROMPT,
|
| 682 |
+
# "problem_image_paths": problem_image_paths if problem_image_paths else None,
|
| 683 |
+
# "response": [
|
| 684 |
+
# {
|
| 685 |
+
# "step": i + 1,
|
| 686 |
+
# "text": text,
|
| 687 |
+
# "image_path": image_paths[i] if i < len(image_paths) else None
|
| 688 |
+
# }
|
| 689 |
+
# for i, text in enumerate(reasoning_text)
|
| 690 |
+
# ],
|
| 691 |
+
# "total_steps": len(reasoning_text),
|
| 692 |
+
# "total_images": len(image_paths)
|
| 693 |
+
# }
|
| 694 |
+
|
| 695 |
+
# # Save JSON file
|
| 696 |
+
# json_path = os.path.join(output_folder, "reasoning_data.json")
|
| 697 |
+
# with open(json_path, 'w', encoding='utf-8') as f:
|
| 698 |
+
# json.dump(reasoning_data, f, indent=2, ensure_ascii=False)
|
| 699 |
+
|
| 700 |
+
# print(f"\nReasoning complete!")
|
| 701 |
+
# print(f"Output folder: {output_folder}")
|
| 702 |
+
# print(f"JSON metadata: {json_path}")
|
| 703 |
+
# print(f"Generated {len(image_paths)} images and {len(reasoning_text)} text steps")
|
| 704 |
+
|
modeling/bagel/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from .bagel import BagelConfig, Bagel
|
| 6 |
+
from .qwen2_navit import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
| 7 |
+
from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'BagelConfig',
|
| 12 |
+
'Bagel',
|
| 13 |
+
'Qwen2Config',
|
| 14 |
+
'Qwen2Model',
|
| 15 |
+
'Qwen2ForCausalLM',
|
| 16 |
+
'SiglipVisionConfig',
|
| 17 |
+
'SiglipVisionModel',
|
| 18 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
decord==0.6.0
|
| 2 |
+
einops==0.8.1
|
| 3 |
+
huggingface_hub==0.29.1
|
| 4 |
+
matplotlib==3.7.0
|
| 5 |
+
numpy==1.24.4
|
| 6 |
+
opencv-python-headless
|
| 7 |
+
pyarrow==11.0.0
|
| 8 |
+
PyYAML==6.0.2
|
| 9 |
+
Requests==2.32.3
|
| 10 |
+
safetensors==0.4.5
|
| 11 |
+
scipy==1.10.1
|
| 12 |
+
sentencepiece==0.1.99
|
| 13 |
+
torch==2.5.1
|
| 14 |
+
torchvision==0.20.1
|
| 15 |
+
transformers==4.49.0
|
| 16 |
+
accelerate>=0.34.0
|
| 17 |
+
wandb
|
| 18 |
+
gradio
|
| 19 |
+
setuptools
|
| 20 |
+
wheel
|
| 21 |
+
ninja
|
| 22 |
+
bitsandbytes
|
| 23 |
+
xlsxwriter
|
| 24 |
+
triton ; sys_platform != 'win32'
|
| 25 |
+
triton-windows ; sys_platform == 'win32'
|