Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- assets/arch.png +3 -0
- assets/bagel-cot-example.png +3 -0
- assets/emerging_curves.png +3 -0
- assets/teaser.webp +3 -0
- assets/zebra_cot_datacard.png +3 -0
- data/__init__.py +2 -0
- data/configs/example.yaml +50 -0
- data/configs/example_smm_random.yaml +50 -0
- data/dataset_base.py +768 -0
- data/dataset_info.py +46 -0
- data/distributed_iterable_dataset.py +58 -0
- data/interleave_datasets/edit_dataset.py +72 -0
- data/interleave_datasets/interleave_t2i_dataset.py +218 -0
- data/interleave_datasets/think_trace_dataset.py +289 -0
- modeling/__init__.py +4 -0
- modeling/autoencoder.py +360 -0
- modeling/bagel/bagel.py +1068 -0
- modeling/bagel/modeling_utils.py +144 -0
- modeling/bagel/qwen2_navit.py +1157 -0
- modeling/bagel/siglip_navit.py +402 -0
- modeling/qwen2/__init__.py +68 -0
- modeling/qwen2/configuration_qwen2.py +179 -0
- modeling/qwen2/modeling_qwen2.py +929 -0
- modeling/qwen2/tokenization_qwen2.py +328 -0
- modeling/qwen2/tokenization_qwen2_fast.py +123 -0
- modeling/siglip/__init__.py +98 -0
- modeling/siglip/configuration_siglip.py +287 -0
- modeling/siglip/convert_siglip_to_hf.py +401 -0
- modeling/siglip/image_processing_siglip.py +230 -0
- modeling/siglip/modeling_siglip.py +1557 -0
- modeling/siglip/processing_siglip.py +131 -0
- modeling/siglip/tokenization_siglip.py +364 -0
- run.err +150 -0
- run.out +871 -0
- scripts/eval/eval_vlm.sh +27 -0
- scripts/eval/run_eval_vlm.sh +19 -0
- scripts/eval/run_gedit.sh +57 -0
- scripts/eval/run_geneval.sh +41 -0
- scripts/eval/run_imgedit.sh +42 -0
- scripts/eval/run_kris.sh +50 -0
- scripts/eval/run_rise.sh +30 -0
- scripts/eval/run_wise.sh +44 -0
- scripts/train.sh +48 -0
- scripts/train_smm.sh +57 -0
- scripts/train_smm_sbatch.sh +85 -0
- test_images/image.png +3 -0
- test_images/meme.jpg +0 -0
- test_images/octupusy.jpg +0 -0
- test_images/women.jpg +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/arch.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/bagel-cot-example.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/emerging_curves.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/zebra_cot_datacard.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
test_images/image.png filter=lfs diff=lfs merge=lfs -text
|
assets/arch.png
ADDED
|
Git LFS Details
|
assets/bagel-cot-example.png
ADDED
|
Git LFS Details
|
assets/emerging_curves.png
ADDED
|
Git LFS Details
|
assets/teaser.webp
ADDED
|
Git LFS Details
|
assets/zebra_cot_datacard.png
ADDED
|
Git LFS Details
|
data/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
data/configs/example.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
think_trace:
|
| 2 |
+
dataset_names:
|
| 3 |
+
- think_trace_dataset
|
| 4 |
+
jsonl_path_list: ["/dev/shm/data/Zebra-CoT/zebra_cot.jsonl"]
|
| 5 |
+
num_used_data: None
|
| 6 |
+
image_prefix_dir: "/dev/shm/data/Zebra-CoT"
|
| 7 |
+
image_transform_args:
|
| 8 |
+
image_stride: 16
|
| 9 |
+
max_image_size: 512
|
| 10 |
+
min_image_size: 512
|
| 11 |
+
vit_image_transform_args:
|
| 12 |
+
image_stride: 14
|
| 13 |
+
max_image_size: 512
|
| 14 |
+
min_image_size: 512
|
| 15 |
+
weight: 1.0
|
| 16 |
+
is_mandatory: true
|
| 17 |
+
|
| 18 |
+
# unified_edit:
|
| 19 |
+
# dataset_names:
|
| 20 |
+
# - seedxedit_multi
|
| 21 |
+
# image_transform_args:
|
| 22 |
+
# image_stride: 16
|
| 23 |
+
# max_image_size: 1024
|
| 24 |
+
# min_image_size: 512
|
| 25 |
+
# vit_image_transform_args:
|
| 26 |
+
# image_stride: 14
|
| 27 |
+
# max_image_size: 518
|
| 28 |
+
# min_image_size: 224
|
| 29 |
+
# is_mandatory: true
|
| 30 |
+
# num_used_data:
|
| 31 |
+
# - 10
|
| 32 |
+
# weight: 1
|
| 33 |
+
|
| 34 |
+
# vlm_sft:
|
| 35 |
+
# dataset_names:
|
| 36 |
+
# - llava_ov
|
| 37 |
+
# image_transform_args:
|
| 38 |
+
# image_stride: 14
|
| 39 |
+
# max_image_size: 980
|
| 40 |
+
# min_image_size: 378
|
| 41 |
+
# max_pixels: 2_007_040
|
| 42 |
+
# frame_sampler_args:
|
| 43 |
+
# max_num_frames: 12
|
| 44 |
+
# min_num_frames: 8
|
| 45 |
+
# is_mandatory: true
|
| 46 |
+
# shuffle_lines: True
|
| 47 |
+
# shuffle_seed: 0
|
| 48 |
+
# num_used_data:
|
| 49 |
+
# - 1000
|
| 50 |
+
# weight: 1
|
data/configs/example_smm_random.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
block_dataset_random:
|
| 2 |
+
dataset_names:
|
| 3 |
+
- block_dataset_random
|
| 4 |
+
jsonl_path_list: ["/scratch/by2593/project/SMM/SMM_data/random_block.jsonl"]
|
| 5 |
+
num_used_data: None
|
| 6 |
+
image_prefix_dir: "/scratch/by2593/project/SMM/random_pipeline/random_blocks"
|
| 7 |
+
image_transform_args:
|
| 8 |
+
image_stride: 16
|
| 9 |
+
max_image_size: 512 # VAE使用stride=16, 512/16=32 patches
|
| 10 |
+
min_image_size: 512
|
| 11 |
+
vit_image_transform_args:
|
| 12 |
+
image_stride: 14
|
| 13 |
+
max_image_size: 512 # ViT使用stride=14, 512/14=36 patches (匹配模型能力)
|
| 14 |
+
min_image_size: 512
|
| 15 |
+
weight: 1.0
|
| 16 |
+
is_mandatory: true
|
| 17 |
+
|
| 18 |
+
# unified_edit:
|
| 19 |
+
# dataset_names:
|
| 20 |
+
# - seedxedit_multi
|
| 21 |
+
# image_transform_args:
|
| 22 |
+
# image_stride: 16
|
| 23 |
+
# max_image_size: 1024
|
| 24 |
+
# min_image_size: 512
|
| 25 |
+
# vit_image_transform_args:
|
| 26 |
+
# image_stride: 14
|
| 27 |
+
# max_image_size: 518
|
| 28 |
+
# min_image_size: 224
|
| 29 |
+
# is_mandatory: true
|
| 30 |
+
# num_used_data:
|
| 31 |
+
# - 10
|
| 32 |
+
# weight: 1
|
| 33 |
+
|
| 34 |
+
# vlm_sft:
|
| 35 |
+
# dataset_names:
|
| 36 |
+
# - llava_ov
|
| 37 |
+
# image_transform_args:
|
| 38 |
+
# image_stride: 14
|
| 39 |
+
# max_image_size: 980
|
| 40 |
+
# min_image_size: 378
|
| 41 |
+
# max_pixels: 2_007_040
|
| 42 |
+
# frame_sampler_args:
|
| 43 |
+
# max_num_frames: 12
|
| 44 |
+
# min_num_frames: 8
|
| 45 |
+
# is_mandatory: true
|
| 46 |
+
# shuffle_lines: True
|
| 47 |
+
# shuffle_seed: 0
|
| 48 |
+
# num_used_data:
|
| 49 |
+
# - 1000
|
| 50 |
+
# weight: 1
|
data/dataset_base.py
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .data_utils import (
|
| 12 |
+
get_flattened_position_ids_interpolate,
|
| 13 |
+
get_flattened_position_ids_extrapolate,
|
| 14 |
+
len2weight,
|
| 15 |
+
patchify,
|
| 16 |
+
prepare_attention_mask_per_sample,
|
| 17 |
+
)
|
| 18 |
+
from .dataset_info import DATASET_INFO, DATASET_REGISTRY
|
| 19 |
+
from .transforms import ImageTransform
|
| 20 |
+
from .video_utils import FrameSampler
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DataConfig:
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
grouped_datasets,
|
| 27 |
+
text_cond_dropout_prob=0.1,
|
| 28 |
+
vit_cond_dropout_prob=0.4,
|
| 29 |
+
vae_cond_dropout_prob=0.1,
|
| 30 |
+
vae_image_downsample=16,
|
| 31 |
+
max_latent_size=32,
|
| 32 |
+
vit_patch_size=14,
|
| 33 |
+
max_num_patch_per_side=70,
|
| 34 |
+
):
|
| 35 |
+
self.grouped_datasets = grouped_datasets
|
| 36 |
+
self.text_cond_dropout_prob = text_cond_dropout_prob
|
| 37 |
+
self.vit_cond_dropout_prob = vit_cond_dropout_prob
|
| 38 |
+
self.vit_patch_size = vit_patch_size
|
| 39 |
+
self.max_num_patch_per_side = max_num_patch_per_side
|
| 40 |
+
self.vae_cond_dropout_prob = vae_cond_dropout_prob
|
| 41 |
+
self.vae_image_downsample = vae_image_downsample
|
| 42 |
+
self.max_latent_size = max_latent_size
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PackedDataset(torch.utils.data.IterableDataset):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
data_config,
|
| 49 |
+
tokenizer,
|
| 50 |
+
special_tokens,
|
| 51 |
+
local_rank,
|
| 52 |
+
world_size,
|
| 53 |
+
num_workers,
|
| 54 |
+
expected_num_tokens=32768,
|
| 55 |
+
max_num_tokens_per_sample=16384,
|
| 56 |
+
max_num_tokens=36864,
|
| 57 |
+
prefer_buffer_before=16384,
|
| 58 |
+
max_buffer_size=50,
|
| 59 |
+
interpolate_pos=False,
|
| 60 |
+
use_flex=False,
|
| 61 |
+
data_status=None,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.expected_num_tokens = expected_num_tokens
|
| 65 |
+
self.max_num_tokens_per_sample = max_num_tokens_per_sample
|
| 66 |
+
self.prefer_buffer_before = prefer_buffer_before
|
| 67 |
+
self.max_num_tokens = max_num_tokens
|
| 68 |
+
self.max_buffer_size = max_buffer_size
|
| 69 |
+
self.tokenizer = tokenizer
|
| 70 |
+
self.local_rank = local_rank
|
| 71 |
+
self.world_size = world_size
|
| 72 |
+
self.num_workers = num_workers
|
| 73 |
+
self.use_flex = use_flex
|
| 74 |
+
for k, v in special_tokens.items():
|
| 75 |
+
setattr(self, k, v)
|
| 76 |
+
|
| 77 |
+
grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
|
| 78 |
+
data_config.grouped_datasets, data_status
|
| 79 |
+
)
|
| 80 |
+
self.grouped_datasets = grouped_datasets
|
| 81 |
+
self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
|
| 82 |
+
self.is_mandatory = is_mandatory
|
| 83 |
+
self.grouped_weights = grouped_weights
|
| 84 |
+
self.data_config = data_config
|
| 85 |
+
self.interpolate_pos = interpolate_pos
|
| 86 |
+
if self.interpolate_pos:
|
| 87 |
+
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
|
| 88 |
+
else:
|
| 89 |
+
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
|
| 90 |
+
|
| 91 |
+
def build_datasets(self, datasets_metainfo, data_status):
|
| 92 |
+
datasets = []
|
| 93 |
+
is_mandatory = []
|
| 94 |
+
grouped_weights = []
|
| 95 |
+
for grouped_dataset_name, dataset_args in datasets_metainfo.items():
|
| 96 |
+
is_mandatory.append(dataset_args.pop('is_mandatory', False))
|
| 97 |
+
grouped_weights.append(dataset_args.pop('weight', 0.0))
|
| 98 |
+
|
| 99 |
+
if 'frame_sampler_args' in dataset_args.keys():
|
| 100 |
+
frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
|
| 101 |
+
dataset_args['frame_sampler'] = frame_sampler
|
| 102 |
+
if 'image_transform_args' in dataset_args.keys():
|
| 103 |
+
transform = ImageTransform(**dataset_args.pop('image_transform_args'))
|
| 104 |
+
dataset_args['transform'] = transform
|
| 105 |
+
if 'vit_image_transform_args' in dataset_args.keys():
|
| 106 |
+
vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
|
| 107 |
+
dataset_args['vit_transform'] = vit_transform
|
| 108 |
+
|
| 109 |
+
assert 'dataset_names' in dataset_args.keys()
|
| 110 |
+
dataset_names = dataset_args.pop('dataset_names')
|
| 111 |
+
dataset_args['data_dir_list'] = []
|
| 112 |
+
for item in dataset_names:
|
| 113 |
+
if self.local_rank == 0:
|
| 114 |
+
print(f'Preparing Dataset {grouped_dataset_name}/{item}')
|
| 115 |
+
meta_info = DATASET_INFO[grouped_dataset_name][item]
|
| 116 |
+
dataset_args['data_dir_list'].append(meta_info['data_dir'])
|
| 117 |
+
|
| 118 |
+
if "parquet_info_path" in meta_info.keys():
|
| 119 |
+
if 'parquet_info' not in dataset_args.keys():
|
| 120 |
+
dataset_args['parquet_info'] = {}
|
| 121 |
+
with open(meta_info['parquet_info_path'], 'r') as f:
|
| 122 |
+
parquet_info = json.load(f)
|
| 123 |
+
dataset_args['parquet_info'].update(parquet_info)
|
| 124 |
+
|
| 125 |
+
if 'json_dir' in meta_info.keys():
|
| 126 |
+
# parquet/tar with json
|
| 127 |
+
if 'json_dir_list' not in dataset_args.keys():
|
| 128 |
+
dataset_args['json_dir_list'] = [meta_info['json_dir']]
|
| 129 |
+
else:
|
| 130 |
+
dataset_args['json_dir_list'].append(meta_info['json_dir'])
|
| 131 |
+
|
| 132 |
+
if 'jsonl_path' in meta_info.keys():
|
| 133 |
+
# jsonl with jpeg
|
| 134 |
+
if 'jsonl_path_list' not in dataset_args.keys():
|
| 135 |
+
dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
|
| 136 |
+
else:
|
| 137 |
+
dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
|
| 138 |
+
|
| 139 |
+
if 'image_prefix_dir' in meta_info.keys():
|
| 140 |
+
dataset_args['image_prefix_dir'] = meta_info['image_prefix_dir']
|
| 141 |
+
|
| 142 |
+
resume_data_status = dataset_args.pop('resume_data_status', True)
|
| 143 |
+
if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
|
| 144 |
+
data_status_per_group = data_status[grouped_dataset_name]
|
| 145 |
+
else:
|
| 146 |
+
data_status_per_group = None
|
| 147 |
+
dataset = DATASET_REGISTRY[grouped_dataset_name](
|
| 148 |
+
dataset_name=grouped_dataset_name,
|
| 149 |
+
tokenizer=self.tokenizer,
|
| 150 |
+
local_rank=self.local_rank,
|
| 151 |
+
world_size=self.world_size,
|
| 152 |
+
num_workers=self.num_workers,
|
| 153 |
+
data_status=data_status_per_group,
|
| 154 |
+
**dataset_args
|
| 155 |
+
)
|
| 156 |
+
datasets.append(dataset)
|
| 157 |
+
|
| 158 |
+
return datasets, is_mandatory, grouped_weights
|
| 159 |
+
|
| 160 |
+
def set_epoch(self, seed):
|
| 161 |
+
for dataset in self.grouped_datasets:
|
| 162 |
+
dataset.set_epoch(seed)
|
| 163 |
+
|
| 164 |
+
def set_sequence_status(self):
|
| 165 |
+
sequence_status = dict(
|
| 166 |
+
curr = 0,
|
| 167 |
+
sample_lens = list(),
|
| 168 |
+
packed_position_ids = list(),
|
| 169 |
+
nested_attention_masks = list(),
|
| 170 |
+
split_lens = list(),
|
| 171 |
+
attn_modes = list(),
|
| 172 |
+
packed_text_ids = list(),
|
| 173 |
+
packed_text_indexes = list(),
|
| 174 |
+
packed_label_ids = list(),
|
| 175 |
+
ce_loss_indexes = list(),
|
| 176 |
+
ce_loss_weights = list(),
|
| 177 |
+
vae_image_tensors = list(),
|
| 178 |
+
packed_latent_position_ids = list(),
|
| 179 |
+
vae_latent_shapes = list(),
|
| 180 |
+
packed_vae_token_indexes = list(),
|
| 181 |
+
packed_timesteps = list(),
|
| 182 |
+
mse_loss_indexes = list(),
|
| 183 |
+
packed_vit_tokens = list(),
|
| 184 |
+
vit_token_seqlens = list(),
|
| 185 |
+
packed_vit_position_ids = list(),
|
| 186 |
+
packed_vit_token_indexes = list(),
|
| 187 |
+
)
|
| 188 |
+
return sequence_status
|
| 189 |
+
|
| 190 |
+
def to_tensor(self, sequence_status):
|
| 191 |
+
data = dict(
|
| 192 |
+
sequence_length=sum(sequence_status['sample_lens']),
|
| 193 |
+
sample_lens=sequence_status['sample_lens'],
|
| 194 |
+
packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
|
| 195 |
+
packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
|
| 196 |
+
packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
|
| 197 |
+
)
|
| 198 |
+
if not self.use_flex:
|
| 199 |
+
data['nested_attention_masks'] = sequence_status['nested_attention_masks']
|
| 200 |
+
else:
|
| 201 |
+
sequence_len = data['sequence_length']
|
| 202 |
+
pad_len = self.max_num_tokens - sequence_len
|
| 203 |
+
data['split_lens'] = sequence_status['split_lens'] + [pad_len]
|
| 204 |
+
data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
|
| 205 |
+
data['sample_lens'] += [pad_len]
|
| 206 |
+
|
| 207 |
+
# if the model has a convnet vae (e.g., as visual tokenizer)
|
| 208 |
+
if len(sequence_status['vae_image_tensors']) > 0:
|
| 209 |
+
image_tensors = sequence_status.pop('vae_image_tensors')
|
| 210 |
+
image_sizes = [item.shape for item in image_tensors]
|
| 211 |
+
max_image_size = [max(item) for item in list(zip(*image_sizes))]
|
| 212 |
+
padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
|
| 213 |
+
for i, image_tensor in enumerate(image_tensors):
|
| 214 |
+
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
|
| 215 |
+
|
| 216 |
+
data['padded_images'] = padded_images
|
| 217 |
+
data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
|
| 218 |
+
data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
|
| 219 |
+
data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
|
| 220 |
+
|
| 221 |
+
# if the model has a vit (e.g., as visual tokenizer)
|
| 222 |
+
if len(sequence_status['packed_vit_tokens']) > 0:
|
| 223 |
+
data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
|
| 224 |
+
data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
|
| 225 |
+
data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
|
| 226 |
+
data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
|
| 227 |
+
|
| 228 |
+
# if the model is required to perform visual generation
|
| 229 |
+
if len(sequence_status['packed_timesteps']) > 0:
|
| 230 |
+
data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
|
| 231 |
+
data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
|
| 232 |
+
|
| 233 |
+
# if the model is required to perform text generation
|
| 234 |
+
if len(sequence_status['packed_label_ids']) > 0:
|
| 235 |
+
data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
|
| 236 |
+
data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
|
| 237 |
+
data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
|
| 238 |
+
|
| 239 |
+
# Debug printing for rank 0
|
| 240 |
+
# if self.local_rank == 0:
|
| 241 |
+
# self.print_debug_info(data, sequence_status)
|
| 242 |
+
|
| 243 |
+
return data
|
| 244 |
+
|
| 245 |
+
def print_debug_info(self, data, sequence_status):
|
| 246 |
+
"""Print detailed debug information in an intuitive table format"""
|
| 247 |
+
print("\n" + "="*120)
|
| 248 |
+
print("DEBUG: Complete Sequence Analysis")
|
| 249 |
+
print("="*120)
|
| 250 |
+
|
| 251 |
+
# Basic info
|
| 252 |
+
print(f"Sequence Length: {data['sequence_length']}")
|
| 253 |
+
print(f"Sample Lengths: {data['sample_lens']}")
|
| 254 |
+
|
| 255 |
+
# Get all data
|
| 256 |
+
packed_text_ids = data['packed_text_ids'].tolist()
|
| 257 |
+
packed_text_indexes = data['packed_text_indexes'].tolist()
|
| 258 |
+
|
| 259 |
+
# Build loss mappings
|
| 260 |
+
ce_loss_indexes = set(data.get('ce_loss_indexes', []).tolist())
|
| 261 |
+
mse_loss_indexes = set(data.get('mse_loss_indexes', []).tolist())
|
| 262 |
+
vit_token_indexes = set(data.get('packed_vit_token_indexes', []).tolist())
|
| 263 |
+
vae_token_indexes = set(data.get('packed_vae_token_indexes', []).tolist())
|
| 264 |
+
|
| 265 |
+
# Build label mapping
|
| 266 |
+
label_mapping = {}
|
| 267 |
+
if 'ce_loss_indexes' in data:
|
| 268 |
+
ce_indexes = data['ce_loss_indexes'].tolist()
|
| 269 |
+
ce_labels = data['packed_label_ids'].tolist()
|
| 270 |
+
for i, pos in enumerate(ce_indexes):
|
| 271 |
+
label_mapping[pos] = ce_labels[i]
|
| 272 |
+
|
| 273 |
+
# Print raw token sequence
|
| 274 |
+
print(f"\n1. Raw Token IDs: {packed_text_ids}")
|
| 275 |
+
|
| 276 |
+
# Print decoded token sequence
|
| 277 |
+
try:
|
| 278 |
+
decoded_text_tokens = []
|
| 279 |
+
for token_id in packed_text_ids:
|
| 280 |
+
decoded = self.tokenizer.decode([token_id])
|
| 281 |
+
decoded_text_tokens.append(decoded)
|
| 282 |
+
print(f"2. Decoded Tokens: {decoded_text_tokens}")
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"2. Error decoding tokens: {e}")
|
| 285 |
+
decoded_text_tokens = ["<ERROR>"] * len(packed_text_ids)
|
| 286 |
+
|
| 287 |
+
# Create comprehensive sequence table
|
| 288 |
+
print(f"\n3. Complete Sequence Table:")
|
| 289 |
+
print("-" * 120)
|
| 290 |
+
print(f"{'Order':<6} | {'Token Type':<12} | {'Token/Content':<30} | {'Loss Type':<10} | {'Label':<30} | {'Notes':<20}")
|
| 291 |
+
print("-" * 120)
|
| 292 |
+
|
| 293 |
+
# Track text token index
|
| 294 |
+
text_token_idx = 0
|
| 295 |
+
|
| 296 |
+
for pos in range(data['sequence_length']):
|
| 297 |
+
# Determine token type and content
|
| 298 |
+
if pos in packed_text_indexes:
|
| 299 |
+
# This is a text token position
|
| 300 |
+
token_id = packed_text_ids[text_token_idx]
|
| 301 |
+
try:
|
| 302 |
+
decoded_token = self.tokenizer.decode([token_id])
|
| 303 |
+
token_content = f"ID:{token_id} '{decoded_token}'"
|
| 304 |
+
except:
|
| 305 |
+
token_content = f"ID:{token_id} '<ERROR>'"
|
| 306 |
+
token_type = "TEXT"
|
| 307 |
+
text_token_idx += 1
|
| 308 |
+
|
| 309 |
+
elif pos in vit_token_indexes:
|
| 310 |
+
token_type = "VIT_IMAGE"
|
| 311 |
+
token_content = "[VIT Image Patch]"
|
| 312 |
+
|
| 313 |
+
elif pos in vae_token_indexes:
|
| 314 |
+
token_type = "VAE_IMAGE"
|
| 315 |
+
token_content = "[VAE Image Latent]"
|
| 316 |
+
|
| 317 |
+
else:
|
| 318 |
+
token_type = "UNKNOWN"
|
| 319 |
+
token_content = "[Unknown Position]"
|
| 320 |
+
|
| 321 |
+
# Determine loss type
|
| 322 |
+
if pos in ce_loss_indexes:
|
| 323 |
+
loss_type = "CE"
|
| 324 |
+
elif pos in mse_loss_indexes:
|
| 325 |
+
loss_type = "MSE"
|
| 326 |
+
else:
|
| 327 |
+
loss_type = "None"
|
| 328 |
+
|
| 329 |
+
# Determine label
|
| 330 |
+
if pos in label_mapping:
|
| 331 |
+
label_id = label_mapping[pos]
|
| 332 |
+
try:
|
| 333 |
+
decoded_label = self.tokenizer.decode([label_id])
|
| 334 |
+
label_content = f"ID:{label_id} '{decoded_label}'"
|
| 335 |
+
except:
|
| 336 |
+
label_content = f"ID:{label_id} '<ERROR>'"
|
| 337 |
+
elif pos in mse_loss_indexes:
|
| 338 |
+
label_content = "[Image Generation Target]"
|
| 339 |
+
else:
|
| 340 |
+
label_content = "N/A"
|
| 341 |
+
|
| 342 |
+
# Additional notes
|
| 343 |
+
notes = ""
|
| 344 |
+
if pos in mse_loss_indexes and 'packed_timesteps' in data:
|
| 345 |
+
timestep_idx = list(mse_loss_indexes).index(pos) if pos in mse_loss_indexes else -1
|
| 346 |
+
if timestep_idx >= 0 and timestep_idx < len(data['packed_timesteps']):
|
| 347 |
+
timestep = data['packed_timesteps'][timestep_idx].item()
|
| 348 |
+
if timestep == float('-inf'):
|
| 349 |
+
notes = "No noise"
|
| 350 |
+
else:
|
| 351 |
+
notes = f"t={timestep:.3f}"
|
| 352 |
+
|
| 353 |
+
print(f"{pos:<6} | {token_type:<12} | {token_content:<30} | {loss_type:<10} | {label_content:<30} | {notes:<20}")
|
| 354 |
+
|
| 355 |
+
print("-" * 120)
|
| 356 |
+
|
| 357 |
+
# Summary statistics
|
| 358 |
+
total_positions = data['sequence_length']
|
| 359 |
+
ce_positions = len(ce_loss_indexes)
|
| 360 |
+
mse_positions = len(mse_loss_indexes)
|
| 361 |
+
vit_positions = len(vit_token_indexes)
|
| 362 |
+
vae_positions = len(vae_token_indexes)
|
| 363 |
+
text_positions = len(packed_text_indexes)
|
| 364 |
+
no_loss_positions = total_positions - ce_positions - mse_positions
|
| 365 |
+
|
| 366 |
+
print(f"\nSummary Statistics:")
|
| 367 |
+
print(f" Total positions: {total_positions}")
|
| 368 |
+
print(f" Text tokens: {text_positions} ({text_positions/total_positions*100:.1f}%)")
|
| 369 |
+
print(f" VIT image tokens: {vit_positions} ({vit_positions/total_positions*100:.1f}%)")
|
| 370 |
+
print(f" VAE image tokens: {vae_positions} ({vae_positions/total_positions*100:.1f}%)")
|
| 371 |
+
print(f" Positions with CE loss: {ce_positions} ({ce_positions/total_positions*100:.1f}%)")
|
| 372 |
+
print(f" Positions with MSE loss: {mse_positions} ({mse_positions/total_positions*100:.1f}%)")
|
| 373 |
+
print(f" Positions with no loss: {no_loss_positions} ({no_loss_positions/total_positions*100:.1f}%)")
|
| 374 |
+
|
| 375 |
+
print("="*120 + "\n")
|
| 376 |
+
|
| 377 |
+
def __iter__(self):
|
| 378 |
+
total_weights = sum(self.grouped_weights)
|
| 379 |
+
assert total_weights > 0.0
|
| 380 |
+
group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
|
| 381 |
+
for i in range(len(self.grouped_weights))]
|
| 382 |
+
sequence_status = self.set_sequence_status()
|
| 383 |
+
batch_data_indexes = []
|
| 384 |
+
|
| 385 |
+
buffer = []
|
| 386 |
+
while True:
|
| 387 |
+
# Ensure at least one sample from each group
|
| 388 |
+
if sequence_status['curr'] == 0:
|
| 389 |
+
for group_index, group_iter in enumerate(self.dataset_iters):
|
| 390 |
+
if self.is_mandatory[group_index]:
|
| 391 |
+
while True:
|
| 392 |
+
sample = next(group_iter)
|
| 393 |
+
# if a sample is too long, skip it
|
| 394 |
+
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
|
| 395 |
+
if num_tokens < self.max_num_tokens_per_sample:
|
| 396 |
+
sequence_status = self.pack_sequence(sample, sequence_status)
|
| 397 |
+
batch_data_indexes.append(sample['data_indexes'])
|
| 398 |
+
break
|
| 399 |
+
else:
|
| 400 |
+
print(f"skip a sample with length {num_tokens}")
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
|
| 404 |
+
sample = buffer.pop(0)
|
| 405 |
+
sample_from_buffer = True
|
| 406 |
+
else:
|
| 407 |
+
# sample normally across all groups
|
| 408 |
+
n = random.random()
|
| 409 |
+
group_index = 0
|
| 410 |
+
for i, cumprob in enumerate(group_cumprobs):
|
| 411 |
+
if n < cumprob:
|
| 412 |
+
group_index = i
|
| 413 |
+
break
|
| 414 |
+
sample = next(self.dataset_iters[group_index])
|
| 415 |
+
sample_from_buffer = False
|
| 416 |
+
|
| 417 |
+
# if a sample is too long, skip it
|
| 418 |
+
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
|
| 419 |
+
if num_tokens > self.max_num_tokens_per_sample:
|
| 420 |
+
print(f"skip a sample with length {num_tokens}")
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
if sequence_status['curr'] + num_tokens > self.max_num_tokens:
|
| 424 |
+
if len(buffer) < self.max_buffer_size and not sample_from_buffer:
|
| 425 |
+
buffer.append(sample)
|
| 426 |
+
else:
|
| 427 |
+
print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
|
| 428 |
+
data = self.to_tensor(sequence_status)
|
| 429 |
+
data['batch_data_indexes'] = batch_data_indexes
|
| 430 |
+
yield data
|
| 431 |
+
sequence_status = self.set_sequence_status()
|
| 432 |
+
batch_data_indexes = []
|
| 433 |
+
continue
|
| 434 |
+
|
| 435 |
+
sequence_status = self.pack_sequence(sample, sequence_status)
|
| 436 |
+
batch_data_indexes.append(sample['data_indexes'])
|
| 437 |
+
|
| 438 |
+
if sequence_status['curr'] >= self.expected_num_tokens:
|
| 439 |
+
data = self.to_tensor(sequence_status)
|
| 440 |
+
data['batch_data_indexes'] = batch_data_indexes
|
| 441 |
+
yield data
|
| 442 |
+
sequence_status = self.set_sequence_status()
|
| 443 |
+
batch_data_indexes = []
|
| 444 |
+
|
| 445 |
+
def pack_sequence(self, sample, sequence_status):
|
| 446 |
+
image_tensor_list = sample['image_tensor_list']
|
| 447 |
+
text_ids_list = sample['text_ids_list']
|
| 448 |
+
sequence_plan = sample['sequence_plan']
|
| 449 |
+
|
| 450 |
+
split_lens, attn_modes = list(), list()
|
| 451 |
+
curr = sequence_status['curr']
|
| 452 |
+
curr_rope_id = 0
|
| 453 |
+
sample_lens = 0
|
| 454 |
+
|
| 455 |
+
for item in sequence_plan:
|
| 456 |
+
split_start = item.get('split_start', True)
|
| 457 |
+
if split_start:
|
| 458 |
+
curr_split_len = 0
|
| 459 |
+
|
| 460 |
+
if item['type'] == 'text':
|
| 461 |
+
text_ids = text_ids_list.pop(0)
|
| 462 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
shifted_text_ids = [self.bos_token_id] + text_ids
|
| 466 |
+
sequence_status['packed_text_ids'].extend(shifted_text_ids)
|
| 467 |
+
sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
|
| 468 |
+
if item['loss'] == 1:
|
| 469 |
+
sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
|
| 470 |
+
sequence_status['ce_loss_weights'].extend(
|
| 471 |
+
[len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
|
| 472 |
+
)
|
| 473 |
+
sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
|
| 474 |
+
curr += len(shifted_text_ids)
|
| 475 |
+
curr_split_len += len(shifted_text_ids)
|
| 476 |
+
|
| 477 |
+
# add a <|im_end|> token
|
| 478 |
+
sequence_status['packed_text_ids'].append(self.eos_token_id)
|
| 479 |
+
sequence_status['packed_text_indexes'].append(curr)
|
| 480 |
+
if item['special_token_loss'] == 1: # <|im_end|> may have loss
|
| 481 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
| 482 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
| 483 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
| 484 |
+
curr += 1
|
| 485 |
+
curr_split_len += 1
|
| 486 |
+
|
| 487 |
+
# update sequence status
|
| 488 |
+
attn_modes.append("causal")
|
| 489 |
+
sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
|
| 490 |
+
curr_rope_id += curr_split_len
|
| 491 |
+
|
| 492 |
+
elif item['type'] == 'vit_image':
|
| 493 |
+
image_tensor = image_tensor_list.pop(0)
|
| 494 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
|
| 495 |
+
curr_rope_id += 1
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
# add a <|startofimage|> token
|
| 499 |
+
sequence_status['packed_text_ids'].append(self.start_of_image)
|
| 500 |
+
sequence_status['packed_text_indexes'].append(curr)
|
| 501 |
+
curr += 1
|
| 502 |
+
curr_split_len += 1
|
| 503 |
+
|
| 504 |
+
# preprocess image
|
| 505 |
+
vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
|
| 506 |
+
num_img_tokens = vit_tokens.shape[0]
|
| 507 |
+
sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
|
| 508 |
+
curr += num_img_tokens
|
| 509 |
+
curr_split_len += num_img_tokens
|
| 510 |
+
|
| 511 |
+
sequence_status['packed_vit_tokens'].append(vit_tokens)
|
| 512 |
+
sequence_status['vit_token_seqlens'].append(num_img_tokens)
|
| 513 |
+
sequence_status['packed_vit_position_ids'].append(
|
| 514 |
+
self.get_flattened_position_ids(
|
| 515 |
+
image_tensor.size(1), image_tensor.size(2),
|
| 516 |
+
self.data_config.vit_patch_size,
|
| 517 |
+
max_num_patches_per_side=self.data_config.max_num_patch_per_side
|
| 518 |
+
)
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# add a <|endofimage|> token
|
| 522 |
+
sequence_status['packed_text_ids'].append(self.end_of_image)
|
| 523 |
+
sequence_status['packed_text_indexes'].append(curr)
|
| 524 |
+
if item['special_token_loss'] == 1: # <|endofimage|> may have loss
|
| 525 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
| 526 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
| 527 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
| 528 |
+
curr += 1
|
| 529 |
+
curr_split_len += 1
|
| 530 |
+
|
| 531 |
+
# update sequence status
|
| 532 |
+
attn_modes.append("full")
|
| 533 |
+
sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
|
| 534 |
+
curr_rope_id += 1
|
| 535 |
+
|
| 536 |
+
elif item['type'] == 'vae_image':
|
| 537 |
+
image_tensor = image_tensor_list.pop(0)
|
| 538 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
|
| 539 |
+
# FIXME fix vae dropout in video2video setting.
|
| 540 |
+
curr_rope_id += 1
|
| 541 |
+
continue
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# add a <|startofimage|> token
|
| 547 |
+
sequence_status['packed_text_ids'].append(self.start_of_image)
|
| 548 |
+
sequence_status['packed_text_indexes'].append(curr)
|
| 549 |
+
|
| 550 |
+
if item['special_token_loss'] == 1:
|
| 551 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
| 552 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
| 553 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
| 554 |
+
|
| 555 |
+
curr += 1
|
| 556 |
+
curr_split_len += 1
|
| 557 |
+
|
| 558 |
+
# preprocess image
|
| 559 |
+
sequence_status['vae_image_tensors'].append(image_tensor)
|
| 560 |
+
sequence_status['packed_latent_position_ids'].append(
|
| 561 |
+
self.get_flattened_position_ids(
|
| 562 |
+
image_tensor.size(1), image_tensor.size(2),
|
| 563 |
+
self.data_config.vae_image_downsample,
|
| 564 |
+
max_num_patches_per_side=self.data_config.max_latent_size
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
H, W = image_tensor.shape[1:]
|
| 568 |
+
h = H // self.data_config.vae_image_downsample
|
| 569 |
+
w = W // self.data_config.vae_image_downsample
|
| 570 |
+
sequence_status['vae_latent_shapes'].append((h, w))
|
| 571 |
+
|
| 572 |
+
num_img_tokens = w * h
|
| 573 |
+
sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
|
| 574 |
+
if item['loss'] == 1:
|
| 575 |
+
sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
|
| 576 |
+
if split_start:
|
| 577 |
+
timestep = np.random.randn()
|
| 578 |
+
else:
|
| 579 |
+
timestep = float('-inf')
|
| 580 |
+
|
| 581 |
+
sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
|
| 582 |
+
curr += num_img_tokens
|
| 583 |
+
curr_split_len += num_img_tokens
|
| 584 |
+
|
| 585 |
+
# add a <|endofimage|> token
|
| 586 |
+
sequence_status['packed_text_ids'].append(self.end_of_image)
|
| 587 |
+
sequence_status['packed_text_indexes'].append(curr)
|
| 588 |
+
# <|endofimage|> may have loss
|
| 589 |
+
if item['special_token_loss'] == 1:
|
| 590 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
| 591 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
| 592 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
| 593 |
+
curr += 1
|
| 594 |
+
curr_split_len += 1
|
| 595 |
+
|
| 596 |
+
# update sequence status
|
| 597 |
+
if split_start:
|
| 598 |
+
if item['loss'] == 1 and 'frame_delta' not in item.keys():
|
| 599 |
+
attn_modes.append("noise")
|
| 600 |
+
else:
|
| 601 |
+
attn_modes.append("full")
|
| 602 |
+
sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
|
| 603 |
+
if 'frame_delta' in item.keys():
|
| 604 |
+
curr_rope_id += item['frame_delta']
|
| 605 |
+
elif item['loss'] == 0:
|
| 606 |
+
curr_rope_id += 1
|
| 607 |
+
|
| 608 |
+
if item.get('split_end', True):
|
| 609 |
+
split_lens.append(curr_split_len)
|
| 610 |
+
sample_lens += curr_split_len
|
| 611 |
+
|
| 612 |
+
sequence_status['curr'] = curr
|
| 613 |
+
sequence_status['sample_lens'].append(sample_lens)
|
| 614 |
+
# prepare attention mask
|
| 615 |
+
if not self.use_flex:
|
| 616 |
+
sequence_status['nested_attention_masks'].append(
|
| 617 |
+
prepare_attention_mask_per_sample(split_lens, attn_modes)
|
| 618 |
+
)
|
| 619 |
+
else:
|
| 620 |
+
sequence_status['split_lens'].extend(split_lens)
|
| 621 |
+
sequence_status['attn_modes'].extend(attn_modes)
|
| 622 |
+
|
| 623 |
+
return sequence_status
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class SimpleCustomBatch:
|
| 627 |
+
def __init__(self, batch):
|
| 628 |
+
data = batch[0]
|
| 629 |
+
self.batch_data_indexes = data['batch_data_indexes']
|
| 630 |
+
self.sequence_length = data["sequence_length"]
|
| 631 |
+
self.sample_lens = data["sample_lens"]
|
| 632 |
+
self.packed_text_ids = data["packed_text_ids"]
|
| 633 |
+
self.packed_text_indexes = data["packed_text_indexes"]
|
| 634 |
+
self.packed_position_ids = data["packed_position_ids"]
|
| 635 |
+
|
| 636 |
+
self.use_flex = "nested_attention_masks" not in data.keys()
|
| 637 |
+
|
| 638 |
+
if self.use_flex:
|
| 639 |
+
self.split_lens = data["split_lens"]
|
| 640 |
+
self.attn_modes = data["attn_modes"]
|
| 641 |
+
else:
|
| 642 |
+
self.nested_attention_masks = data["nested_attention_masks"]
|
| 643 |
+
|
| 644 |
+
if "padded_images" in data.keys():
|
| 645 |
+
self.padded_images = data["padded_images"]
|
| 646 |
+
self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
|
| 647 |
+
self.packed_latent_position_ids = data["packed_latent_position_ids"]
|
| 648 |
+
self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
|
| 649 |
+
|
| 650 |
+
if "packed_vit_tokens" in data.keys():
|
| 651 |
+
self.packed_vit_tokens = data["packed_vit_tokens"]
|
| 652 |
+
self.packed_vit_position_ids = data["packed_vit_position_ids"]
|
| 653 |
+
self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
|
| 654 |
+
self.vit_token_seqlens = data["vit_token_seqlens"]
|
| 655 |
+
|
| 656 |
+
if "packed_timesteps" in data.keys():
|
| 657 |
+
self.packed_timesteps = data["packed_timesteps"]
|
| 658 |
+
self.mse_loss_indexes = data["mse_loss_indexes"]
|
| 659 |
+
|
| 660 |
+
if "packed_label_ids" in data.keys():
|
| 661 |
+
self.packed_label_ids = data["packed_label_ids"]
|
| 662 |
+
self.ce_loss_indexes = data["ce_loss_indexes"]
|
| 663 |
+
self.ce_loss_weights = data["ce_loss_weights"]
|
| 664 |
+
|
| 665 |
+
def pin_memory(self):
|
| 666 |
+
self.packed_text_ids = self.packed_text_ids.pin_memory()
|
| 667 |
+
self.packed_text_indexes = self.packed_text_indexes.pin_memory()
|
| 668 |
+
self.packed_position_ids = self.packed_position_ids.pin_memory()
|
| 669 |
+
|
| 670 |
+
if not self.use_flex:
|
| 671 |
+
self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
|
| 672 |
+
|
| 673 |
+
if hasattr(self, 'padded_images'):
|
| 674 |
+
self.padded_images = self.padded_images.pin_memory()
|
| 675 |
+
self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
|
| 676 |
+
self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
|
| 677 |
+
|
| 678 |
+
if hasattr(self, 'packed_timesteps'):
|
| 679 |
+
self.packed_timesteps = self.packed_timesteps.pin_memory()
|
| 680 |
+
self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
|
| 681 |
+
|
| 682 |
+
if hasattr(self, 'packed_vit_tokens'):
|
| 683 |
+
self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
|
| 684 |
+
self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
|
| 685 |
+
self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
|
| 686 |
+
self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
|
| 687 |
+
|
| 688 |
+
if hasattr(self, 'packed_label_ids'):
|
| 689 |
+
self.packed_label_ids = self.packed_label_ids.pin_memory()
|
| 690 |
+
self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
|
| 691 |
+
self.ce_loss_weights = self.ce_loss_weights.pin_memory()
|
| 692 |
+
|
| 693 |
+
return self
|
| 694 |
+
|
| 695 |
+
def cuda(self, device):
|
| 696 |
+
self.packed_text_ids = self.packed_text_ids.to(device)
|
| 697 |
+
self.packed_text_indexes = self.packed_text_indexes.to(device)
|
| 698 |
+
self.packed_position_ids = self.packed_position_ids.to(device)
|
| 699 |
+
|
| 700 |
+
if not self.use_flex:
|
| 701 |
+
self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
|
| 702 |
+
|
| 703 |
+
if hasattr(self, 'padded_images'):
|
| 704 |
+
self.padded_images = self.padded_images.to(device)
|
| 705 |
+
self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
|
| 706 |
+
self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
|
| 707 |
+
|
| 708 |
+
if hasattr(self, 'packed_timesteps'):
|
| 709 |
+
self.packed_timesteps = self.packed_timesteps.to(device)
|
| 710 |
+
self.mse_loss_indexes = self.mse_loss_indexes.to(device)
|
| 711 |
+
|
| 712 |
+
if hasattr(self, 'packed_vit_tokens'):
|
| 713 |
+
self.packed_vit_tokens = self.packed_vit_tokens.to(device)
|
| 714 |
+
self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
|
| 715 |
+
self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
|
| 716 |
+
self.vit_token_seqlens = self.vit_token_seqlens.to(device)
|
| 717 |
+
|
| 718 |
+
if hasattr(self, 'packed_label_ids'):
|
| 719 |
+
self.packed_label_ids = self.packed_label_ids.to(device)
|
| 720 |
+
self.ce_loss_indexes = self.ce_loss_indexes.to(device)
|
| 721 |
+
self.ce_loss_weights = self.ce_loss_weights.to(device)
|
| 722 |
+
|
| 723 |
+
return self
|
| 724 |
+
|
| 725 |
+
def to_dict(self):
|
| 726 |
+
data = dict(
|
| 727 |
+
sequence_length = self.sequence_length,
|
| 728 |
+
sample_lens = self.sample_lens,
|
| 729 |
+
packed_text_ids = self.packed_text_ids,
|
| 730 |
+
packed_text_indexes = self.packed_text_indexes,
|
| 731 |
+
packed_position_ids = self.packed_position_ids,
|
| 732 |
+
batch_data_indexes = self.batch_data_indexes,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
if not self.use_flex:
|
| 736 |
+
data['nested_attention_masks'] = self.nested_attention_masks
|
| 737 |
+
else:
|
| 738 |
+
data['split_lens'] = self.split_lens
|
| 739 |
+
data['attn_modes'] = self.attn_modes
|
| 740 |
+
|
| 741 |
+
if hasattr(self, 'padded_images'):
|
| 742 |
+
data['padded_images'] = self.padded_images
|
| 743 |
+
data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
|
| 744 |
+
data['packed_latent_position_ids'] = self.packed_latent_position_ids
|
| 745 |
+
data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
|
| 746 |
+
|
| 747 |
+
if hasattr(self, 'packed_vit_tokens'):
|
| 748 |
+
data['packed_vit_tokens'] = self.packed_vit_tokens
|
| 749 |
+
data['packed_vit_position_ids'] = self.packed_vit_position_ids
|
| 750 |
+
data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
|
| 751 |
+
data['vit_token_seqlens'] = self.vit_token_seqlens
|
| 752 |
+
|
| 753 |
+
if hasattr(self, 'packed_timesteps'):
|
| 754 |
+
data['packed_timesteps'] = self.packed_timesteps
|
| 755 |
+
data['mse_loss_indexes'] = self.mse_loss_indexes
|
| 756 |
+
|
| 757 |
+
if hasattr(self, 'packed_label_ids'):
|
| 758 |
+
data['packed_label_ids'] = self.packed_label_ids
|
| 759 |
+
data['ce_loss_indexes'] = self.ce_loss_indexes
|
| 760 |
+
data['ce_loss_weights'] = self.ce_loss_weights
|
| 761 |
+
|
| 762 |
+
return data
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def collate_wrapper():
|
| 766 |
+
def collate_fn(batch):
|
| 767 |
+
return SimpleCustomBatch(batch)
|
| 768 |
+
return collate_fn
|
data/dataset_info.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from .interleave_datasets import UnifiedEditIterableDataset
|
| 5 |
+
from .t2i_dataset import T2IIterableDataset
|
| 6 |
+
from .vlm_dataset import SftJSONLIterableDataset
|
| 7 |
+
from .interleave_datasets.think_trace_dataset import ThinkTraceJSONLIterableDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DATASET_REGISTRY = {
|
| 11 |
+
't2i_pretrain': T2IIterableDataset,
|
| 12 |
+
'vlm_sft': SftJSONLIterableDataset,
|
| 13 |
+
'unified_edit': UnifiedEditIterableDataset,
|
| 14 |
+
'think_trace': ThinkTraceJSONLIterableDataset,
|
| 15 |
+
'block_dataset': ThinkTraceJSONLIterableDataset,
|
| 16 |
+
'block_dataset_random': ThinkTraceJSONLIterableDataset,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DATASET_INFO = {
|
| 21 |
+
'think_trace': {
|
| 22 |
+
'think_trace_dataset': {
|
| 23 |
+
'data_dir': '/scratch/by2593/project/SpaCU/interleaved-co3dv2/data',
|
| 24 |
+
'jsonl_path': '/scratch/by2593/project/SpaCU/interleaved-co3dv2/data/merged_train.jsonl',
|
| 25 |
+
'image_prefix_dir': '/scratch/by2593/project/SpaCU/restored_data2', # Base path for relative image paths
|
| 26 |
+
# 'num_total_samples': 100,
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
'block_dataset': {
|
| 30 |
+
'block_dataset': {
|
| 31 |
+
'data_dir': "/scratch/by2593/project/SMM/semantic_blocks_part1",
|
| 32 |
+
# 'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1_v2_reordered.jsonl',
|
| 33 |
+
'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl',
|
| 34 |
+
'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', # Base path for relative image paths
|
| 35 |
+
# 'num_total_samples': 100,
|
| 36 |
+
},
|
| 37 |
+
},
|
| 38 |
+
'block_dataset_random': {
|
| 39 |
+
'block_dataset_random': {
|
| 40 |
+
'data_dir': "/scratch/by2593/project/SMM/random_pipeline/random_blocks",
|
| 41 |
+
'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/random_block.jsonl',
|
| 42 |
+
'image_prefix_dir': '/scratch/by2593/project/SMM/random_pipeline/random_blocks', # Base path for relative image paths
|
| 43 |
+
# 'num_total_samples': 100,
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
}
|
data/distributed_iterable_dataset.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DistributedIterableDataset(torch.utils.data.IterableDataset):
|
| 9 |
+
def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
|
| 10 |
+
self.dataset_name = dataset_name
|
| 11 |
+
self.local_rank = local_rank
|
| 12 |
+
self.world_size = world_size
|
| 13 |
+
self.num_workers = num_workers
|
| 14 |
+
self.rng = random.Random()
|
| 15 |
+
self.data_paths = None
|
| 16 |
+
|
| 17 |
+
def get_data_paths(self, *args, **kwargs):
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
def set_epoch(self, seed=42):
|
| 21 |
+
if self.data_paths is None:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
if isinstance(self.data_paths[0], tuple):
|
| 25 |
+
data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
|
| 26 |
+
elif isinstance(self.data_paths[0], str):
|
| 27 |
+
data_paths = sorted(self.data_paths)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
|
| 30 |
+
|
| 31 |
+
self.rng.seed(seed)
|
| 32 |
+
self.rng.shuffle(data_paths)
|
| 33 |
+
|
| 34 |
+
num_files_per_rank = len(data_paths) // self.world_size
|
| 35 |
+
local_start = self.local_rank * num_files_per_rank
|
| 36 |
+
local_end = (self.local_rank + 1) * num_files_per_rank
|
| 37 |
+
self.num_files_per_rank = num_files_per_rank
|
| 38 |
+
self.data_paths_per_rank = data_paths[local_start:local_end]
|
| 39 |
+
|
| 40 |
+
def get_data_paths_per_worker(self):
|
| 41 |
+
if self.data_paths is None:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
info = torch.utils.data.get_worker_info()
|
| 45 |
+
if info is None:
|
| 46 |
+
# Single worker: Use all files assigned to the rank
|
| 47 |
+
return self.data_paths_per_rank, 0
|
| 48 |
+
|
| 49 |
+
worker_id = info.id
|
| 50 |
+
num_files_per_worker = self.num_files_per_rank // info.num_workers
|
| 51 |
+
start = num_files_per_worker * worker_id
|
| 52 |
+
end = num_files_per_worker * (worker_id + 1)
|
| 53 |
+
data_paths_per_worker = self.data_paths_per_rank[start:end]
|
| 54 |
+
|
| 55 |
+
return data_paths_per_worker[::-1], worker_id
|
| 56 |
+
|
| 57 |
+
def __iter__(self):
|
| 58 |
+
raise NotImplementedError
|
data/interleave_datasets/edit_dataset.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import io
|
| 5 |
+
import random
|
| 6 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
| 7 |
+
|
| 8 |
+
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
|
| 9 |
+
from ..data_utils import pil_img2rgb
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
| 13 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 14 |
+
MaximumDecompressedSize = 1024
|
| 15 |
+
MegaByte = 2 ** 20
|
| 16 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
|
| 20 |
+
|
| 21 |
+
def parse_row(self, row):
|
| 22 |
+
image_num = len(row["image_list"])
|
| 23 |
+
# randomly choose start and end, return [0, 1] when only two images
|
| 24 |
+
start_idx = random.choice(range(image_num - 1))
|
| 25 |
+
max_end = min(start_idx + 3, image_num)
|
| 26 |
+
end_idx = random.choice(range(start_idx + 1, max_end))
|
| 27 |
+
|
| 28 |
+
data = self._init_data()
|
| 29 |
+
data = self._add_image(
|
| 30 |
+
data,
|
| 31 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
|
| 32 |
+
need_loss=False,
|
| 33 |
+
need_vae=True,
|
| 34 |
+
need_vit=True,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
|
| 38 |
+
if end_idx == image_num - 1:
|
| 39 |
+
end_idx -= 1
|
| 40 |
+
|
| 41 |
+
instruction = ""
|
| 42 |
+
for idx in range(start_idx + 1, end_idx + 1):
|
| 43 |
+
instruction += random.choice(row["instruction_list"][idx-1]) + ". "
|
| 44 |
+
data = self._add_text(data, instruction.rstrip(), need_loss=False)
|
| 45 |
+
data = self._add_image(
|
| 46 |
+
data,
|
| 47 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
|
| 48 |
+
need_loss=True,
|
| 49 |
+
need_vae=False,
|
| 50 |
+
need_vit=False,
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
for idx in range(start_idx + 1, end_idx + 1):
|
| 54 |
+
instruction = random.choice(row["instruction_list"][idx-1])
|
| 55 |
+
data = self._add_text(data, instruction, need_loss=False)
|
| 56 |
+
if idx != end_idx:
|
| 57 |
+
data = self._add_image(
|
| 58 |
+
data,
|
| 59 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
|
| 60 |
+
need_loss=True,
|
| 61 |
+
need_vae=True,
|
| 62 |
+
need_vit=True,
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
data = self._add_image(
|
| 66 |
+
data,
|
| 67 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
|
| 68 |
+
need_loss=True,
|
| 69 |
+
need_vae=False,
|
| 70 |
+
need_vit=False,
|
| 71 |
+
)
|
| 72 |
+
return data
|
data/interleave_datasets/interleave_t2i_dataset.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import pyarrow.parquet as pq
|
| 5 |
+
|
| 6 |
+
from ..distributed_iterable_dataset import DistributedIterableDataset
|
| 7 |
+
from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InterleavedBaseIterableDataset(DistributedIterableDataset):
|
| 11 |
+
|
| 12 |
+
def _init_data(self):
|
| 13 |
+
data = {
|
| 14 |
+
'sequence_plan': [],
|
| 15 |
+
'text_ids_list': [],
|
| 16 |
+
'image_tensor_list': [],
|
| 17 |
+
'num_tokens': 0,
|
| 18 |
+
}
|
| 19 |
+
return data
|
| 20 |
+
|
| 21 |
+
def _add_text(self, data, text, need_loss, enable_cfg=True, next_token_label=None):
|
| 22 |
+
text_ids = self.tokenizer.encode(text)
|
| 23 |
+
data['num_tokens'] += len(text_ids)
|
| 24 |
+
data['text_ids_list'].append(text_ids)
|
| 25 |
+
|
| 26 |
+
# If next_token_label is provided, the im_end token should predict it
|
| 27 |
+
special_token_loss = 1 if next_token_label is not None else 0
|
| 28 |
+
|
| 29 |
+
data['sequence_plan'].append(
|
| 30 |
+
{
|
| 31 |
+
'type': 'text',
|
| 32 |
+
'enable_cfg': int(enable_cfg),
|
| 33 |
+
'loss': int(need_loss),
|
| 34 |
+
'special_token_loss': special_token_loss,
|
| 35 |
+
'special_token_label': next_token_label,
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
return data
|
| 39 |
+
|
| 40 |
+
def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True, special_token_label=None):
|
| 41 |
+
assert need_loss or need_vae or need_vit
|
| 42 |
+
|
| 43 |
+
if need_loss:
|
| 44 |
+
# For loss images, don't add special_token_loss on the start token
|
| 45 |
+
# The previous text token should predict the vision_start token
|
| 46 |
+
data['sequence_plan'].append(
|
| 47 |
+
{
|
| 48 |
+
'type': 'vae_image',
|
| 49 |
+
'enable_cfg': 0,
|
| 50 |
+
'loss': 1,
|
| 51 |
+
'special_token_loss': 0, # No loss on start token itself
|
| 52 |
+
'special_token_label': None,
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
image_tensor = self.transform(image)
|
| 57 |
+
height, width = image_tensor.shape[1:]
|
| 58 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
| 59 |
+
data['image_tensor_list'].append(image_tensor)
|
| 60 |
+
|
| 61 |
+
if need_vae:
|
| 62 |
+
data['sequence_plan'].append(
|
| 63 |
+
{
|
| 64 |
+
'type': 'vae_image',
|
| 65 |
+
'enable_cfg': int(enable_cfg),
|
| 66 |
+
'loss': 0,
|
| 67 |
+
'special_token_loss': 0,
|
| 68 |
+
'special_token_label': None,
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
image_tensor = self.transform(image)
|
| 73 |
+
height, width = image_tensor.shape[1:]
|
| 74 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
| 75 |
+
data['image_tensor_list'].append(image_tensor.clone())
|
| 76 |
+
|
| 77 |
+
if need_vit:
|
| 78 |
+
data['sequence_plan'].append(
|
| 79 |
+
{
|
| 80 |
+
'type': 'vit_image',
|
| 81 |
+
'enable_cfg': int(enable_cfg),
|
| 82 |
+
'loss': 0,
|
| 83 |
+
'special_token_loss': 0,
|
| 84 |
+
'special_token_label': None,
|
| 85 |
+
},
|
| 86 |
+
)
|
| 87 |
+
vit_image_tensor = self.vit_transform(image)
|
| 88 |
+
height, width = vit_image_tensor.shape[1:]
|
| 89 |
+
data['num_tokens'] += width * height // self.vit_transform.stride ** 2
|
| 90 |
+
data['image_tensor_list'].append(vit_image_tensor)
|
| 91 |
+
|
| 92 |
+
return data
|
| 93 |
+
|
| 94 |
+
def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
|
| 95 |
+
assert int(need_loss) + int(need_vae) == 1
|
| 96 |
+
|
| 97 |
+
if need_loss:
|
| 98 |
+
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
|
| 99 |
+
current_sequence_plan = {
|
| 100 |
+
'type': 'vae_image',
|
| 101 |
+
'enable_cfg': 0,
|
| 102 |
+
'loss': 1,
|
| 103 |
+
'special_token_loss': 0,
|
| 104 |
+
'special_token_label': None,
|
| 105 |
+
'split_start': idx == 0,
|
| 106 |
+
'split_end': idx == len(frames) - 1,
|
| 107 |
+
}
|
| 108 |
+
if idx < len(frame_indexes) - 1:
|
| 109 |
+
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
|
| 110 |
+
data['sequence_plan'].append(current_sequence_plan)
|
| 111 |
+
image_tensor = self.transform(image)
|
| 112 |
+
height, width = image_tensor.shape[1:]
|
| 113 |
+
data['image_tensor_list'].append(image_tensor)
|
| 114 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
| 115 |
+
|
| 116 |
+
elif need_vae:
|
| 117 |
+
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
|
| 118 |
+
current_sequence_plan = {
|
| 119 |
+
'type': 'vae_image',
|
| 120 |
+
'enable_cfg': int(enable_cfg),
|
| 121 |
+
'loss': 0,
|
| 122 |
+
'special_token_loss': 0,
|
| 123 |
+
'special_token_label': None,
|
| 124 |
+
'split_start': idx == 0,
|
| 125 |
+
'split_end': idx == len(frames) - 1,
|
| 126 |
+
}
|
| 127 |
+
if idx < len(frame_indexes) - 1:
|
| 128 |
+
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
|
| 129 |
+
data['sequence_plan'].append(current_sequence_plan)
|
| 130 |
+
image_tensor = self.transform(image)
|
| 131 |
+
height, width = image_tensor.shape[1:]
|
| 132 |
+
data['image_tensor_list'].append(image_tensor)
|
| 133 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
| 134 |
+
|
| 135 |
+
return data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ParquetStandardIterableDataset(DistributedIterableDataset):
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self, dataset_name, transform, tokenizer, vit_transform,
|
| 142 |
+
data_dir_list, num_used_data, parquet_info,
|
| 143 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
data_dir_list: list of data directories contains parquet files
|
| 147 |
+
num_used_data: list of number of sampled data paths for each data directory
|
| 148 |
+
vit_transform: input transform for vit model.
|
| 149 |
+
"""
|
| 150 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
| 151 |
+
self.transform = transform
|
| 152 |
+
self.vit_transform = vit_transform
|
| 153 |
+
self.tokenizer = tokenizer
|
| 154 |
+
self.data_status = data_status
|
| 155 |
+
self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
|
| 156 |
+
self.set_epoch()
|
| 157 |
+
|
| 158 |
+
def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
|
| 159 |
+
row_groups = []
|
| 160 |
+
for data_dir, num_data_path in zip(data_dir_list, num_used_data):
|
| 161 |
+
data_paths = get_parquet_data_paths([data_dir], [num_data_path])
|
| 162 |
+
for data_path in data_paths:
|
| 163 |
+
if data_path in parquet_info.keys():
|
| 164 |
+
num_row_groups = parquet_info[data_path]['num_row_groups']
|
| 165 |
+
for rg_idx in range(num_row_groups):
|
| 166 |
+
row_groups.append((data_path, rg_idx))
|
| 167 |
+
return row_groups
|
| 168 |
+
|
| 169 |
+
def parse_row(self, row):
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
def __iter__(self):
|
| 173 |
+
file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
| 174 |
+
if self.data_status is not None:
|
| 175 |
+
global_row_group_start_id = self.data_status[worker_id][0]
|
| 176 |
+
row_start_id = self.data_status[worker_id][1] + 1
|
| 177 |
+
else:
|
| 178 |
+
global_row_group_start_id = 0
|
| 179 |
+
row_start_id = 0
|
| 180 |
+
|
| 181 |
+
print(
|
| 182 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
| 183 |
+
f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
while True:
|
| 187 |
+
file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
|
| 188 |
+
for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
|
| 189 |
+
file_paths_per_worker_, start=global_row_group_start_id
|
| 190 |
+
):
|
| 191 |
+
fs = init_arrow_pf_fs(parquet_file_path)
|
| 192 |
+
with fs.open_input_file(parquet_file_path) as f:
|
| 193 |
+
try:
|
| 194 |
+
fr = pq.ParquetFile(f)
|
| 195 |
+
df = fr.read_row_group(row_group_id).to_pandas()
|
| 196 |
+
df = df.iloc[row_start_id:]
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
for row_idx, row in df.iterrows():
|
| 202 |
+
try:
|
| 203 |
+
data = self.parse_row(row)
|
| 204 |
+
if len(data) == 0:
|
| 205 |
+
continue
|
| 206 |
+
data['data_indexes'] = {
|
| 207 |
+
"data_indexes": [global_row_group_idx, row_idx],
|
| 208 |
+
"worker_id": worker_id,
|
| 209 |
+
"dataset_name": self.dataset_name,
|
| 210 |
+
}
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
|
| 213 |
+
continue
|
| 214 |
+
yield data
|
| 215 |
+
|
| 216 |
+
row_start_id = 0
|
| 217 |
+
global_row_group_start_id = 0
|
| 218 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
data/interleave_datasets/think_trace_dataset.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import traceback
|
| 5 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
| 6 |
+
|
| 7 |
+
from .interleave_t2i_dataset import InterleavedBaseIterableDataset
|
| 8 |
+
from ..data_utils import pil_img2rgb
|
| 9 |
+
from ..distributed_iterable_dataset import DistributedIterableDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
| 13 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 14 |
+
MaximumDecompressedSize = 1024
|
| 15 |
+
MegaByte = 2 ** 20
|
| 16 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ThinkTraceJSONLIterableDataset(InterleavedBaseIterableDataset, DistributedIterableDataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset_name,
|
| 23 |
+
transform,
|
| 24 |
+
tokenizer,
|
| 25 |
+
vit_transform,
|
| 26 |
+
jsonl_path_list,
|
| 27 |
+
data_dir_list,
|
| 28 |
+
num_used_data,
|
| 29 |
+
local_rank=0,
|
| 30 |
+
world_size=1,
|
| 31 |
+
num_workers=8,
|
| 32 |
+
data_status=None,
|
| 33 |
+
shuffle_lines=True,
|
| 34 |
+
shuffle_seed=0,
|
| 35 |
+
image_prefix_dir=None,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Dataset for think-trace style JSONL files with interleaved text and images.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dataset_name: Name of the dataset
|
| 42 |
+
transform: Transform for VAE images
|
| 43 |
+
tokenizer: Text tokenizer
|
| 44 |
+
vit_transform: Transform for VIT images
|
| 45 |
+
jsonl_path_list: List of JSONL file paths
|
| 46 |
+
data_dir_list: List of base directories (should match jsonl_path_list)
|
| 47 |
+
num_used_data: List of number of samples to use from each JSONL. If a value is None or non-positive, all data from that JSONL will be used.
|
| 48 |
+
image_prefix_dir: Absolute path to prepend to relative image paths
|
| 49 |
+
Other args: Standard distributed dataset args
|
| 50 |
+
"""
|
| 51 |
+
DistributedIterableDataset.__init__(self, dataset_name, local_rank, world_size, num_workers)
|
| 52 |
+
self.transform = transform
|
| 53 |
+
self.vit_transform = vit_transform
|
| 54 |
+
self.tokenizer = tokenizer
|
| 55 |
+
self.data_status = data_status
|
| 56 |
+
self.image_prefix_dir = image_prefix_dir or ""
|
| 57 |
+
|
| 58 |
+
self.start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
|
| 59 |
+
self.end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
|
| 60 |
+
self.im_start = tokenizer.convert_tokens_to_ids('<|im_start|>')
|
| 61 |
+
|
| 62 |
+
self.data_paths = self.get_data_paths(
|
| 63 |
+
jsonl_path_list,
|
| 64 |
+
num_used_data,
|
| 65 |
+
shuffle_lines,
|
| 66 |
+
shuffle_seed,
|
| 67 |
+
)
|
| 68 |
+
self.set_epoch()
|
| 69 |
+
|
| 70 |
+
def get_data_paths(self, jsonl_path_list, num_used_data, shuffle_lines, shuffle_seed):
|
| 71 |
+
data_paths = []
|
| 72 |
+
if not isinstance(num_used_data, list):
|
| 73 |
+
num_used_data = [num_used_data] * len(jsonl_path_list)
|
| 74 |
+
|
| 75 |
+
for jsonl_path, num_data_point in zip(jsonl_path_list, num_used_data):
|
| 76 |
+
with open(jsonl_path, 'r') as f:
|
| 77 |
+
raw_data = f.readlines()
|
| 78 |
+
if shuffle_lines:
|
| 79 |
+
self.rng.seed(shuffle_seed)
|
| 80 |
+
self.rng.shuffle(raw_data)
|
| 81 |
+
|
| 82 |
+
# Convert 'None' string to None type
|
| 83 |
+
if num_data_point == 'None':
|
| 84 |
+
num_data_point = None
|
| 85 |
+
|
| 86 |
+
if num_data_point is not None and int(num_data_point) > 0:
|
| 87 |
+
raw_data = raw_data[:int(num_data_point)]
|
| 88 |
+
|
| 89 |
+
data_paths.extend(raw_data)
|
| 90 |
+
return data_paths
|
| 91 |
+
|
| 92 |
+
def extract_image_references(self, text):
|
| 93 |
+
"""Extract image references from text like <image_start>[problem_image_1]<image_end>"""
|
| 94 |
+
pattern = r'<image_start>\[([^\]]+)\]<image_end>'
|
| 95 |
+
matches = re.findall(pattern, text)
|
| 96 |
+
return matches
|
| 97 |
+
|
| 98 |
+
def replace_image_references(self, text):
|
| 99 |
+
"""Replace image references with placeholder tokens for processing"""
|
| 100 |
+
pattern = r'<image_start>\[([^\]]+)\]<image_end>'
|
| 101 |
+
# Replace with a special placeholder that we'll process later
|
| 102 |
+
return re.sub(pattern, '<IMAGE_PLACEHOLDER>', text)
|
| 103 |
+
|
| 104 |
+
def remove_thought_patterns(self, text):
|
| 105 |
+
"""Remove THOUGHT x: patterns from text"""
|
| 106 |
+
# Remove patterns like "THOUGHT 1:", "THOUGHT 2:", etc.
|
| 107 |
+
pattern = r'THOUGHT\s*\d+:\s*'
|
| 108 |
+
return re.sub(pattern, '', text)
|
| 109 |
+
|
| 110 |
+
def load_image_safely(self, data_item, image_key):
|
| 111 |
+
"""Load image with null checking and path resolution"""
|
| 112 |
+
if image_key not in data_item or data_item[image_key] is None:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
image_path = data_item[image_key]
|
| 116 |
+
full_path = os.path.join(self.image_prefix_dir, image_path)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
return pil_img2rgb(Image.open(full_path))
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Failed to load image {full_path}: {e}")
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
def parse_row(self, json_line):
|
| 125 |
+
"""Parse a single JSON line into the required format"""
|
| 126 |
+
try:
|
| 127 |
+
data_item = json.loads(json_line.strip())
|
| 128 |
+
except:
|
| 129 |
+
traceback.print_exc()
|
| 130 |
+
return {}
|
| 131 |
+
|
| 132 |
+
# Extract the main fields
|
| 133 |
+
prompt = "You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and generate visual aids to enhance your problem-solving. You should first think about the reasoning and planning process in the mind before generating visual aids. Wrap your text reasoning with <think></think> tokens, and wrap your final conclusion with <answer></answer> tokens. Provide your final conclusion clearly in the format of '<answer>Final Answer: <answer here></answer>'"
|
| 134 |
+
question = data_item.get('Question', '')
|
| 135 |
+
question = f'Question: {question}'
|
| 136 |
+
reasoning_trace = data_item.get('Text Reasoning Trace', '')
|
| 137 |
+
reasoning_trace = f'{reasoning_trace}'
|
| 138 |
+
final_answer = data_item.get('Final Answer', '')
|
| 139 |
+
final_answer = f'<answer>Final Answer: {final_answer}</answer>'
|
| 140 |
+
|
| 141 |
+
if not question or not reasoning_trace or not final_answer:
|
| 142 |
+
return {}
|
| 143 |
+
|
| 144 |
+
# Build the sequence
|
| 145 |
+
data = self._init_data()
|
| 146 |
+
|
| 147 |
+
# 0. Add prompt
|
| 148 |
+
data = self._add_text(data, prompt, need_loss=False, enable_cfg=True)
|
| 149 |
+
|
| 150 |
+
# 1. Add question (with image parsing)
|
| 151 |
+
question_image_refs = self.extract_image_references(question)
|
| 152 |
+
if question_image_refs:
|
| 153 |
+
clean_question = self.replace_image_references(question)
|
| 154 |
+
question_text_parts = clean_question.split('<IMAGE_PLACEHOLDER>')
|
| 155 |
+
|
| 156 |
+
if len(question_text_parts) != len(question_image_refs) + 1:
|
| 157 |
+
print(f"Mismatch in question: text parts {len(question_text_parts)}, images {len(question_image_refs)}")
|
| 158 |
+
return {}
|
| 159 |
+
|
| 160 |
+
question_images = []
|
| 161 |
+
for image_ref in question_image_refs:
|
| 162 |
+
image = self.load_image_safely(data_item, image_ref)
|
| 163 |
+
if image is None:
|
| 164 |
+
print(f"Skipping sample due to missing image in question: {image_ref}")
|
| 165 |
+
return {}
|
| 166 |
+
question_images.append(image)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
for i, text_part in enumerate(question_text_parts):
|
| 170 |
+
if text_part.strip():
|
| 171 |
+
# Question text has no loss, so no need for vision start prediction
|
| 172 |
+
data = self._add_text(data, text_part.strip(), need_loss=False, enable_cfg=True)
|
| 173 |
+
if i < len(question_images):
|
| 174 |
+
data = self._add_image(
|
| 175 |
+
data, question_images[i],
|
| 176 |
+
need_loss=False, # No loss for question images
|
| 177 |
+
need_vae=False, # VAE conditioning
|
| 178 |
+
need_vit=True, # VIT understanding
|
| 179 |
+
enable_cfg=True,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
# Original behavior if no images in question
|
| 183 |
+
data = self._add_text(data, question, need_loss=False, enable_cfg=True)
|
| 184 |
+
|
| 185 |
+
# 2. Interleave text parts and images from reasoning trace
|
| 186 |
+
image_refs = self.extract_image_references(reasoning_trace)
|
| 187 |
+
|
| 188 |
+
loaded_images = []
|
| 189 |
+
for image_ref in image_refs:
|
| 190 |
+
image = self.load_image_safely(data_item, image_ref)
|
| 191 |
+
if image is not None:
|
| 192 |
+
loaded_images.append(image)
|
| 193 |
+
else:
|
| 194 |
+
# If image fails to load, skip this sample
|
| 195 |
+
print(f"Skipping sample due to missing image: {image_ref}")
|
| 196 |
+
return {}
|
| 197 |
+
|
| 198 |
+
# Clean reasoning trace by removing image references for text processing
|
| 199 |
+
clean_reasoning_trace = self.replace_image_references(reasoning_trace)
|
| 200 |
+
|
| 201 |
+
# Remove THOUGHT patterns from the reasoning trace
|
| 202 |
+
clean_reasoning_trace = self.remove_thought_patterns(clean_reasoning_trace)
|
| 203 |
+
|
| 204 |
+
# Append final answer to the reasoning trace
|
| 205 |
+
# clean_reasoning_trace += f"\n\nFinal Answer: {final_answer}"
|
| 206 |
+
|
| 207 |
+
# Split reasoning trace by image placeholders to interleave text and images
|
| 208 |
+
text_parts = clean_reasoning_trace.split('<IMAGE_PLACEHOLDER>')
|
| 209 |
+
|
| 210 |
+
if len(text_parts) != len(loaded_images) + 1:
|
| 211 |
+
print(f"Mismatch between text parts ({len(text_parts)}) and images ({len(loaded_images)})")
|
| 212 |
+
return {}
|
| 213 |
+
|
| 214 |
+
# 4. Interleave text parts and images from reasoning trace
|
| 215 |
+
for i, text_part in enumerate(text_parts):
|
| 216 |
+
# Add text part if not empty
|
| 217 |
+
if text_part.strip():
|
| 218 |
+
# Wrap reasoning text with <think></think> tokens
|
| 219 |
+
wrapped_text = f"<think>{text_part.strip()}</think>"
|
| 220 |
+
|
| 221 |
+
# Determine what the im_end token should predict
|
| 222 |
+
if i < len(loaded_images):
|
| 223 |
+
# If this text part is followed by an image, predict vision_start
|
| 224 |
+
next_token_label = self.start_of_image
|
| 225 |
+
elif i == len(text_parts) - 1:
|
| 226 |
+
# If this is the last text part, predict im_start for final answer
|
| 227 |
+
next_token_label = self.im_start
|
| 228 |
+
else:
|
| 229 |
+
next_token_label = None
|
| 230 |
+
|
| 231 |
+
data = self._add_text(data, wrapped_text, need_loss=True, enable_cfg=True, next_token_label=next_token_label)
|
| 232 |
+
|
| 233 |
+
# Add image if available
|
| 234 |
+
if i < len(loaded_images):
|
| 235 |
+
# Add image with both VAE and VIT processing for full capability
|
| 236 |
+
data = self._add_image(
|
| 237 |
+
data,
|
| 238 |
+
loaded_images[i],
|
| 239 |
+
need_loss=True, # VAE generation loss
|
| 240 |
+
need_vae=True, # VAE conditioning
|
| 241 |
+
need_vit=True, # VIT understanding
|
| 242 |
+
enable_cfg=True,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# 5. Add final answer
|
| 246 |
+
data = self._add_text(data, final_answer, need_loss=True, enable_cfg=True)# ybq1025 need_loss=False
|
| 247 |
+
|
| 248 |
+
return data
|
| 249 |
+
|
| 250 |
+
def __iter__(self):
|
| 251 |
+
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
| 252 |
+
if self.data_status is not None:
|
| 253 |
+
row_start_id = self.data_status[worker_id] + 1
|
| 254 |
+
else:
|
| 255 |
+
row_start_id = 0
|
| 256 |
+
|
| 257 |
+
print(
|
| 258 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
| 259 |
+
f"resuming data at row#{row_start_id}"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
while True:
|
| 263 |
+
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
|
| 264 |
+
for row_idx, json_line in enumerate(data_paths_per_worker_, start=row_start_id):
|
| 265 |
+
try:
|
| 266 |
+
data = self.parse_row(json_line)
|
| 267 |
+
if len(data) == 0:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
# Check if we have any loss
|
| 271 |
+
has_loss = any(item['loss'] for item in data['sequence_plan'])
|
| 272 |
+
if not has_loss:
|
| 273 |
+
print('No loss defined, skipped.')
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
data['data_indexes'] = {
|
| 277 |
+
"data_indexes": row_idx,
|
| 278 |
+
"worker_id": worker_id,
|
| 279 |
+
"dataset_name": self.dataset_name,
|
| 280 |
+
}
|
| 281 |
+
yield data
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"Error processing row {row_idx}: {e}")
|
| 285 |
+
traceback.print_exc()
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
row_start_id = 0
|
| 289 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
modeling/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from . import bagel, qwen2, siglip, autoencoder
|
modeling/autoencoder.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Black Forest Labs.
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under Apache-2.0, with the full license text
|
| 8 |
+
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from torch import Tensor, nn
|
| 17 |
+
from safetensors.torch import load_file as load_sft
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AutoEncoderParams:
|
| 22 |
+
resolution: int
|
| 23 |
+
in_channels: int
|
| 24 |
+
downsample: int
|
| 25 |
+
ch: int
|
| 26 |
+
out_ch: int
|
| 27 |
+
ch_mult: list[int]
|
| 28 |
+
num_res_blocks: int
|
| 29 |
+
z_channels: int
|
| 30 |
+
scale_factor: float
|
| 31 |
+
shift_factor: float
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def swish(x: Tensor) -> Tensor:
|
| 35 |
+
return x * torch.sigmoid(x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AttnBlock(nn.Module):
|
| 39 |
+
def __init__(self, in_channels: int):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.in_channels = in_channels
|
| 42 |
+
|
| 43 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 44 |
+
|
| 45 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 46 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 47 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 48 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 49 |
+
|
| 50 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 51 |
+
h_ = self.norm(h_)
|
| 52 |
+
q = self.q(h_)
|
| 53 |
+
k = self.k(h_)
|
| 54 |
+
v = self.v(h_)
|
| 55 |
+
|
| 56 |
+
b, c, h, w = q.shape
|
| 57 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 58 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 59 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 60 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 61 |
+
|
| 62 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 65 |
+
return x + self.proj_out(self.attention(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ResnetBlock(nn.Module):
|
| 69 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.in_channels = in_channels
|
| 72 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 73 |
+
self.out_channels = out_channels
|
| 74 |
+
|
| 75 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 76 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 77 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 78 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 79 |
+
if self.in_channels != self.out_channels:
|
| 80 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
h = x
|
| 84 |
+
h = self.norm1(h)
|
| 85 |
+
h = swish(h)
|
| 86 |
+
h = self.conv1(h)
|
| 87 |
+
|
| 88 |
+
h = self.norm2(h)
|
| 89 |
+
h = swish(h)
|
| 90 |
+
h = self.conv2(h)
|
| 91 |
+
|
| 92 |
+
if self.in_channels != self.out_channels:
|
| 93 |
+
x = self.nin_shortcut(x)
|
| 94 |
+
|
| 95 |
+
return x + h
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Downsample(nn.Module):
|
| 99 |
+
def __init__(self, in_channels: int):
|
| 100 |
+
super().__init__()
|
| 101 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 102 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 103 |
+
|
| 104 |
+
def forward(self, x: Tensor):
|
| 105 |
+
pad = (0, 1, 0, 1)
|
| 106 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 107 |
+
x = self.conv(x)
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Upsample(nn.Module):
|
| 112 |
+
def __init__(self, in_channels: int):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 115 |
+
|
| 116 |
+
def forward(self, x: Tensor):
|
| 117 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 118 |
+
x = self.conv(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Encoder(nn.Module):
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
resolution: int,
|
| 126 |
+
in_channels: int,
|
| 127 |
+
ch: int,
|
| 128 |
+
ch_mult: list[int],
|
| 129 |
+
num_res_blocks: int,
|
| 130 |
+
z_channels: int,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.ch = ch
|
| 134 |
+
self.num_resolutions = len(ch_mult)
|
| 135 |
+
self.num_res_blocks = num_res_blocks
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
# downsampling
|
| 139 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 140 |
+
|
| 141 |
+
curr_res = resolution
|
| 142 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 143 |
+
self.in_ch_mult = in_ch_mult
|
| 144 |
+
self.down = nn.ModuleList()
|
| 145 |
+
block_in = self.ch
|
| 146 |
+
for i_level in range(self.num_resolutions):
|
| 147 |
+
block = nn.ModuleList()
|
| 148 |
+
attn = nn.ModuleList()
|
| 149 |
+
block_in = ch * in_ch_mult[i_level]
|
| 150 |
+
block_out = ch * ch_mult[i_level]
|
| 151 |
+
for _ in range(self.num_res_blocks):
|
| 152 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 153 |
+
block_in = block_out
|
| 154 |
+
down = nn.Module()
|
| 155 |
+
down.block = block
|
| 156 |
+
down.attn = attn
|
| 157 |
+
if i_level != self.num_resolutions - 1:
|
| 158 |
+
down.downsample = Downsample(block_in)
|
| 159 |
+
curr_res = curr_res // 2
|
| 160 |
+
self.down.append(down)
|
| 161 |
+
|
| 162 |
+
# middle
|
| 163 |
+
self.mid = nn.Module()
|
| 164 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 165 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 166 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 167 |
+
|
| 168 |
+
# end
|
| 169 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 170 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 173 |
+
# downsampling
|
| 174 |
+
hs = [self.conv_in(x)]
|
| 175 |
+
for i_level in range(self.num_resolutions):
|
| 176 |
+
for i_block in range(self.num_res_blocks):
|
| 177 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 178 |
+
if len(self.down[i_level].attn) > 0:
|
| 179 |
+
h = self.down[i_level].attn[i_block](h)
|
| 180 |
+
hs.append(h)
|
| 181 |
+
if i_level != self.num_resolutions - 1:
|
| 182 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 183 |
+
|
| 184 |
+
# middle
|
| 185 |
+
h = hs[-1]
|
| 186 |
+
h = self.mid.block_1(h)
|
| 187 |
+
h = self.mid.attn_1(h)
|
| 188 |
+
h = self.mid.block_2(h)
|
| 189 |
+
# end
|
| 190 |
+
h = self.norm_out(h)
|
| 191 |
+
h = swish(h)
|
| 192 |
+
h = self.conv_out(h)
|
| 193 |
+
return h
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Decoder(nn.Module):
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
ch: int,
|
| 200 |
+
out_ch: int,
|
| 201 |
+
ch_mult: list[int],
|
| 202 |
+
num_res_blocks: int,
|
| 203 |
+
in_channels: int,
|
| 204 |
+
resolution: int,
|
| 205 |
+
z_channels: int,
|
| 206 |
+
):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.ch = ch
|
| 209 |
+
self.num_resolutions = len(ch_mult)
|
| 210 |
+
self.num_res_blocks = num_res_blocks
|
| 211 |
+
self.resolution = resolution
|
| 212 |
+
self.in_channels = in_channels
|
| 213 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 214 |
+
|
| 215 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 216 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 217 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 218 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 219 |
+
|
| 220 |
+
# z to block_in
|
| 221 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 222 |
+
|
| 223 |
+
# middle
|
| 224 |
+
self.mid = nn.Module()
|
| 225 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 226 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 227 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 228 |
+
|
| 229 |
+
# upsampling
|
| 230 |
+
self.up = nn.ModuleList()
|
| 231 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 232 |
+
block = nn.ModuleList()
|
| 233 |
+
attn = nn.ModuleList()
|
| 234 |
+
block_out = ch * ch_mult[i_level]
|
| 235 |
+
for _ in range(self.num_res_blocks + 1):
|
| 236 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 237 |
+
block_in = block_out
|
| 238 |
+
up = nn.Module()
|
| 239 |
+
up.block = block
|
| 240 |
+
up.attn = attn
|
| 241 |
+
if i_level != 0:
|
| 242 |
+
up.upsample = Upsample(block_in)
|
| 243 |
+
curr_res = curr_res * 2
|
| 244 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 245 |
+
|
| 246 |
+
# end
|
| 247 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 248 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 249 |
+
|
| 250 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 251 |
+
# z to block_in
|
| 252 |
+
h = self.conv_in(z)
|
| 253 |
+
|
| 254 |
+
# middle
|
| 255 |
+
h = self.mid.block_1(h)
|
| 256 |
+
h = self.mid.attn_1(h)
|
| 257 |
+
h = self.mid.block_2(h)
|
| 258 |
+
|
| 259 |
+
# upsampling
|
| 260 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 261 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 262 |
+
h = self.up[i_level].block[i_block](h)
|
| 263 |
+
if len(self.up[i_level].attn) > 0:
|
| 264 |
+
h = self.up[i_level].attn[i_block](h)
|
| 265 |
+
if i_level != 0:
|
| 266 |
+
h = self.up[i_level].upsample(h)
|
| 267 |
+
|
| 268 |
+
# end
|
| 269 |
+
h = self.norm_out(h)
|
| 270 |
+
h = swish(h)
|
| 271 |
+
h = self.conv_out(h)
|
| 272 |
+
return h
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class DiagonalGaussian(nn.Module):
|
| 276 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.sample = sample
|
| 279 |
+
self.chunk_dim = chunk_dim
|
| 280 |
+
|
| 281 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 282 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 283 |
+
if self.sample:
|
| 284 |
+
std = torch.exp(0.5 * logvar)
|
| 285 |
+
return mean + std * torch.randn_like(mean)
|
| 286 |
+
else:
|
| 287 |
+
return mean
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class AutoEncoder(nn.Module):
|
| 291 |
+
def __init__(self, params: AutoEncoderParams):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.encoder = Encoder(
|
| 294 |
+
resolution=params.resolution,
|
| 295 |
+
in_channels=params.in_channels,
|
| 296 |
+
ch=params.ch,
|
| 297 |
+
ch_mult=params.ch_mult,
|
| 298 |
+
num_res_blocks=params.num_res_blocks,
|
| 299 |
+
z_channels=params.z_channels,
|
| 300 |
+
)
|
| 301 |
+
self.decoder = Decoder(
|
| 302 |
+
resolution=params.resolution,
|
| 303 |
+
in_channels=params.in_channels,
|
| 304 |
+
ch=params.ch,
|
| 305 |
+
out_ch=params.out_ch,
|
| 306 |
+
ch_mult=params.ch_mult,
|
| 307 |
+
num_res_blocks=params.num_res_blocks,
|
| 308 |
+
z_channels=params.z_channels,
|
| 309 |
+
)
|
| 310 |
+
self.reg = DiagonalGaussian()
|
| 311 |
+
|
| 312 |
+
self.scale_factor = params.scale_factor
|
| 313 |
+
self.shift_factor = params.shift_factor
|
| 314 |
+
|
| 315 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 316 |
+
z = self.reg(self.encoder(x))
|
| 317 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 318 |
+
return z
|
| 319 |
+
|
| 320 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 321 |
+
z = z / self.scale_factor + self.shift_factor
|
| 322 |
+
return self.decoder(z)
|
| 323 |
+
|
| 324 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 325 |
+
return self.decode(self.encode(x))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
| 329 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
| 330 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 331 |
+
print("\n" + "-" * 79 + "\n")
|
| 332 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 333 |
+
elif len(missing) > 0:
|
| 334 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 335 |
+
elif len(unexpected) > 0:
|
| 336 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def load_ae(local_path: str) -> AutoEncoder:
|
| 340 |
+
ae_params = AutoEncoderParams(
|
| 341 |
+
resolution=256,
|
| 342 |
+
in_channels=3,
|
| 343 |
+
downsample=8,
|
| 344 |
+
ch=128,
|
| 345 |
+
out_ch=3,
|
| 346 |
+
ch_mult=[1, 2, 4, 4],
|
| 347 |
+
num_res_blocks=2,
|
| 348 |
+
z_channels=16,
|
| 349 |
+
scale_factor=0.3611,
|
| 350 |
+
shift_factor=0.1159,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Loading the autoencoder
|
| 354 |
+
ae = AutoEncoder(ae_params)
|
| 355 |
+
|
| 356 |
+
if local_path is not None:
|
| 357 |
+
sd = load_sft(local_path)
|
| 358 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 359 |
+
print_load_warning(missing, unexpected)
|
| 360 |
+
return ae, ae_params
|
modeling/bagel/bagel.py
ADDED
|
@@ -0,0 +1,1068 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import copy
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
| 11 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 12 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
+
|
| 14 |
+
from data.data_utils import (
|
| 15 |
+
create_sparse_mask,
|
| 16 |
+
get_flattened_position_ids_extrapolate,
|
| 17 |
+
get_flattened_position_ids_interpolate,
|
| 18 |
+
patchify,
|
| 19 |
+
)
|
| 20 |
+
from .qwen2_navit import NaiveCache
|
| 21 |
+
from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
|
| 22 |
+
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BagelConfig(PretrainedConfig):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
visual_gen=True,
|
| 30 |
+
visual_und=True,
|
| 31 |
+
llm_config=None,
|
| 32 |
+
vit_config=None,
|
| 33 |
+
vae_config=None,
|
| 34 |
+
latent_patch_size=2,
|
| 35 |
+
max_latent_size=32,
|
| 36 |
+
vit_max_num_patch_per_side=70,
|
| 37 |
+
connector_act="gelu_pytorch_tanh",
|
| 38 |
+
interpolate_pos=False,
|
| 39 |
+
timestep_shift=1.0,
|
| 40 |
+
**kwargs
|
| 41 |
+
):
|
| 42 |
+
super().__init__(**kwargs)
|
| 43 |
+
self.visual_gen = visual_gen
|
| 44 |
+
self.visual_und = visual_und
|
| 45 |
+
self.llm_config = llm_config
|
| 46 |
+
self.vit_config = vit_config
|
| 47 |
+
self.vae_config = vae_config
|
| 48 |
+
self.latent_patch_size = latent_patch_size
|
| 49 |
+
self.max_latent_size = max_latent_size
|
| 50 |
+
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
|
| 51 |
+
self.connector_act = connector_act
|
| 52 |
+
self.interpolate_pos = interpolate_pos
|
| 53 |
+
self.timestep_shift = timestep_shift
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Bagel(PreTrainedModel):
|
| 57 |
+
config_class = BagelConfig
|
| 58 |
+
base_model_prefix = 'bagel'
|
| 59 |
+
|
| 60 |
+
def __init__(self, language_model, vit_model, config: BagelConfig):
|
| 61 |
+
super().__init__(config)
|
| 62 |
+
self.language_model = language_model
|
| 63 |
+
self.hidden_size = config.llm_config.hidden_size
|
| 64 |
+
self.use_moe = "Mo" in config.llm_config.layer_module
|
| 65 |
+
self.num_heads = config.llm_config.num_attention_heads
|
| 66 |
+
|
| 67 |
+
if config.visual_gen:
|
| 68 |
+
self.latent_patch_size = config.latent_patch_size
|
| 69 |
+
self.timestep_shift = config.timestep_shift
|
| 70 |
+
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
|
| 71 |
+
self.max_latent_size = config.max_latent_size
|
| 72 |
+
self.latent_channel = config.vae_config.z_channels
|
| 73 |
+
self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
|
| 74 |
+
self.time_embedder = TimestepEmbedder(self.hidden_size)
|
| 75 |
+
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
|
| 76 |
+
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
|
| 77 |
+
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
|
| 78 |
+
|
| 79 |
+
if config.visual_und:
|
| 80 |
+
self.vit_model = vit_model
|
| 81 |
+
self.vit_patch_size = config.vit_config.patch_size
|
| 82 |
+
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
|
| 83 |
+
self.vit_hidden_size = config.vit_config.hidden_size
|
| 84 |
+
self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
|
| 85 |
+
self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
|
| 86 |
+
|
| 87 |
+
if config.interpolate_pos:
|
| 88 |
+
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
|
| 89 |
+
else:
|
| 90 |
+
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
|
| 91 |
+
|
| 92 |
+
self.config = config
|
| 93 |
+
self._init_weights()
|
| 94 |
+
|
| 95 |
+
def _init_weights(self):
|
| 96 |
+
if self.config.visual_gen:
|
| 97 |
+
nn.init.constant_(self.llm2vae.weight, 0)
|
| 98 |
+
nn.init.constant_(self.llm2vae.bias, 0)
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
sequence_length: int,
|
| 103 |
+
packed_text_ids: torch.LongTensor,
|
| 104 |
+
packed_text_indexes: torch.LongTensor,
|
| 105 |
+
sample_lens: List[int],
|
| 106 |
+
packed_position_ids: torch.LongTensor,
|
| 107 |
+
nested_attention_masks: List[torch.Tensor] = None,
|
| 108 |
+
split_lens: List[int] = None,
|
| 109 |
+
attn_modes: List[str] = None,
|
| 110 |
+
# for visual understanding
|
| 111 |
+
ce_loss_indexes: Optional[torch.BoolTensor] = None,
|
| 112 |
+
packed_label_ids: Optional[torch.LongTensor] = None,
|
| 113 |
+
packed_vit_tokens: Optional[torch.Tensor] = None,
|
| 114 |
+
packed_vit_token_indexes: Optional[torch.LongTensor] = None,
|
| 115 |
+
packed_vit_position_ids: Optional[torch.LongTensor] = None,
|
| 116 |
+
vit_token_seqlens: Optional[torch.IntTensor] = None,
|
| 117 |
+
# for visual generation
|
| 118 |
+
padded_latent: Optional[torch.Tensor] = None,
|
| 119 |
+
patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
|
| 120 |
+
packed_latent_position_ids: Optional[torch.LongTensor] = None,
|
| 121 |
+
packed_vae_token_indexes: Optional[torch.LongTensor] = None,
|
| 122 |
+
packed_timesteps: Optional[torch.LongTensor] = None,
|
| 123 |
+
mse_loss_indexes: Optional[torch.BoolTensor] = None,
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Args:
|
| 127 |
+
sequence_length: length of sequence.
|
| 128 |
+
packed_text_ids: 1-D int tensor, packed text token ids.
|
| 129 |
+
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
|
| 130 |
+
sample_lens: A list of N ints, length of each sample in packed_sequence.
|
| 131 |
+
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
|
| 132 |
+
-inf means ignore.
|
| 133 |
+
packed_position_ids: packed 1-D positions, an image has only one global position shared
|
| 134 |
+
by all latent tokens.
|
| 135 |
+
|
| 136 |
+
packed_vit_tokens: packed patchified image tokens for vit model.
|
| 137 |
+
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
|
| 138 |
+
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
|
| 139 |
+
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
|
| 140 |
+
packed_label_ids: 1-D int tensor, packed label token ids.
|
| 141 |
+
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
|
| 142 |
+
|
| 143 |
+
padded_latent: padded latent from VAE encoder.
|
| 144 |
+
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
|
| 145 |
+
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
|
| 146 |
+
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
|
| 147 |
+
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
|
| 148 |
+
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
|
| 149 |
+
"""
|
| 150 |
+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
|
| 151 |
+
packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
|
| 152 |
+
packed_sequence[packed_text_indexes] = packed_text_embedding
|
| 153 |
+
|
| 154 |
+
if nested_attention_masks is None:
|
| 155 |
+
sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
|
| 156 |
+
seqlen = sum(sample_lens)
|
| 157 |
+
block_mask = create_block_mask(
|
| 158 |
+
sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
|
| 159 |
+
device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
|
| 160 |
+
)
|
| 161 |
+
attention_mask = block_mask
|
| 162 |
+
else:
|
| 163 |
+
attention_mask = nested_attention_masks
|
| 164 |
+
|
| 165 |
+
# if self.config.visual_und and vit_token_seqlens is not None:
|
| 166 |
+
if self.config.visual_und:
|
| 167 |
+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
|
| 168 |
+
cu_seqlens = cu_seqlens.to(torch.int32)
|
| 169 |
+
max_seqlen = torch.max(vit_token_seqlens).item()
|
| 170 |
+
packed_vit_token_embed = self.vit_model(
|
| 171 |
+
packed_pixel_values=packed_vit_tokens,
|
| 172 |
+
packed_flattened_position_ids=packed_vit_position_ids,
|
| 173 |
+
cu_seqlens=cu_seqlens,
|
| 174 |
+
max_seqlen=max_seqlen,
|
| 175 |
+
)
|
| 176 |
+
packed_vit_token_embed = self.connector(packed_vit_token_embed)
|
| 177 |
+
vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
|
| 178 |
+
packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
|
| 179 |
+
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
|
| 180 |
+
|
| 181 |
+
if self.config.visual_gen:
|
| 182 |
+
p = self.latent_patch_size
|
| 183 |
+
packed_latent = []
|
| 184 |
+
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
|
| 185 |
+
latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
|
| 186 |
+
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
|
| 187 |
+
packed_latent.append(latent)
|
| 188 |
+
packed_latent_clean = torch.cat(packed_latent, dim=0)
|
| 189 |
+
|
| 190 |
+
noise = torch.randn_like(packed_latent_clean)
|
| 191 |
+
packed_timesteps = torch.sigmoid(packed_timesteps)
|
| 192 |
+
packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
|
| 193 |
+
packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
|
| 194 |
+
packed_timestep_embeds = self.time_embedder(packed_timesteps)
|
| 195 |
+
latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
|
| 196 |
+
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
|
| 197 |
+
packed_sequence[packed_vae_token_indexes] = packed_latent
|
| 198 |
+
|
| 199 |
+
extra_inputs = {}
|
| 200 |
+
if self.use_moe:
|
| 201 |
+
packed_und_token_indexes = packed_text_indexes
|
| 202 |
+
if packed_vit_token_indexes is not None:
|
| 203 |
+
packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
|
| 204 |
+
extra_inputs.update(
|
| 205 |
+
packed_und_token_indexes=packed_und_token_indexes,
|
| 206 |
+
packed_gen_token_indexes=packed_vae_token_indexes,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
last_hidden_state = self.language_model(
|
| 210 |
+
packed_sequence=packed_sequence,
|
| 211 |
+
sample_lens=sample_lens,
|
| 212 |
+
attention_mask=attention_mask,
|
| 213 |
+
packed_position_ids=packed_position_ids,
|
| 214 |
+
**extra_inputs,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
mse = None
|
| 218 |
+
if self.config.visual_gen:
|
| 219 |
+
packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
|
| 220 |
+
target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
|
| 221 |
+
has_mse = packed_timesteps > 0
|
| 222 |
+
mse = (packed_mse_preds - target[has_mse]) ** 2
|
| 223 |
+
|
| 224 |
+
ce = None
|
| 225 |
+
if ce_loss_indexes is not None:
|
| 226 |
+
packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
|
| 227 |
+
ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
|
| 228 |
+
|
| 229 |
+
return dict(mse=mse, ce=ce)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
|
| 233 |
+
packed_text_ids = list()
|
| 234 |
+
packed_text_position_ids = list()
|
| 235 |
+
text_token_lens = list()
|
| 236 |
+
packed_text_indexes = list()
|
| 237 |
+
packed_key_value_indexes = list()
|
| 238 |
+
|
| 239 |
+
curr = 0
|
| 240 |
+
newlens, new_rope = list(), list()
|
| 241 |
+
for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
|
| 242 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 243 |
+
curr += curr_kvlen
|
| 244 |
+
|
| 245 |
+
text_ids = tokenizer.encode(prompt)
|
| 246 |
+
text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
|
| 247 |
+
text_token_lens.append(len(text_ids))
|
| 248 |
+
packed_text_ids.extend(text_ids)
|
| 249 |
+
packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
|
| 250 |
+
packed_text_indexes.extend(range(curr, curr + len(text_ids)))
|
| 251 |
+
newlens.append(curr_kvlen + len(text_ids))
|
| 252 |
+
new_rope.append(curr_position_id + len(text_ids))
|
| 253 |
+
curr += len(text_ids)
|
| 254 |
+
|
| 255 |
+
generation_input = {
|
| 256 |
+
"text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
|
| 257 |
+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
|
| 258 |
+
"packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
|
| 259 |
+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
|
| 260 |
+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 261 |
+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
return generation_input, newlens, new_rope
|
| 265 |
+
|
| 266 |
+
@torch.no_grad
|
| 267 |
+
def forward_cache_update_text(
|
| 268 |
+
self,
|
| 269 |
+
past_key_values: NaiveCache,
|
| 270 |
+
packed_text_ids: torch.IntTensor,
|
| 271 |
+
packed_text_position_ids: torch.LongTensor,
|
| 272 |
+
text_token_lens: torch.LongTensor,
|
| 273 |
+
packed_text_indexes: torch.LongTensor,
|
| 274 |
+
packed_key_value_indexes: torch.LongTensor,
|
| 275 |
+
key_values_lens: torch.IntTensor,
|
| 276 |
+
):
|
| 277 |
+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
|
| 278 |
+
|
| 279 |
+
extra_inputs = {}
|
| 280 |
+
if self.use_moe:
|
| 281 |
+
extra_inputs = {"mode": "und"}
|
| 282 |
+
|
| 283 |
+
output = self.language_model.forward_inference(
|
| 284 |
+
packed_query_sequence=packed_text_embedding,
|
| 285 |
+
query_lens=text_token_lens,
|
| 286 |
+
packed_query_position_ids=packed_text_position_ids,
|
| 287 |
+
packed_query_indexes=packed_text_indexes,
|
| 288 |
+
past_key_values=past_key_values,
|
| 289 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 290 |
+
key_values_lens=key_values_lens,
|
| 291 |
+
update_past_key_values=True,
|
| 292 |
+
is_causal=True,
|
| 293 |
+
**extra_inputs,
|
| 294 |
+
)
|
| 295 |
+
past_key_values = output.past_key_values
|
| 296 |
+
|
| 297 |
+
return past_key_values
|
| 298 |
+
|
| 299 |
+
def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
|
| 300 |
+
packed_vit_token_indexes = list()
|
| 301 |
+
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
|
| 302 |
+
packed_text_ids, packed_text_indexes = list(), list()
|
| 303 |
+
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
|
| 304 |
+
packed_key_value_indexes = list()
|
| 305 |
+
|
| 306 |
+
_curr = curr = 0
|
| 307 |
+
newlens, new_rope = list(), list()
|
| 308 |
+
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
|
| 309 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 310 |
+
curr += curr_kvlen
|
| 311 |
+
|
| 312 |
+
packed_text_ids.append(new_token_ids['start_of_image'])
|
| 313 |
+
packed_text_indexes.append(_curr)
|
| 314 |
+
packed_indexes.append(curr)
|
| 315 |
+
curr += 1
|
| 316 |
+
_curr += 1
|
| 317 |
+
|
| 318 |
+
image_tensor = transforms(image)
|
| 319 |
+
vit_position_ids = self.get_flattened_position_ids(
|
| 320 |
+
image_tensor.size(1), image_tensor.size(2),
|
| 321 |
+
self.vit_patch_size,
|
| 322 |
+
max_num_patches_per_side=self.vit_max_num_patch_per_side
|
| 323 |
+
)
|
| 324 |
+
vit_tokens = patchify(image_tensor, self.vit_patch_size)
|
| 325 |
+
packed_vit_tokens.append(vit_tokens)
|
| 326 |
+
num_img_tokens = vit_tokens.shape[0]
|
| 327 |
+
packed_vit_position_ids.append(vit_position_ids)
|
| 328 |
+
vit_token_seqlens.append(num_img_tokens)
|
| 329 |
+
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
|
| 330 |
+
packed_indexes.extend(range(curr, curr + num_img_tokens))
|
| 331 |
+
curr += num_img_tokens
|
| 332 |
+
_curr += num_img_tokens
|
| 333 |
+
|
| 334 |
+
packed_text_ids.append(new_token_ids['end_of_image'])
|
| 335 |
+
packed_text_indexes.append(_curr)
|
| 336 |
+
packed_indexes.append(curr)
|
| 337 |
+
curr += 1
|
| 338 |
+
_curr += 1
|
| 339 |
+
|
| 340 |
+
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
|
| 341 |
+
packed_seqlens.append(num_img_tokens + 2)
|
| 342 |
+
newlens.append(curr_kvlen + num_img_tokens + 2)
|
| 343 |
+
new_rope.append(curr_position_id + 1)
|
| 344 |
+
|
| 345 |
+
generation_input = {
|
| 346 |
+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
|
| 347 |
+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
|
| 348 |
+
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
|
| 349 |
+
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
|
| 350 |
+
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
|
| 351 |
+
"packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
|
| 352 |
+
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
|
| 353 |
+
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
|
| 354 |
+
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
|
| 355 |
+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 356 |
+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
return generation_input, newlens, new_rope
|
| 360 |
+
|
| 361 |
+
@torch.no_grad
|
| 362 |
+
def forward_cache_update_vit(
|
| 363 |
+
self,
|
| 364 |
+
past_key_values: NaiveCache,
|
| 365 |
+
packed_text_ids: torch.LongTensor,
|
| 366 |
+
packed_text_indexes: torch.LongTensor,
|
| 367 |
+
packed_vit_tokens: torch.Tensor,
|
| 368 |
+
packed_vit_token_indexes: torch.LongTensor,
|
| 369 |
+
packed_vit_position_ids: torch.LongTensor,
|
| 370 |
+
vit_token_seqlens: torch.IntTensor,
|
| 371 |
+
packed_position_ids: torch.LongTensor,
|
| 372 |
+
packed_seqlens: torch.IntTensor,
|
| 373 |
+
packed_indexes: torch.LongTensor,
|
| 374 |
+
packed_key_value_indexes: torch.LongTensor,
|
| 375 |
+
key_values_lens: torch.IntTensor,
|
| 376 |
+
):
|
| 377 |
+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
|
| 378 |
+
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
|
| 379 |
+
packed_sequence[packed_text_indexes] = packed_text_embedding
|
| 380 |
+
|
| 381 |
+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
|
| 382 |
+
cu_seqlens = cu_seqlens.to(torch.int32)
|
| 383 |
+
max_seqlen = torch.max(vit_token_seqlens).item()
|
| 384 |
+
packed_vit_token_embed = self.vit_model(
|
| 385 |
+
packed_pixel_values=packed_vit_tokens,
|
| 386 |
+
packed_flattened_position_ids=packed_vit_position_ids,
|
| 387 |
+
cu_seqlens=cu_seqlens,
|
| 388 |
+
max_seqlen=max_seqlen,
|
| 389 |
+
)
|
| 390 |
+
packed_vit_token_embed = self.connector(packed_vit_token_embed)
|
| 391 |
+
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
|
| 392 |
+
packed_vit_token_embed = packed_vit_token_embed + pos_emb
|
| 393 |
+
if packed_vit_token_embed.dtype != packed_sequence.dtype:
|
| 394 |
+
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
|
| 395 |
+
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
|
| 396 |
+
|
| 397 |
+
extra_inputs = {}
|
| 398 |
+
if self.use_moe:
|
| 399 |
+
extra_inputs = {"mode": "und"}
|
| 400 |
+
|
| 401 |
+
output = self.language_model.forward_inference(
|
| 402 |
+
packed_query_sequence=packed_sequence,
|
| 403 |
+
query_lens=packed_seqlens,
|
| 404 |
+
packed_query_position_ids=packed_position_ids,
|
| 405 |
+
packed_query_indexes=packed_indexes,
|
| 406 |
+
past_key_values=past_key_values,
|
| 407 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 408 |
+
key_values_lens=key_values_lens,
|
| 409 |
+
update_past_key_values=True,
|
| 410 |
+
is_causal=False,
|
| 411 |
+
**extra_inputs,
|
| 412 |
+
)
|
| 413 |
+
past_key_values = output.past_key_values
|
| 414 |
+
|
| 415 |
+
return past_key_values
|
| 416 |
+
|
| 417 |
+
def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
|
| 418 |
+
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
|
| 419 |
+
packed_vae_token_indexes = list()
|
| 420 |
+
packed_text_ids, packed_text_indexes = list(), list()
|
| 421 |
+
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
|
| 422 |
+
packed_key_value_indexes = list()
|
| 423 |
+
|
| 424 |
+
_curr = curr = 0
|
| 425 |
+
vae_image_tensors = list()
|
| 426 |
+
newlens, new_rope = list(), list()
|
| 427 |
+
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
|
| 428 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 429 |
+
curr += curr_kvlen
|
| 430 |
+
|
| 431 |
+
packed_text_ids.append(new_token_ids['start_of_image'])
|
| 432 |
+
packed_text_indexes.append(_curr)
|
| 433 |
+
packed_indexes.append(curr)
|
| 434 |
+
curr += 1
|
| 435 |
+
_curr += 1
|
| 436 |
+
|
| 437 |
+
image_tensor = transforms(image)
|
| 438 |
+
vae_image_tensors.append(image_tensor)
|
| 439 |
+
vae_posiiton_ids = self.get_flattened_position_ids(
|
| 440 |
+
image_tensor.size(1), image_tensor.size(2),
|
| 441 |
+
self.latent_downsample,
|
| 442 |
+
max_num_patches_per_side=self.max_latent_size
|
| 443 |
+
)
|
| 444 |
+
packed_vae_position_ids.append(vae_posiiton_ids)
|
| 445 |
+
H, W = image_tensor.shape[1:]
|
| 446 |
+
h = H // self.latent_downsample
|
| 447 |
+
w = W // self.latent_downsample
|
| 448 |
+
patchified_vae_latent_shapes.append((h, w))
|
| 449 |
+
|
| 450 |
+
num_img_tokens = w * h
|
| 451 |
+
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
|
| 452 |
+
packed_indexes.extend(range(curr, curr + num_img_tokens))
|
| 453 |
+
curr += num_img_tokens
|
| 454 |
+
_curr += num_img_tokens
|
| 455 |
+
|
| 456 |
+
packed_text_ids.append(new_token_ids['end_of_image'])
|
| 457 |
+
packed_text_indexes.append(_curr)
|
| 458 |
+
packed_indexes.append(curr)
|
| 459 |
+
curr += 1
|
| 460 |
+
_curr += 1
|
| 461 |
+
|
| 462 |
+
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
|
| 463 |
+
packed_seqlens.append(num_img_tokens + 2)
|
| 464 |
+
newlens.append(curr_kvlen + num_img_tokens + 2)
|
| 465 |
+
new_rope.append(curr_position_id + 1)
|
| 466 |
+
|
| 467 |
+
image_sizes = [item.shape for item in vae_image_tensors]
|
| 468 |
+
max_image_size = [max(item) for item in list(zip(*image_sizes))]
|
| 469 |
+
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
|
| 470 |
+
for i, image_tensor in enumerate(vae_image_tensors):
|
| 471 |
+
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
|
| 472 |
+
|
| 473 |
+
generation_input = {
|
| 474 |
+
"padded_images": padded_images,
|
| 475 |
+
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
|
| 476 |
+
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
|
| 477 |
+
"packed_timesteps": torch.tensor([timestep]),
|
| 478 |
+
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
|
| 479 |
+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
|
| 480 |
+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
|
| 481 |
+
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
|
| 482 |
+
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
|
| 483 |
+
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
|
| 484 |
+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 485 |
+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
return generation_input, newlens, new_rope
|
| 489 |
+
|
| 490 |
+
@torch.no_grad
|
| 491 |
+
def forward_cache_update_vae(
|
| 492 |
+
self,
|
| 493 |
+
vae_model,
|
| 494 |
+
past_key_values: NaiveCache,
|
| 495 |
+
padded_images: torch.Tensor,
|
| 496 |
+
patchified_vae_latent_shapes: List,
|
| 497 |
+
packed_vae_position_ids: torch.LongTensor,
|
| 498 |
+
packed_timesteps: torch.Tensor,
|
| 499 |
+
packed_vae_token_indexes: torch.LongTensor,
|
| 500 |
+
packed_text_ids: torch.LongTensor,
|
| 501 |
+
packed_text_indexes: torch.LongTensor,
|
| 502 |
+
packed_position_ids: torch.LongTensor,
|
| 503 |
+
packed_seqlens: torch.IntTensor,
|
| 504 |
+
packed_indexes: torch.LongTensor,
|
| 505 |
+
key_values_lens: torch.IntTensor,
|
| 506 |
+
packed_key_value_indexes: torch.Tensor,
|
| 507 |
+
):
|
| 508 |
+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
|
| 509 |
+
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
|
| 510 |
+
packed_sequence[packed_text_indexes] = packed_text_embedding
|
| 511 |
+
|
| 512 |
+
padded_latent = vae_model.encode(padded_images)
|
| 513 |
+
|
| 514 |
+
p = self.latent_patch_size
|
| 515 |
+
packed_latent = list()
|
| 516 |
+
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
|
| 517 |
+
latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
|
| 518 |
+
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
|
| 519 |
+
packed_latent.append(latent)
|
| 520 |
+
packed_latent = torch.cat(packed_latent, dim=0)
|
| 521 |
+
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
|
| 522 |
+
packed_timestep_embeds = self.time_embedder(packed_timesteps)
|
| 523 |
+
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
|
| 524 |
+
if packed_latent.dtype != packed_sequence.dtype:
|
| 525 |
+
packed_latent = packed_latent.to(packed_sequence.dtype)
|
| 526 |
+
packed_sequence[packed_vae_token_indexes] = packed_latent
|
| 527 |
+
|
| 528 |
+
extra_inputs = {}
|
| 529 |
+
if self.use_moe:
|
| 530 |
+
extra_inputs = {
|
| 531 |
+
"mode": "gen",
|
| 532 |
+
"packed_vae_token_indexes": packed_vae_token_indexes,
|
| 533 |
+
"packed_text_indexes": packed_text_indexes
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
output = self.language_model.forward_inference(
|
| 537 |
+
packed_query_sequence=packed_sequence,
|
| 538 |
+
query_lens=packed_seqlens,
|
| 539 |
+
packed_query_position_ids=packed_position_ids,
|
| 540 |
+
packed_query_indexes=packed_indexes,
|
| 541 |
+
past_key_values=past_key_values,
|
| 542 |
+
key_values_lens=key_values_lens,
|
| 543 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 544 |
+
update_past_key_values=True,
|
| 545 |
+
is_causal=False,
|
| 546 |
+
**extra_inputs,
|
| 547 |
+
)
|
| 548 |
+
past_key_values = output.past_key_values
|
| 549 |
+
|
| 550 |
+
return past_key_values
|
| 551 |
+
|
| 552 |
+
def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
|
| 553 |
+
packed_text_ids, packed_text_indexes = list(), list()
|
| 554 |
+
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
|
| 555 |
+
packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
|
| 556 |
+
packed_key_value_indexes = list()
|
| 557 |
+
|
| 558 |
+
query_curr = curr = 0
|
| 559 |
+
for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
|
| 560 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 561 |
+
curr += curr_kvlen
|
| 562 |
+
|
| 563 |
+
packed_text_ids.append(new_token_ids['start_of_image'])
|
| 564 |
+
packed_text_indexes.append(query_curr)
|
| 565 |
+
packed_indexes.append(curr)
|
| 566 |
+
curr += 1
|
| 567 |
+
query_curr += 1
|
| 568 |
+
|
| 569 |
+
vae_posiiton_ids = self.get_flattened_position_ids(
|
| 570 |
+
H, W,
|
| 571 |
+
self.latent_downsample,
|
| 572 |
+
max_num_patches_per_side=self.max_latent_size
|
| 573 |
+
)
|
| 574 |
+
packed_vae_position_ids.append(vae_posiiton_ids)
|
| 575 |
+
|
| 576 |
+
h, w = H // self.latent_downsample, W // self.latent_downsample
|
| 577 |
+
num_image_tokens = h * w
|
| 578 |
+
packed_init_noises.append(
|
| 579 |
+
torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
|
| 580 |
+
)
|
| 581 |
+
packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
|
| 582 |
+
packed_indexes.extend(range(curr, curr + num_image_tokens))
|
| 583 |
+
curr += num_image_tokens
|
| 584 |
+
query_curr += num_image_tokens
|
| 585 |
+
|
| 586 |
+
packed_text_ids.append(new_token_ids['end_of_image'])
|
| 587 |
+
packed_text_indexes.append(query_curr)
|
| 588 |
+
packed_indexes.append(curr)
|
| 589 |
+
curr += 1
|
| 590 |
+
query_curr += 1
|
| 591 |
+
|
| 592 |
+
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
|
| 593 |
+
packed_seqlens.append(num_image_tokens + 2)
|
| 594 |
+
|
| 595 |
+
generation_input = {
|
| 596 |
+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
|
| 597 |
+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
|
| 598 |
+
"packed_init_noises": torch.cat(packed_init_noises, dim=0),
|
| 599 |
+
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
|
| 600 |
+
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
|
| 601 |
+
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
|
| 602 |
+
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
|
| 603 |
+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 604 |
+
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
|
| 605 |
+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
return generation_input
|
| 609 |
+
|
| 610 |
+
def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
|
| 611 |
+
packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
|
| 612 |
+
|
| 613 |
+
query_curr = curr = 0
|
| 614 |
+
for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
|
| 615 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 616 |
+
curr += curr_kvlen
|
| 617 |
+
|
| 618 |
+
packed_indexes.append(curr)
|
| 619 |
+
curr += 1
|
| 620 |
+
query_curr += 1
|
| 621 |
+
|
| 622 |
+
h, w = H // self.latent_downsample, W // self.latent_downsample
|
| 623 |
+
num_image_tokens = h * w
|
| 624 |
+
packed_indexes.extend(range(curr, curr + num_image_tokens))
|
| 625 |
+
curr += num_image_tokens
|
| 626 |
+
query_curr += num_image_tokens
|
| 627 |
+
|
| 628 |
+
packed_indexes.append(curr)
|
| 629 |
+
curr += 1
|
| 630 |
+
query_curr += 1
|
| 631 |
+
|
| 632 |
+
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
|
| 633 |
+
|
| 634 |
+
generation_input = {
|
| 635 |
+
"cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
|
| 636 |
+
"cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 637 |
+
"cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
|
| 638 |
+
"cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
return generation_input
|
| 642 |
+
|
| 643 |
+
@torch.no_grad
|
| 644 |
+
def generate_image(
|
| 645 |
+
self,
|
| 646 |
+
packed_text_ids: torch.LongTensor,
|
| 647 |
+
packed_text_indexes: torch.LongTensor,
|
| 648 |
+
packed_init_noises: torch.Tensor,
|
| 649 |
+
packed_vae_position_ids: torch.LongTensor,
|
| 650 |
+
packed_vae_token_indexes: torch.LongTensor,
|
| 651 |
+
packed_seqlens: torch.IntTensor,
|
| 652 |
+
packed_position_ids: torch.LongTensor,
|
| 653 |
+
packed_indexes: torch.LongTensor,
|
| 654 |
+
past_key_values: NaiveCache,
|
| 655 |
+
key_values_lens: torch.IntTensor,
|
| 656 |
+
packed_key_value_indexes: torch.LongTensor,
|
| 657 |
+
num_timesteps: int = 24,
|
| 658 |
+
timestep_shift: float = 1.0,
|
| 659 |
+
cfg_renorm_min: float = 0.0,
|
| 660 |
+
cfg_renorm_type: str = "global",
|
| 661 |
+
cfg_interval: Optional[Tuple[float, float]] = [0, 1],
|
| 662 |
+
# cfg_text
|
| 663 |
+
cfg_text_scale: float = 1.0,
|
| 664 |
+
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
|
| 665 |
+
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
|
| 666 |
+
cfg_text_past_key_values: Optional[NaiveCache] = None,
|
| 667 |
+
cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
|
| 668 |
+
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
|
| 669 |
+
# cfg_img
|
| 670 |
+
cfg_img_scale: float = 1.0,
|
| 671 |
+
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
|
| 672 |
+
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
|
| 673 |
+
cfg_img_past_key_values: Optional[NaiveCache] = None,
|
| 674 |
+
cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
|
| 675 |
+
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
|
| 676 |
+
cfg_type: str = "parallel",
|
| 677 |
+
):
|
| 678 |
+
x_t = packed_init_noises
|
| 679 |
+
|
| 680 |
+
timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
|
| 681 |
+
timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
|
| 682 |
+
dts = timesteps[:-1] - timesteps[1:]
|
| 683 |
+
timesteps = timesteps[:-1]
|
| 684 |
+
|
| 685 |
+
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
|
| 686 |
+
|
| 687 |
+
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
|
| 688 |
+
if t > cfg_interval[0] and t <= cfg_interval[1]:
|
| 689 |
+
cfg_text_scale_ = cfg_text_scale
|
| 690 |
+
cfg_img_scale_ = cfg_img_scale
|
| 691 |
+
else:
|
| 692 |
+
cfg_text_scale_ = 1.0
|
| 693 |
+
cfg_img_scale_ = 1.0
|
| 694 |
+
v_t = self._forward_flow(
|
| 695 |
+
x_t=x_t,
|
| 696 |
+
timestep=timestep,
|
| 697 |
+
packed_vae_token_indexes=packed_vae_token_indexes,
|
| 698 |
+
packed_vae_position_ids=packed_vae_position_ids,
|
| 699 |
+
packed_text_ids=packed_text_ids,
|
| 700 |
+
packed_text_indexes=packed_text_indexes,
|
| 701 |
+
packed_position_ids=packed_position_ids,
|
| 702 |
+
packed_indexes=packed_indexes,
|
| 703 |
+
packed_seqlens=packed_seqlens,
|
| 704 |
+
key_values_lens=key_values_lens,
|
| 705 |
+
past_key_values=past_key_values,
|
| 706 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 707 |
+
cfg_renorm_min=cfg_renorm_min,
|
| 708 |
+
cfg_renorm_type=cfg_renorm_type,
|
| 709 |
+
# cfg_text
|
| 710 |
+
cfg_text_scale=cfg_text_scale_,
|
| 711 |
+
cfg_text_packed_position_ids=cfg_text_packed_position_ids,
|
| 712 |
+
cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
|
| 713 |
+
cfg_text_key_values_lens=cfg_text_key_values_lens,
|
| 714 |
+
cfg_text_past_key_values=cfg_text_past_key_values,
|
| 715 |
+
cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
|
| 716 |
+
# cfg_img
|
| 717 |
+
cfg_img_scale=cfg_img_scale_,
|
| 718 |
+
cfg_img_packed_position_ids=cfg_img_packed_position_ids,
|
| 719 |
+
cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
|
| 720 |
+
cfg_img_key_values_lens=cfg_img_key_values_lens,
|
| 721 |
+
cfg_img_past_key_values=cfg_img_past_key_values,
|
| 722 |
+
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
|
| 723 |
+
cfg_type=cfg_type,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
|
| 727 |
+
|
| 728 |
+
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
|
| 729 |
+
return unpacked_latent
|
| 730 |
+
|
| 731 |
+
@torch.no_grad
|
| 732 |
+
def _forward_flow(
|
| 733 |
+
self,
|
| 734 |
+
x_t: torch.Tensor,
|
| 735 |
+
timestep: torch.LongTensor,
|
| 736 |
+
packed_vae_token_indexes: torch.LongTensor,
|
| 737 |
+
packed_vae_position_ids: torch.LongTensor,
|
| 738 |
+
packed_text_ids: torch.LongTensor,
|
| 739 |
+
packed_text_indexes: torch.LongTensor,
|
| 740 |
+
packed_indexes: torch.LongTensor,
|
| 741 |
+
packed_position_ids: torch.LongTensor,
|
| 742 |
+
packed_seqlens: torch.IntTensor,
|
| 743 |
+
key_values_lens: torch.IntTensor,
|
| 744 |
+
past_key_values: NaiveCache,
|
| 745 |
+
packed_key_value_indexes: torch.LongTensor,
|
| 746 |
+
cfg_renorm_min: float = 0.0,
|
| 747 |
+
cfg_renorm_type: str = "global",
|
| 748 |
+
# cfg_text
|
| 749 |
+
cfg_text_scale: float = 1.0,
|
| 750 |
+
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
|
| 751 |
+
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
|
| 752 |
+
cfg_text_key_values_lens: Optional[torch.Tensor] = None,
|
| 753 |
+
cfg_text_past_key_values: Optional[NaiveCache] = None,
|
| 754 |
+
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
|
| 755 |
+
# cfg_img
|
| 756 |
+
cfg_img_scale: float = 1.0,
|
| 757 |
+
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
|
| 758 |
+
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
|
| 759 |
+
cfg_img_key_values_lens: Optional[torch.Tensor] = None,
|
| 760 |
+
cfg_img_past_key_values: Optional[NaiveCache] = None,
|
| 761 |
+
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
|
| 762 |
+
cfg_type: str = "parallel",
|
| 763 |
+
):
|
| 764 |
+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
|
| 765 |
+
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
|
| 766 |
+
packed_sequence[packed_text_indexes] = packed_text_embedding
|
| 767 |
+
|
| 768 |
+
assert timestep.unique().shape[0] == 1
|
| 769 |
+
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
|
| 770 |
+
packed_timestep_embeds = self.time_embedder(timestep)
|
| 771 |
+
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
|
| 772 |
+
if x_t.dtype != packed_sequence.dtype:
|
| 773 |
+
x_t = x_t.to(packed_sequence.dtype)
|
| 774 |
+
packed_sequence[packed_vae_token_indexes] = x_t
|
| 775 |
+
|
| 776 |
+
extra_inputs = {}
|
| 777 |
+
if self.use_moe:
|
| 778 |
+
extra_inputs = {
|
| 779 |
+
"mode": "gen",
|
| 780 |
+
"packed_vae_token_indexes": packed_vae_token_indexes,
|
| 781 |
+
"packed_text_indexes": packed_text_indexes
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
output = self.language_model.forward_inference(
|
| 785 |
+
packed_query_sequence=packed_sequence,
|
| 786 |
+
query_lens=packed_seqlens,
|
| 787 |
+
packed_query_position_ids=packed_position_ids,
|
| 788 |
+
packed_query_indexes=packed_indexes,
|
| 789 |
+
past_key_values=past_key_values,
|
| 790 |
+
key_values_lens=key_values_lens,
|
| 791 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 792 |
+
update_past_key_values=False,
|
| 793 |
+
is_causal=False,
|
| 794 |
+
**extra_inputs,
|
| 795 |
+
)
|
| 796 |
+
v_t = self.llm2vae(output.packed_query_sequence)
|
| 797 |
+
v_t = v_t[packed_vae_token_indexes]
|
| 798 |
+
|
| 799 |
+
if cfg_text_scale > 1.0:
|
| 800 |
+
cfg_text_output = self.language_model.forward_inference(
|
| 801 |
+
packed_query_sequence=packed_sequence,
|
| 802 |
+
query_lens=packed_seqlens,
|
| 803 |
+
packed_query_position_ids=cfg_text_packed_position_ids,
|
| 804 |
+
packed_query_indexes=cfg_text_packed_query_indexes,
|
| 805 |
+
past_key_values=cfg_text_past_key_values,
|
| 806 |
+
key_values_lens=cfg_text_key_values_lens,
|
| 807 |
+
packed_key_value_indexes=cfg_text_packed_key_value_indexes,
|
| 808 |
+
update_past_key_values=False,
|
| 809 |
+
is_causal=False,
|
| 810 |
+
**extra_inputs,
|
| 811 |
+
)
|
| 812 |
+
cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
|
| 813 |
+
cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
|
| 814 |
+
|
| 815 |
+
if cfg_img_scale > 1.0:
|
| 816 |
+
cfg_img_output = self.language_model.forward_inference(
|
| 817 |
+
packed_query_sequence=packed_sequence,
|
| 818 |
+
query_lens=packed_seqlens,
|
| 819 |
+
packed_query_position_ids=cfg_img_packed_position_ids,
|
| 820 |
+
packed_query_indexes=cfg_img_packed_query_indexes,
|
| 821 |
+
past_key_values=cfg_img_past_key_values,
|
| 822 |
+
key_values_lens=cfg_img_key_values_lens,
|
| 823 |
+
packed_key_value_indexes=cfg_img_packed_key_value_indexes,
|
| 824 |
+
update_past_key_values=False,
|
| 825 |
+
is_causal=False,
|
| 826 |
+
**extra_inputs,
|
| 827 |
+
)
|
| 828 |
+
cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
|
| 829 |
+
cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
|
| 830 |
+
|
| 831 |
+
if cfg_text_scale > 1.0:
|
| 832 |
+
if cfg_renorm_type == "text_channel":
|
| 833 |
+
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
|
| 834 |
+
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
|
| 835 |
+
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
|
| 836 |
+
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
|
| 837 |
+
v_t_text = v_t_text_ * scale
|
| 838 |
+
if cfg_img_scale > 1.0:
|
| 839 |
+
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
|
| 840 |
+
else:
|
| 841 |
+
v_t = v_t_text
|
| 842 |
+
else:
|
| 843 |
+
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
|
| 844 |
+
|
| 845 |
+
if cfg_img_scale > 1.0:
|
| 846 |
+
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
|
| 847 |
+
else:
|
| 848 |
+
v_t_ = v_t_text_
|
| 849 |
+
|
| 850 |
+
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
|
| 851 |
+
if cfg_renorm_type == "global":
|
| 852 |
+
norm_v_t = torch.norm(v_t)
|
| 853 |
+
norm_v_t_ = torch.norm(v_t_)
|
| 854 |
+
elif cfg_renorm_type == "channel":
|
| 855 |
+
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
|
| 856 |
+
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
|
| 857 |
+
else:
|
| 858 |
+
raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
|
| 859 |
+
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
|
| 860 |
+
v_t = v_t_ * scale
|
| 861 |
+
else:
|
| 862 |
+
# No CFG
|
| 863 |
+
pass
|
| 864 |
+
|
| 865 |
+
return v_t
|
| 866 |
+
|
| 867 |
+
def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
|
| 868 |
+
packed_start_tokens, packed_key_value_indexes = list(), list()
|
| 869 |
+
packed_query_position_ids = list()
|
| 870 |
+
|
| 871 |
+
curr = 0
|
| 872 |
+
for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
|
| 873 |
+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
|
| 874 |
+
packed_start_tokens.append(new_token_ids['bos_token_id'])
|
| 875 |
+
packed_query_position_ids.append(curr_position_id)
|
| 876 |
+
curr += curr_kvlen
|
| 877 |
+
|
| 878 |
+
generation_input = {
|
| 879 |
+
"packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
|
| 880 |
+
"packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
|
| 881 |
+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
|
| 882 |
+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
return generation_input
|
| 886 |
+
|
| 887 |
+
@torch.no_grad
|
| 888 |
+
def generate_text(
|
| 889 |
+
self,
|
| 890 |
+
past_key_values: NaiveCache,
|
| 891 |
+
packed_key_value_indexes: torch.LongTensor,
|
| 892 |
+
key_values_lens: torch.IntTensor,
|
| 893 |
+
packed_start_tokens: torch.LongTensor,
|
| 894 |
+
packed_query_position_ids: torch.LongTensor,
|
| 895 |
+
max_length: int,
|
| 896 |
+
do_sample: bool = False,
|
| 897 |
+
temperature: float = 1.0,
|
| 898 |
+
end_token_id: int = None,
|
| 899 |
+
):
|
| 900 |
+
step = 0
|
| 901 |
+
generated_sequence = []
|
| 902 |
+
curr_tokens = packed_start_tokens
|
| 903 |
+
while step < max_length:
|
| 904 |
+
generated_sequence.append(curr_tokens)
|
| 905 |
+
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
|
| 906 |
+
query_lens = torch.ones_like(curr_tokens)
|
| 907 |
+
packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
|
| 908 |
+
0, len(key_values_lens),
|
| 909 |
+
device=key_values_lens.device,
|
| 910 |
+
dtype=key_values_lens.dtype
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
|
| 914 |
+
for i in range(len(uppacked)):
|
| 915 |
+
uppacked[i] += i
|
| 916 |
+
packed_key_value_indexes = torch.cat(uppacked, dim=0)
|
| 917 |
+
|
| 918 |
+
extra_inputs = {}
|
| 919 |
+
if self.use_moe:
|
| 920 |
+
extra_inputs = {"mode": "und"}
|
| 921 |
+
|
| 922 |
+
output = self.language_model.forward_inference(
|
| 923 |
+
packed_query_sequence=packed_text_embedding,
|
| 924 |
+
query_lens=query_lens,
|
| 925 |
+
packed_query_position_ids=packed_query_position_ids,
|
| 926 |
+
packed_query_indexes=packed_query_indexes,
|
| 927 |
+
past_key_values=past_key_values,
|
| 928 |
+
key_values_lens=key_values_lens,
|
| 929 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 930 |
+
update_past_key_values=True,
|
| 931 |
+
is_causal=True,
|
| 932 |
+
**extra_inputs,
|
| 933 |
+
)
|
| 934 |
+
past_key_values = output.past_key_values
|
| 935 |
+
packed_query_sequence = output.packed_query_sequence
|
| 936 |
+
pred_logits = self.language_model.lm_head(packed_query_sequence)
|
| 937 |
+
|
| 938 |
+
if do_sample:
|
| 939 |
+
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
|
| 940 |
+
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 941 |
+
else:
|
| 942 |
+
curr_tokens = torch.argmax(pred_logits, dim=-1)
|
| 943 |
+
|
| 944 |
+
uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
|
| 945 |
+
for i in range(len(uppacked)):
|
| 946 |
+
uppacked[i] = torch.cat(
|
| 947 |
+
[uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
|
| 948 |
+
)
|
| 949 |
+
packed_key_value_indexes = torch.cat(uppacked, dim=0)
|
| 950 |
+
key_values_lens = key_values_lens + 1
|
| 951 |
+
packed_query_position_ids = packed_query_position_ids + 1
|
| 952 |
+
step += 1
|
| 953 |
+
|
| 954 |
+
if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
|
| 955 |
+
# Check if next token would be vision_start (151652)
|
| 956 |
+
generated_sequence.append(curr_tokens) # Add the end token
|
| 957 |
+
|
| 958 |
+
# Generate one more token to check if it's vision_start
|
| 959 |
+
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
|
| 960 |
+
|
| 961 |
+
uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
|
| 962 |
+
for i in range(len(uppacked)):
|
| 963 |
+
uppacked[i] += i
|
| 964 |
+
packed_key_value_indexes = torch.cat(uppacked, dim=0)
|
| 965 |
+
|
| 966 |
+
output = self.language_model.forward_inference(
|
| 967 |
+
packed_query_sequence=packed_text_embedding,
|
| 968 |
+
query_lens=query_lens,
|
| 969 |
+
packed_query_position_ids=packed_query_position_ids,
|
| 970 |
+
packed_query_indexes=packed_query_indexes,
|
| 971 |
+
past_key_values=past_key_values,
|
| 972 |
+
key_values_lens=key_values_lens,
|
| 973 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 974 |
+
update_past_key_values=False,
|
| 975 |
+
is_causal=True,
|
| 976 |
+
**extra_inputs,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
pred_logits = self.language_model.lm_head(output.packed_query_sequence)
|
| 980 |
+
if do_sample:
|
| 981 |
+
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
|
| 982 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 983 |
+
else:
|
| 984 |
+
next_token = torch.argmax(pred_logits, dim=-1)
|
| 985 |
+
|
| 986 |
+
# If next token is vision_start (151652), include it
|
| 987 |
+
if next_token[0] == 151652:
|
| 988 |
+
generated_sequence.append(next_token)
|
| 989 |
+
|
| 990 |
+
break
|
| 991 |
+
|
| 992 |
+
output_device = generated_sequence[0].device
|
| 993 |
+
return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
|
| 994 |
+
|
| 995 |
+
# for evaluation
|
| 996 |
+
@torch.no_grad()
|
| 997 |
+
def chat(
|
| 998 |
+
self,
|
| 999 |
+
tokenizer,
|
| 1000 |
+
new_token_ids,
|
| 1001 |
+
image_transform,
|
| 1002 |
+
images,
|
| 1003 |
+
prompt,
|
| 1004 |
+
max_length: int,
|
| 1005 |
+
do_sample: bool = False,
|
| 1006 |
+
temperature: float = 1.0,
|
| 1007 |
+
):
|
| 1008 |
+
device = next(self.parameters()).device
|
| 1009 |
+
|
| 1010 |
+
if isinstance(new_token_ids, dict):
|
| 1011 |
+
for k, v in new_token_ids.items():
|
| 1012 |
+
if torch.is_tensor(v):
|
| 1013 |
+
new_token_ids[k] = v.to(device)
|
| 1014 |
+
elif torch.is_tensor(new_token_ids):
|
| 1015 |
+
new_token_ids = new_token_ids.to(device)
|
| 1016 |
+
|
| 1017 |
+
# prefill
|
| 1018 |
+
past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
|
| 1019 |
+
newlens = [0]
|
| 1020 |
+
new_rope = [0]
|
| 1021 |
+
|
| 1022 |
+
# add images
|
| 1023 |
+
for image in images:
|
| 1024 |
+
generation_input, newlens, new_rope = self.prepare_vit_images(
|
| 1025 |
+
curr_kvlens=newlens,
|
| 1026 |
+
curr_rope=new_rope,
|
| 1027 |
+
images=[image],
|
| 1028 |
+
transforms=image_transform,
|
| 1029 |
+
new_token_ids=new_token_ids,
|
| 1030 |
+
)
|
| 1031 |
+
for k, v in generation_input.items():
|
| 1032 |
+
if torch.is_tensor(v):
|
| 1033 |
+
generation_input[k] = v.to(device)
|
| 1034 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 1035 |
+
past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
|
| 1036 |
+
|
| 1037 |
+
# add text
|
| 1038 |
+
generation_input, newlens, new_rope = self.prepare_prompts(
|
| 1039 |
+
curr_kvlens=newlens,
|
| 1040 |
+
curr_rope=new_rope,
|
| 1041 |
+
prompts=[prompt],
|
| 1042 |
+
tokenizer=tokenizer,
|
| 1043 |
+
new_token_ids=new_token_ids,
|
| 1044 |
+
)
|
| 1045 |
+
for k, v in generation_input.items():
|
| 1046 |
+
if torch.is_tensor(v):
|
| 1047 |
+
generation_input[k] = v.to(device)
|
| 1048 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 1049 |
+
past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
|
| 1050 |
+
|
| 1051 |
+
# decode
|
| 1052 |
+
generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
|
| 1053 |
+
for k, v in generation_input.items():
|
| 1054 |
+
if torch.is_tensor(v):
|
| 1055 |
+
generation_input[k] = v.to(device)
|
| 1056 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 1057 |
+
unpacked_latent = self.generate_text(
|
| 1058 |
+
past_key_values=past_key_values,
|
| 1059 |
+
max_length=max_length,
|
| 1060 |
+
do_sample=do_sample,
|
| 1061 |
+
temperature=temperature,
|
| 1062 |
+
end_token_id=new_token_ids['eos_token_id'],
|
| 1063 |
+
**generation_input,
|
| 1064 |
+
)
|
| 1065 |
+
output = tokenizer.decode(unpacked_latent[:,0])
|
| 1066 |
+
output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
|
| 1067 |
+
|
| 1068 |
+
return output
|
modeling/bagel/modeling_utils.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: CC BY-NC 4.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under CC BY-NC 4.0, with the full license text
|
| 8 |
+
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn
|
| 17 |
+
from transformers.activations import ACT2FN
|
| 18 |
+
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
# 2D sine-cosine position embedding
|
| 21 |
+
# References:
|
| 22 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 23 |
+
# --------------------------------------------------------
|
| 24 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 25 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 26 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 28 |
+
grid = np.stack(grid, axis=0)
|
| 29 |
+
|
| 30 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 31 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 32 |
+
if cls_token and extra_tokens > 0:
|
| 33 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 34 |
+
return pos_embed
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 38 |
+
assert embed_dim % 2 == 0
|
| 39 |
+
|
| 40 |
+
# use half of dimensions to encode grid_h
|
| 41 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 42 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 43 |
+
|
| 44 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 45 |
+
return emb
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 49 |
+
"""
|
| 50 |
+
embed_dim: output dimension for each position
|
| 51 |
+
pos: a list of positions to be encoded: size (M,)
|
| 52 |
+
out: (M, D)
|
| 53 |
+
"""
|
| 54 |
+
assert embed_dim % 2 == 0
|
| 55 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 56 |
+
omega /= embed_dim / 2.
|
| 57 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 58 |
+
|
| 59 |
+
pos = pos.reshape(-1) # (M,)
|
| 60 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 61 |
+
|
| 62 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 63 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 64 |
+
|
| 65 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 66 |
+
return emb
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# --------------------------------------------------------
|
| 70 |
+
# TimestepEmbedder
|
| 71 |
+
# Reference:
|
| 72 |
+
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 73 |
+
# --------------------------------------------------------
|
| 74 |
+
class TimestepEmbedder(nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Embeds scalar timesteps into vector representations.
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.mlp = nn.Sequential(
|
| 81 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 82 |
+
nn.SiLU(),
|
| 83 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 84 |
+
)
|
| 85 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 89 |
+
"""
|
| 90 |
+
Create sinusoidal timestep embeddings.
|
| 91 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 92 |
+
These may be fractional.
|
| 93 |
+
:param dim: the dimension of the output.
|
| 94 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 95 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 96 |
+
"""
|
| 97 |
+
half = dim // 2
|
| 98 |
+
freqs = torch.exp(
|
| 99 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 100 |
+
).to(device=t.device)
|
| 101 |
+
args = t[:, None].float() * freqs[None]
|
| 102 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 103 |
+
if dim % 2:
|
| 104 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 105 |
+
return embedding
|
| 106 |
+
|
| 107 |
+
def forward(self, t):
|
| 108 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 109 |
+
t_emb = self.mlp(t_freq)
|
| 110 |
+
return t_emb
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MLPconnector(nn.Module):
|
| 114 |
+
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.activation_fn = ACT2FN[hidden_act]
|
| 117 |
+
self.fc1 = nn.Linear(in_dim, out_dim)
|
| 118 |
+
self.fc2 = nn.Linear(out_dim, out_dim)
|
| 119 |
+
|
| 120 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
hidden_states = self.fc1(hidden_states)
|
| 122 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 123 |
+
hidden_states = self.fc2(hidden_states)
|
| 124 |
+
return hidden_states
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class PositionEmbedding(nn.Module):
|
| 128 |
+
def __init__(self, max_num_patch_per_side, hidden_size):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.max_num_patch_per_side = max_num_patch_per_side
|
| 131 |
+
self.hidden_size = hidden_size
|
| 132 |
+
self.pos_embed = nn.Parameter(
|
| 133 |
+
torch.zeros(max_num_patch_per_side ** 2, hidden_size),
|
| 134 |
+
requires_grad=False
|
| 135 |
+
)
|
| 136 |
+
self._init_weights()
|
| 137 |
+
|
| 138 |
+
def _init_weights(self):
|
| 139 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 140 |
+
pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
|
| 141 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
|
| 142 |
+
|
| 143 |
+
def forward(self, position_ids):
|
| 144 |
+
return self.pos_embed[position_ids]
|
modeling/bagel/qwen2_navit.py
ADDED
|
@@ -0,0 +1,1157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under Apache-2.0, with the full license text
|
| 8 |
+
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 20 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 21 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 22 |
+
from transformers.utils import ModelOutput
|
| 23 |
+
|
| 24 |
+
from flash_attn import flash_attn_varlen_func
|
| 25 |
+
from modeling.qwen2.modeling_qwen2 import (
|
| 26 |
+
Qwen2Attention,
|
| 27 |
+
Qwen2MLP,
|
| 28 |
+
Qwen2PreTrainedModel,
|
| 29 |
+
Qwen2RMSNorm,
|
| 30 |
+
Qwen2RotaryEmbedding,
|
| 31 |
+
apply_rotary_pos_emb,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
torch._dynamo.config.cache_size_limit = 512
|
| 38 |
+
torch._dynamo.config.accumulated_cache_size_limit = 4096
|
| 39 |
+
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
|
| 40 |
+
flex_attention = torch.compile(flex_attention)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Qwen2Config(_Qwen2Config):
|
| 44 |
+
r"""
|
| 45 |
+
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
|
| 46 |
+
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 47 |
+
with the defaults will yield a similar configuration to that of
|
| 48 |
+
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
|
| 49 |
+
|
| 50 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 51 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
| 55 |
+
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
|
| 56 |
+
`inputs_ids` passed when calling [`Qwen2Model`]
|
| 57 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 58 |
+
Dimension of the hidden representations.
|
| 59 |
+
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 60 |
+
Dimension of the MLP representations.
|
| 61 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 62 |
+
Number of hidden layers in the Transformer encoder.
|
| 63 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 64 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 65 |
+
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 66 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 67 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 68 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 69 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 70 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 71 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 72 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 73 |
+
The non-linear activation function (function or string) in the decoder.
|
| 74 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 75 |
+
The maximum sequence length that this model might ever be used with.
|
| 76 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 77 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 78 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 79 |
+
The epsilon used by the rms normalization layers.
|
| 80 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 82 |
+
relevant if `config.is_decoder=True`.
|
| 83 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 84 |
+
Whether the model's input and output word embeddings should be tied.
|
| 85 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 86 |
+
The base period of the RoPE embeddings.
|
| 87 |
+
rope_scaling (`Dict`, *optional*):
|
| 88 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 89 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 90 |
+
accordingly.
|
| 91 |
+
Expected contents:
|
| 92 |
+
`rope_type` (`str`):
|
| 93 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 94 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 95 |
+
`factor` (`float`, *optional*):
|
| 96 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 97 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 98 |
+
original maximum pre-trained length.
|
| 99 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 100 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 101 |
+
pretraining.
|
| 102 |
+
`attention_factor` (`float`, *optional*):
|
| 103 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 104 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 105 |
+
`factor` field to infer the suggested value.
|
| 106 |
+
`beta_fast` (`float`, *optional*):
|
| 107 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 108 |
+
ramp function. If unspecified, it defaults to 32.
|
| 109 |
+
`beta_slow` (`float`, *optional*):
|
| 110 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 111 |
+
ramp function. If unspecified, it defaults to 1.
|
| 112 |
+
`short_factor` (`List[float]`, *optional*):
|
| 113 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 114 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 115 |
+
size divided by the number of attention heads divided by 2
|
| 116 |
+
`long_factor` (`List[float]`, *optional*):
|
| 117 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 118 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 119 |
+
size divided by the number of attention heads divided by 2
|
| 120 |
+
`low_freq_factor` (`float`, *optional*):
|
| 121 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 122 |
+
`high_freq_factor` (`float`, *optional*):
|
| 123 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 124 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 125 |
+
Whether to use sliding window attention.
|
| 126 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 127 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 128 |
+
max_window_layers (`int`, *optional*, defaults to 28):
|
| 129 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 130 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 131 |
+
The dropout ratio for the attention probabilities.
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
>>> from transformers import Qwen2Model, Qwen2Config
|
| 135 |
+
|
| 136 |
+
>>> # Initializing a Qwen2 style configuration
|
| 137 |
+
>>> configuration = Qwen2Config()
|
| 138 |
+
|
| 139 |
+
>>> # Initializing a model from the Qwen2-7B style configuration
|
| 140 |
+
>>> model = Qwen2Model(configuration)
|
| 141 |
+
|
| 142 |
+
>>> # Accessing the model configuration
|
| 143 |
+
>>> configuration = model.config
|
| 144 |
+
```"""
|
| 145 |
+
|
| 146 |
+
model_type = "qwen2"
|
| 147 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
vocab_size=151936,
|
| 152 |
+
hidden_size=4096,
|
| 153 |
+
intermediate_size=22016,
|
| 154 |
+
num_hidden_layers=32,
|
| 155 |
+
num_attention_heads=32,
|
| 156 |
+
num_key_value_heads=32,
|
| 157 |
+
hidden_act="silu",
|
| 158 |
+
max_position_embeddings=32768,
|
| 159 |
+
initializer_range=0.02,
|
| 160 |
+
rms_norm_eps=1e-6,
|
| 161 |
+
use_cache=True,
|
| 162 |
+
tie_word_embeddings=False,
|
| 163 |
+
rope_theta=10000.0,
|
| 164 |
+
rope_scaling=None,
|
| 165 |
+
use_sliding_window=False,
|
| 166 |
+
sliding_window=4096,
|
| 167 |
+
max_window_layers=28,
|
| 168 |
+
attention_dropout=0.0,
|
| 169 |
+
is_causal=True,
|
| 170 |
+
_attn_implementation="flash_attention_2",
|
| 171 |
+
qk_norm=True,
|
| 172 |
+
layer_module="Qwen2DecoderLayer",
|
| 173 |
+
freeze_und=False,
|
| 174 |
+
**kwargs,
|
| 175 |
+
):
|
| 176 |
+
super().__init__(
|
| 177 |
+
vocab_size=vocab_size,
|
| 178 |
+
hidden_size=hidden_size,
|
| 179 |
+
intermediate_size=intermediate_size,
|
| 180 |
+
num_hidden_layers=num_hidden_layers,
|
| 181 |
+
num_attention_heads=num_attention_heads,
|
| 182 |
+
num_key_value_heads=num_key_value_heads,
|
| 183 |
+
hidden_act=hidden_act,
|
| 184 |
+
max_position_embeddings=max_position_embeddings,
|
| 185 |
+
initializer_range=initializer_range,
|
| 186 |
+
rms_norm_eps=rms_norm_eps,
|
| 187 |
+
use_cache=use_cache,
|
| 188 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 189 |
+
rope_theta=rope_theta,
|
| 190 |
+
rope_scaling=rope_scaling,
|
| 191 |
+
use_sliding_window=use_sliding_window,
|
| 192 |
+
sliding_window=sliding_window,
|
| 193 |
+
max_window_layers=max_window_layers,
|
| 194 |
+
attention_dropout=attention_dropout,
|
| 195 |
+
is_causal=is_causal,
|
| 196 |
+
_attn_implementation=_attn_implementation,
|
| 197 |
+
**kwargs,
|
| 198 |
+
)
|
| 199 |
+
self.qk_norm = qk_norm
|
| 200 |
+
self.layer_module = layer_module
|
| 201 |
+
self.freeze_und = freeze_und
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NaiveCache:
|
| 205 |
+
def __init__(self, num_layers):
|
| 206 |
+
self.key_cache = {k: None for k in range(num_layers)}
|
| 207 |
+
self.value_cache = {k: None for k in range(num_layers)}
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def num_layers(self):
|
| 211 |
+
return len(self.key_cache)
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def seq_lens(self):
|
| 215 |
+
if self.key_cache[0] is not None:
|
| 216 |
+
return self.key_cache[0].shape[0]
|
| 217 |
+
else:
|
| 218 |
+
return 0
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@dataclass
|
| 222 |
+
class BaseNavitOutputWithPast(ModelOutput):
|
| 223 |
+
packed_query_sequence: torch.FloatTensor = None
|
| 224 |
+
past_key_values: Optional[NaiveCache] = None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def pad_sequence(tensor, pad_size):
|
| 228 |
+
H, L, D = tensor.shape
|
| 229 |
+
pad_tensor = tensor.new_zeros((H, pad_size, D))
|
| 230 |
+
return torch.cat([tensor, pad_tensor], dim=1)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class PackedAttention(Qwen2Attention):
|
| 234 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 235 |
+
super().__init__(config, layer_idx)
|
| 236 |
+
if self.config.qk_norm:
|
| 237 |
+
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 238 |
+
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 239 |
+
else:
|
| 240 |
+
self.q_norm = nn.Identity()
|
| 241 |
+
self.k_norm = nn.Identity()
|
| 242 |
+
|
| 243 |
+
def forward(self, *args, **kwargs):
|
| 244 |
+
if self.training:
|
| 245 |
+
return self.forward_train(*args, **kwargs)
|
| 246 |
+
else:
|
| 247 |
+
return self.forward_inference(*args, **kwargs)
|
| 248 |
+
|
| 249 |
+
def forward_train(
|
| 250 |
+
self,
|
| 251 |
+
packed_sequence: torch.Tensor,
|
| 252 |
+
sample_lens: List[int],
|
| 253 |
+
attention_mask: List[torch.Tensor],
|
| 254 |
+
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 255 |
+
):
|
| 256 |
+
packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
|
| 257 |
+
packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 258 |
+
packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 259 |
+
|
| 260 |
+
packed_query_states = self.q_norm(packed_query_states)
|
| 261 |
+
packed_key_states = self.k_norm(packed_key_states)
|
| 262 |
+
|
| 263 |
+
packed_cos, packed_sin = packed_position_embeddings
|
| 264 |
+
packed_query_states, packed_key_states = apply_rotary_pos_emb(
|
| 265 |
+
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if isinstance(attention_mask, List):
|
| 269 |
+
packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
|
| 270 |
+
packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
|
| 271 |
+
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
|
| 272 |
+
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
|
| 273 |
+
|
| 274 |
+
unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
|
| 275 |
+
unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
|
| 276 |
+
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
|
| 277 |
+
upacked_attn_output = []
|
| 278 |
+
for query_states, key_states, value_states, attention_mask_per_sample in zip(
|
| 279 |
+
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
|
| 280 |
+
):
|
| 281 |
+
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
| 282 |
+
attn_output = scaled_dot_product_attention(
|
| 283 |
+
query_states.to(torch.bfloat16).unsqueeze(0),
|
| 284 |
+
key_states.to(torch.bfloat16).unsqueeze(0),
|
| 285 |
+
value_states.to(torch.bfloat16).unsqueeze(0),
|
| 286 |
+
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
|
| 287 |
+
)
|
| 288 |
+
upacked_attn_output.append(attn_output.squeeze(0))
|
| 289 |
+
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
|
| 290 |
+
else:
|
| 291 |
+
pad_size = sum(sample_lens) - packed_query_states.shape[0]
|
| 292 |
+
packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
|
| 293 |
+
packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
|
| 294 |
+
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
|
| 295 |
+
packed_attn_output = flex_attention(
|
| 296 |
+
packed_query_states.unsqueeze(0),
|
| 297 |
+
packed_key_states.unsqueeze(0),
|
| 298 |
+
packed_value_states.unsqueeze(0),
|
| 299 |
+
enable_gqa=True,
|
| 300 |
+
block_mask=attention_mask,
|
| 301 |
+
)
|
| 302 |
+
end_index = packed_attn_output.shape[2] - pad_size
|
| 303 |
+
packed_attn_output = packed_attn_output[0, :, :end_index, :]
|
| 304 |
+
|
| 305 |
+
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
|
| 306 |
+
packed_attn_output = self.o_proj(packed_attn_output)
|
| 307 |
+
|
| 308 |
+
return packed_attn_output
|
| 309 |
+
|
| 310 |
+
def forward_inference(
|
| 311 |
+
self,
|
| 312 |
+
packed_query_sequence: torch.Tensor,
|
| 313 |
+
query_lens: torch.Tensor,
|
| 314 |
+
packed_query_position_embeddings: torch.Tensor,
|
| 315 |
+
packed_query_indexes: torch.Tensor,
|
| 316 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 317 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 318 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 319 |
+
update_past_key_values=True,
|
| 320 |
+
is_causal=True,
|
| 321 |
+
):
|
| 322 |
+
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
|
| 323 |
+
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 324 |
+
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 325 |
+
|
| 326 |
+
packed_query_states = self.q_norm(packed_query_states)
|
| 327 |
+
packed_key_states = self.k_norm(packed_key_states)
|
| 328 |
+
|
| 329 |
+
packed_cos, packed_sin = packed_query_position_embeddings
|
| 330 |
+
packed_query_states, packed_key_states = apply_rotary_pos_emb(
|
| 331 |
+
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
packed_query_states = packed_query_states.to(torch.bfloat16)
|
| 335 |
+
packed_key_states = packed_key_states.to(torch.bfloat16)
|
| 336 |
+
packed_value_states = packed_value_states.to(torch.bfloat16)
|
| 337 |
+
|
| 338 |
+
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
|
| 339 |
+
past_key_states = past_key_values.key_cache[self.layer_idx]
|
| 340 |
+
past_value_states = past_key_values.value_cache[self.layer_idx]
|
| 341 |
+
|
| 342 |
+
seqlens = sum(query_lens) + sum(key_values_lens)
|
| 343 |
+
merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
|
| 344 |
+
merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
|
| 345 |
+
merged_key_states[packed_query_indexes] = packed_key_states
|
| 346 |
+
merged_key_states[packed_key_value_indexes] = past_key_states
|
| 347 |
+
merged_value_states[packed_query_indexes] = packed_value_states
|
| 348 |
+
merged_value_states[packed_key_value_indexes] = past_value_states
|
| 349 |
+
key_values_lens = key_values_lens + query_lens
|
| 350 |
+
else:
|
| 351 |
+
merged_key_states = packed_key_states
|
| 352 |
+
merged_value_states = packed_value_states
|
| 353 |
+
key_values_lens = query_lens
|
| 354 |
+
|
| 355 |
+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
|
| 356 |
+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
|
| 357 |
+
|
| 358 |
+
packed_attn_output = flash_attn_varlen_func(
|
| 359 |
+
q=packed_query_states,
|
| 360 |
+
k=merged_key_states,
|
| 361 |
+
v=merged_value_states,
|
| 362 |
+
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
|
| 363 |
+
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
|
| 364 |
+
max_seqlen_q=max(query_lens).item(),
|
| 365 |
+
max_seqlen_k=max(key_values_lens).item(),
|
| 366 |
+
causal=is_causal,
|
| 367 |
+
)
|
| 368 |
+
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
|
| 369 |
+
packed_attn_output = self.o_proj(packed_attn_output)
|
| 370 |
+
|
| 371 |
+
if update_past_key_values:
|
| 372 |
+
past_key_values.key_cache[self.layer_idx] = merged_key_states
|
| 373 |
+
past_key_values.value_cache[self.layer_idx] = merged_value_states
|
| 374 |
+
|
| 375 |
+
return packed_attn_output, past_key_values
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class PackedAttentionMoT(Qwen2Attention):
|
| 379 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 380 |
+
super().__init__(config, layer_idx)
|
| 381 |
+
if self.config.qk_norm:
|
| 382 |
+
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 383 |
+
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 384 |
+
self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 385 |
+
self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 386 |
+
else:
|
| 387 |
+
self.q_norm = nn.Identity()
|
| 388 |
+
self.k_norm = nn.Identity()
|
| 389 |
+
self.q_norm_moe_gen = nn.Identity()
|
| 390 |
+
self.k_norm_moe_gen = nn.Identity()
|
| 391 |
+
|
| 392 |
+
self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 393 |
+
self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 394 |
+
self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 395 |
+
self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 396 |
+
|
| 397 |
+
def forward(self, *args, **kwargs):
|
| 398 |
+
if self.training:
|
| 399 |
+
return self.forward_train(*args, **kwargs)
|
| 400 |
+
else:
|
| 401 |
+
return self.forward_inference(*args, **kwargs)
|
| 402 |
+
|
| 403 |
+
def forward_train(
|
| 404 |
+
self,
|
| 405 |
+
packed_sequence: torch.Tensor,
|
| 406 |
+
sample_lens: List[int],
|
| 407 |
+
attention_mask,
|
| 408 |
+
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 409 |
+
packed_und_token_indexes: torch.LongTensor,
|
| 410 |
+
packed_gen_token_indexes: torch.LongTensor,
|
| 411 |
+
):
|
| 412 |
+
packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
|
| 413 |
+
packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
|
| 414 |
+
packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
|
| 415 |
+
|
| 416 |
+
packed_sequence_und = packed_sequence[packed_und_token_indexes]
|
| 417 |
+
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
|
| 418 |
+
|
| 419 |
+
packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
|
| 420 |
+
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)
|
| 421 |
+
|
| 422 |
+
packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
|
| 423 |
+
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)
|
| 424 |
+
|
| 425 |
+
packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
|
| 426 |
+
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)
|
| 427 |
+
|
| 428 |
+
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
|
| 429 |
+
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
|
| 430 |
+
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
|
| 431 |
+
if self.config.freeze_und:
|
| 432 |
+
packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()
|
| 433 |
+
|
| 434 |
+
packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
|
| 435 |
+
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
|
| 436 |
+
|
| 437 |
+
packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
|
| 438 |
+
if self.config.freeze_und:
|
| 439 |
+
packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
|
| 440 |
+
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])
|
| 441 |
+
|
| 442 |
+
packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
|
| 443 |
+
if self.config.freeze_und:
|
| 444 |
+
packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
|
| 445 |
+
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])
|
| 446 |
+
|
| 447 |
+
packed_cos, packed_sin = packed_position_embeddings
|
| 448 |
+
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
|
| 449 |
+
packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if isinstance(attention_mask, List):
|
| 453 |
+
packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
|
| 454 |
+
packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
|
| 455 |
+
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
|
| 456 |
+
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
|
| 457 |
+
|
| 458 |
+
unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
|
| 459 |
+
unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
|
| 460 |
+
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
|
| 461 |
+
upacked_attn_output = []
|
| 462 |
+
for query_states, key_states, value_states, attention_mask_per_sample in zip(
|
| 463 |
+
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
|
| 464 |
+
):
|
| 465 |
+
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
| 466 |
+
attn_output = scaled_dot_product_attention(
|
| 467 |
+
query_states.to(torch.bfloat16).unsqueeze(0),
|
| 468 |
+
key_states.to(torch.bfloat16).unsqueeze(0),
|
| 469 |
+
value_states.to(torch.bfloat16).unsqueeze(0),
|
| 470 |
+
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
|
| 471 |
+
)
|
| 472 |
+
upacked_attn_output.append(attn_output.squeeze(0))
|
| 473 |
+
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
|
| 474 |
+
else:
|
| 475 |
+
pad_size = sum(sample_lens) - packed_query_states.shape[0]
|
| 476 |
+
packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
|
| 477 |
+
packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
|
| 478 |
+
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
|
| 479 |
+
packed_attn_output = flex_attention(
|
| 480 |
+
packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
|
| 481 |
+
packed_key_states_.unsqueeze(0),
|
| 482 |
+
packed_value_states.unsqueeze(0),
|
| 483 |
+
enable_gqa=True,
|
| 484 |
+
block_mask=attention_mask,
|
| 485 |
+
)
|
| 486 |
+
end_index = packed_attn_output.shape[2] - pad_size
|
| 487 |
+
packed_attn_output = packed_attn_output[0, :, :end_index, :]
|
| 488 |
+
|
| 489 |
+
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
|
| 490 |
+
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
|
| 491 |
+
packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
|
| 492 |
+
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])
|
| 493 |
+
|
| 494 |
+
return packed_attn_output_
|
| 495 |
+
|
| 496 |
+
def forward_inference(
|
| 497 |
+
self,
|
| 498 |
+
packed_query_sequence: torch.Tensor,
|
| 499 |
+
query_lens: torch.Tensor,
|
| 500 |
+
packed_query_position_embeddings: torch.Tensor,
|
| 501 |
+
packed_query_indexes: torch.Tensor,
|
| 502 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 503 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 504 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 505 |
+
update_past_key_values=True,
|
| 506 |
+
is_causal=True,
|
| 507 |
+
mode="und",
|
| 508 |
+
packed_vae_token_indexes=None,
|
| 509 |
+
packed_text_indexes=None,
|
| 510 |
+
):
|
| 511 |
+
if mode == 'und':
|
| 512 |
+
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
|
| 513 |
+
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 514 |
+
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
|
| 515 |
+
packed_query_states = self.q_norm(packed_query_states)
|
| 516 |
+
packed_key_states = self.k_norm(packed_key_states)
|
| 517 |
+
elif mode == 'gen':
|
| 518 |
+
packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
|
| 519 |
+
packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
|
| 520 |
+
packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
|
| 521 |
+
packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
|
| 522 |
+
|
| 523 |
+
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
|
| 524 |
+
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
|
| 525 |
+
|
| 526 |
+
packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
|
| 527 |
+
packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)
|
| 528 |
+
|
| 529 |
+
packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
|
| 530 |
+
packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)
|
| 531 |
+
|
| 532 |
+
packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
|
| 533 |
+
packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)
|
| 534 |
+
|
| 535 |
+
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
|
| 536 |
+
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
|
| 537 |
+
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
|
| 538 |
+
|
| 539 |
+
packed_query_states = packed_query_states.to(torch.float32)
|
| 540 |
+
packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
|
| 541 |
+
packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])
|
| 542 |
+
|
| 543 |
+
packed_key_states = packed_key_states.to(torch.float32)
|
| 544 |
+
packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
|
| 545 |
+
packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])
|
| 546 |
+
|
| 547 |
+
packed_cos, packed_sin = packed_query_position_embeddings
|
| 548 |
+
packed_query_states, packed_key_states = apply_rotary_pos_emb(
|
| 549 |
+
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
packed_query_states = packed_query_states.to(torch.bfloat16)
|
| 553 |
+
packed_key_states = packed_key_states.to(torch.bfloat16)
|
| 554 |
+
packed_value_states = packed_value_states.to(torch.bfloat16)
|
| 555 |
+
|
| 556 |
+
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
|
| 557 |
+
past_key_states = past_key_values.key_cache[self.layer_idx]
|
| 558 |
+
past_value_states = past_key_values.value_cache[self.layer_idx]
|
| 559 |
+
|
| 560 |
+
seqlens = sum(query_lens) + sum(key_values_lens)
|
| 561 |
+
merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
|
| 562 |
+
merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
|
| 563 |
+
merged_key_states[packed_query_indexes] = packed_key_states
|
| 564 |
+
merged_key_states[packed_key_value_indexes] = past_key_states
|
| 565 |
+
merged_value_states[packed_query_indexes] = packed_value_states
|
| 566 |
+
merged_value_states[packed_key_value_indexes] = past_value_states
|
| 567 |
+
key_values_lens = key_values_lens + query_lens
|
| 568 |
+
else:
|
| 569 |
+
merged_key_states = packed_key_states
|
| 570 |
+
merged_value_states = packed_value_states
|
| 571 |
+
key_values_lens = query_lens
|
| 572 |
+
|
| 573 |
+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
|
| 574 |
+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
|
| 575 |
+
|
| 576 |
+
packed_attn_output = flash_attn_varlen_func(
|
| 577 |
+
q=packed_query_states,
|
| 578 |
+
k=merged_key_states,
|
| 579 |
+
v=merged_value_states,
|
| 580 |
+
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
|
| 581 |
+
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
|
| 582 |
+
max_seqlen_q=max(query_lens).item(),
|
| 583 |
+
max_seqlen_k=max(key_values_lens).item(),
|
| 584 |
+
causal=is_causal,
|
| 585 |
+
)
|
| 586 |
+
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
|
| 587 |
+
if mode == 'und':
|
| 588 |
+
packed_attn_output = self.o_proj(packed_attn_output)
|
| 589 |
+
elif mode == 'gen':
|
| 590 |
+
packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
|
| 591 |
+
packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])
|
| 592 |
+
|
| 593 |
+
if update_past_key_values:
|
| 594 |
+
past_key_values.key_cache[self.layer_idx] = merged_key_states
|
| 595 |
+
past_key_values.value_cache[self.layer_idx] = merged_value_states
|
| 596 |
+
|
| 597 |
+
return packed_attn_output, past_key_values
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class Qwen2DecoderLayer(nn.Module):
|
| 601 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 602 |
+
super().__init__()
|
| 603 |
+
self.hidden_size = config.hidden_size
|
| 604 |
+
|
| 605 |
+
self.self_attn = PackedAttention(config, layer_idx)
|
| 606 |
+
|
| 607 |
+
self.mlp = Qwen2MLP(config)
|
| 608 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 609 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 610 |
+
|
| 611 |
+
def forward(self, *args, **kwargs):
|
| 612 |
+
if self.training:
|
| 613 |
+
return self.forward_train(*args, **kwargs)
|
| 614 |
+
else:
|
| 615 |
+
return self.forward_inference(*args, **kwargs)
|
| 616 |
+
|
| 617 |
+
def forward_train(
|
| 618 |
+
self,
|
| 619 |
+
packed_sequence: torch.Tensor,
|
| 620 |
+
sample_lens: List[int],
|
| 621 |
+
attention_mask,
|
| 622 |
+
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 623 |
+
) -> torch.Tensor:
|
| 624 |
+
|
| 625 |
+
residual = packed_sequence
|
| 626 |
+
packed_sequence = self.input_layernorm(packed_sequence)
|
| 627 |
+
|
| 628 |
+
# Self Attention
|
| 629 |
+
packed_sequence = self.self_attn(
|
| 630 |
+
packed_sequence=packed_sequence,
|
| 631 |
+
sample_lens=sample_lens,
|
| 632 |
+
attention_mask=attention_mask,
|
| 633 |
+
packed_position_embeddings=packed_position_embeddings,
|
| 634 |
+
)
|
| 635 |
+
packed_sequence = residual + packed_sequence
|
| 636 |
+
|
| 637 |
+
# Fully Connected
|
| 638 |
+
residual = packed_sequence
|
| 639 |
+
packed_sequence = self.post_attention_layernorm(packed_sequence)
|
| 640 |
+
packed_sequence = self.mlp(packed_sequence)
|
| 641 |
+
packed_sequence = residual + packed_sequence
|
| 642 |
+
|
| 643 |
+
return packed_sequence
|
| 644 |
+
|
| 645 |
+
def forward_inference(
|
| 646 |
+
self,
|
| 647 |
+
packed_query_sequence: torch.Tensor,
|
| 648 |
+
query_lens: torch.Tensor,
|
| 649 |
+
packed_query_position_embeddings: torch.Tensor,
|
| 650 |
+
packed_query_indexes: torch.Tensor,
|
| 651 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 652 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 653 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 654 |
+
update_past_key_values=True,
|
| 655 |
+
is_causal=True,
|
| 656 |
+
) -> BaseNavitOutputWithPast:
|
| 657 |
+
|
| 658 |
+
residual = packed_query_sequence
|
| 659 |
+
packed_query_sequence = self.input_layernorm(packed_query_sequence)
|
| 660 |
+
|
| 661 |
+
# Self Attention
|
| 662 |
+
packed_query_sequence, past_key_values = self.self_attn(
|
| 663 |
+
packed_query_sequence=packed_query_sequence,
|
| 664 |
+
query_lens=query_lens,
|
| 665 |
+
packed_query_position_embeddings=packed_query_position_embeddings,
|
| 666 |
+
packed_query_indexes=packed_query_indexes,
|
| 667 |
+
past_key_values=past_key_values,
|
| 668 |
+
key_values_lens=key_values_lens,
|
| 669 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 670 |
+
update_past_key_values=update_past_key_values,
|
| 671 |
+
is_causal=is_causal,
|
| 672 |
+
)
|
| 673 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 674 |
+
|
| 675 |
+
# Fully Connected
|
| 676 |
+
residual = packed_query_sequence
|
| 677 |
+
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
|
| 678 |
+
packed_query_sequence = self.mlp(packed_query_sequence)
|
| 679 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 680 |
+
|
| 681 |
+
return packed_query_sequence, past_key_values
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class Qwen2MoTDecoderLayer(nn.Module):
|
| 685 |
+
def __init__(
|
| 686 |
+
self,
|
| 687 |
+
config,
|
| 688 |
+
layer_idx: Optional[int] = None,
|
| 689 |
+
attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
|
| 690 |
+
):
|
| 691 |
+
super().__init__()
|
| 692 |
+
self.hidden_size = config.hidden_size
|
| 693 |
+
self.freeze_und = config.freeze_und
|
| 694 |
+
|
| 695 |
+
self.self_attn = attn_module(config, layer_idx)
|
| 696 |
+
|
| 697 |
+
self.mlp = Qwen2MLP(config)
|
| 698 |
+
self.mlp_moe_gen = Qwen2MLP(config)
|
| 699 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 700 |
+
self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 701 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 702 |
+
self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 703 |
+
|
| 704 |
+
def forward(self, *args, **kwargs):
|
| 705 |
+
if self.training:
|
| 706 |
+
return self.forward_train(*args, **kwargs)
|
| 707 |
+
else:
|
| 708 |
+
return self.forward_inference(*args, **kwargs)
|
| 709 |
+
|
| 710 |
+
def forward_train(
|
| 711 |
+
self,
|
| 712 |
+
packed_sequence: torch.Tensor,
|
| 713 |
+
sample_lens: List[int],
|
| 714 |
+
attention_mask,
|
| 715 |
+
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 716 |
+
packed_und_token_indexes: torch.LongTensor,
|
| 717 |
+
packed_gen_token_indexes: torch.LongTensor,
|
| 718 |
+
) -> torch.Tensor:
|
| 719 |
+
|
| 720 |
+
residual = packed_sequence
|
| 721 |
+
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
|
| 722 |
+
packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
|
| 723 |
+
packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
|
| 724 |
+
|
| 725 |
+
# Self Attention
|
| 726 |
+
packed_sequence_ = self.self_attn(
|
| 727 |
+
packed_sequence=packed_sequence_,
|
| 728 |
+
sample_lens=sample_lens,
|
| 729 |
+
attention_mask=attention_mask,
|
| 730 |
+
packed_position_embeddings=packed_position_embeddings,
|
| 731 |
+
packed_und_token_indexes=packed_und_token_indexes,
|
| 732 |
+
packed_gen_token_indexes=packed_gen_token_indexes,
|
| 733 |
+
)
|
| 734 |
+
if self.freeze_und:
|
| 735 |
+
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
|
| 736 |
+
packed_sequence = residual + packed_sequence_
|
| 737 |
+
|
| 738 |
+
# Fully Connected
|
| 739 |
+
residual = packed_sequence
|
| 740 |
+
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
|
| 741 |
+
packed_sequence_[packed_und_token_indexes] = self.mlp(
|
| 742 |
+
self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
|
| 743 |
+
)
|
| 744 |
+
if self.freeze_und:
|
| 745 |
+
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
|
| 746 |
+
|
| 747 |
+
packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
|
| 748 |
+
self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
|
| 749 |
+
)
|
| 750 |
+
packed_sequence = residual + packed_sequence_
|
| 751 |
+
|
| 752 |
+
return packed_sequence
|
| 753 |
+
|
| 754 |
+
def forward_inference(
|
| 755 |
+
self,
|
| 756 |
+
packed_query_sequence: torch.Tensor,
|
| 757 |
+
query_lens: torch.Tensor,
|
| 758 |
+
packed_query_position_embeddings: torch.Tensor,
|
| 759 |
+
packed_query_indexes: torch.Tensor,
|
| 760 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 761 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 762 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 763 |
+
update_past_key_values=True,
|
| 764 |
+
is_causal=True,
|
| 765 |
+
mode="und",
|
| 766 |
+
packed_vae_token_indexes=None,
|
| 767 |
+
packed_text_indexes=None,
|
| 768 |
+
) -> BaseNavitOutputWithPast:
|
| 769 |
+
|
| 770 |
+
residual = packed_query_sequence
|
| 771 |
+
if mode == "und":
|
| 772 |
+
packed_query_sequence = self.input_layernorm(packed_query_sequence)
|
| 773 |
+
elif mode == "gen":
|
| 774 |
+
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
|
| 775 |
+
packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
|
| 776 |
+
packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
|
| 777 |
+
packed_query_sequence = packed_query_sequence_
|
| 778 |
+
|
| 779 |
+
# Self Attention
|
| 780 |
+
packed_query_sequence, past_key_values = self.self_attn(
|
| 781 |
+
packed_query_sequence=packed_query_sequence,
|
| 782 |
+
query_lens=query_lens,
|
| 783 |
+
packed_query_position_embeddings=packed_query_position_embeddings,
|
| 784 |
+
packed_query_indexes=packed_query_indexes,
|
| 785 |
+
past_key_values=past_key_values,
|
| 786 |
+
key_values_lens=key_values_lens,
|
| 787 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 788 |
+
update_past_key_values=update_past_key_values,
|
| 789 |
+
is_causal=is_causal,
|
| 790 |
+
mode=mode,
|
| 791 |
+
packed_vae_token_indexes=packed_vae_token_indexes,
|
| 792 |
+
packed_text_indexes=packed_text_indexes,
|
| 793 |
+
)
|
| 794 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 795 |
+
|
| 796 |
+
# Fully Connected
|
| 797 |
+
residual = packed_query_sequence
|
| 798 |
+
if mode == "und":
|
| 799 |
+
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
|
| 800 |
+
packed_query_sequence = self.mlp(packed_query_sequence)
|
| 801 |
+
elif mode == "gen":
|
| 802 |
+
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
|
| 803 |
+
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
|
| 804 |
+
packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
|
| 805 |
+
packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)
|
| 806 |
+
|
| 807 |
+
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
|
| 808 |
+
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence)
|
| 809 |
+
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence)
|
| 810 |
+
packed_query_sequence = packed_query_sequence_
|
| 811 |
+
|
| 812 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 813 |
+
return packed_query_sequence, past_key_values
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
class Qwen2MoEDecoderLayer(nn.Module):
|
| 817 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
| 818 |
+
super().__init__()
|
| 819 |
+
self.hidden_size = config.hidden_size
|
| 820 |
+
|
| 821 |
+
self.self_attn = PackedAttention(config, layer_idx)
|
| 822 |
+
|
| 823 |
+
self.mlp = Qwen2MLP(config)
|
| 824 |
+
self.mlp_moe_gen = Qwen2MLP(config)
|
| 825 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 826 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 827 |
+
|
| 828 |
+
def forward(self, *args, **kwargs):
|
| 829 |
+
if self.training:
|
| 830 |
+
return self.forward_train(*args, **kwargs)
|
| 831 |
+
else:
|
| 832 |
+
return self.forward_inference(*args, **kwargs)
|
| 833 |
+
|
| 834 |
+
def forward_train(
|
| 835 |
+
self,
|
| 836 |
+
packed_sequence: torch.Tensor,
|
| 837 |
+
sample_lens: List[int],
|
| 838 |
+
attention_mask,
|
| 839 |
+
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 840 |
+
packed_und_token_indexes: torch.LongTensor,
|
| 841 |
+
packed_gen_token_indexes: torch.LongTensor,
|
| 842 |
+
) -> torch.Tensor:
|
| 843 |
+
|
| 844 |
+
residual = packed_sequence
|
| 845 |
+
packed_sequence = self.input_layernorm(packed_sequence)
|
| 846 |
+
|
| 847 |
+
# Self Attention
|
| 848 |
+
packed_sequence = self.self_attn(
|
| 849 |
+
packed_sequence=packed_sequence,
|
| 850 |
+
sample_lens=sample_lens,
|
| 851 |
+
attention_mask=attention_mask,
|
| 852 |
+
packed_position_embeddings=packed_position_embeddings,
|
| 853 |
+
)
|
| 854 |
+
packed_sequence = residual + packed_sequence
|
| 855 |
+
|
| 856 |
+
# Fully Connected
|
| 857 |
+
residual = packed_sequence
|
| 858 |
+
packed_sequence = self.post_attention_layernorm(packed_sequence)
|
| 859 |
+
|
| 860 |
+
packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
|
| 861 |
+
packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
|
| 862 |
+
packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
|
| 863 |
+
packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
|
| 864 |
+
packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
|
| 865 |
+
|
| 866 |
+
packed_sequence = residual + packed_sequence_new
|
| 867 |
+
|
| 868 |
+
return packed_sequence
|
| 869 |
+
|
| 870 |
+
def forward_inference(
|
| 871 |
+
self,
|
| 872 |
+
packed_query_sequence: torch.Tensor,
|
| 873 |
+
query_lens: torch.Tensor,
|
| 874 |
+
packed_query_position_embeddings: torch.Tensor,
|
| 875 |
+
packed_query_indexes: torch.Tensor,
|
| 876 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 877 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 878 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 879 |
+
update_past_key_values=True,
|
| 880 |
+
is_causal=True,
|
| 881 |
+
mode="und",
|
| 882 |
+
packed_vae_token_indexes=None,
|
| 883 |
+
packed_text_indexes=None,
|
| 884 |
+
) -> BaseNavitOutputWithPast:
|
| 885 |
+
|
| 886 |
+
residual = packed_query_sequence
|
| 887 |
+
packed_query_sequence = self.input_layernorm(packed_query_sequence)
|
| 888 |
+
|
| 889 |
+
# Self Attention
|
| 890 |
+
packed_query_sequence, past_key_values = self.self_attn(
|
| 891 |
+
packed_query_sequence=packed_query_sequence,
|
| 892 |
+
query_lens=query_lens,
|
| 893 |
+
packed_query_position_embeddings=packed_query_position_embeddings,
|
| 894 |
+
packed_query_indexes=packed_query_indexes,
|
| 895 |
+
past_key_values=past_key_values,
|
| 896 |
+
key_values_lens=key_values_lens,
|
| 897 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 898 |
+
update_past_key_values=update_past_key_values,
|
| 899 |
+
is_causal=is_causal,
|
| 900 |
+
)
|
| 901 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 902 |
+
|
| 903 |
+
# Fully Connected
|
| 904 |
+
residual = packed_query_sequence
|
| 905 |
+
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
|
| 906 |
+
if mode == "und":
|
| 907 |
+
packed_query_sequence = self.mlp(packed_query_sequence)
|
| 908 |
+
elif mode == "gen":
|
| 909 |
+
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
|
| 910 |
+
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
|
| 911 |
+
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
|
| 912 |
+
packed_query_sequence = packed_query_sequence_
|
| 913 |
+
packed_query_sequence = residual + packed_query_sequence
|
| 914 |
+
|
| 915 |
+
return packed_query_sequence, past_key_values
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
Decoder_layer_dict = {
|
| 919 |
+
"Qwen2DecoderLayer": Qwen2DecoderLayer,
|
| 920 |
+
"Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
|
| 921 |
+
"Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
class Qwen2Model(Qwen2PreTrainedModel):
|
| 926 |
+
def __init__(self, config):
|
| 927 |
+
super().__init__(config)
|
| 928 |
+
self.padding_idx = config.pad_token_id
|
| 929 |
+
self.vocab_size = config.vocab_size
|
| 930 |
+
self.use_moe = 'Mo' in config.layer_module
|
| 931 |
+
|
| 932 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 933 |
+
layer_module = Decoder_layer_dict[config.layer_module]
|
| 934 |
+
self.layers = nn.ModuleList(
|
| 935 |
+
[layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 939 |
+
if self.use_moe:
|
| 940 |
+
self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 941 |
+
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
| 942 |
+
|
| 943 |
+
# Initialize weights and apply final processing
|
| 944 |
+
self.post_init()
|
| 945 |
+
|
| 946 |
+
def forward(self, *args, **kwargs):
|
| 947 |
+
if self.training:
|
| 948 |
+
return self.forward_train(*args, **kwargs)
|
| 949 |
+
else:
|
| 950 |
+
return self.forward_inference(*args, **kwargs)
|
| 951 |
+
|
| 952 |
+
def forward_train(
|
| 953 |
+
self,
|
| 954 |
+
packed_sequence: torch.Tensor,
|
| 955 |
+
sample_lens: List[int],
|
| 956 |
+
attention_mask,
|
| 957 |
+
packed_position_ids: torch.Tensor,
|
| 958 |
+
packed_und_token_indexes: Optional[torch.LongTensor] = None,
|
| 959 |
+
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
|
| 960 |
+
) -> torch.Tensor:
|
| 961 |
+
|
| 962 |
+
if self.config.freeze_und:
|
| 963 |
+
packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()
|
| 964 |
+
|
| 965 |
+
# create position embeddings to be shared across the decoder layers
|
| 966 |
+
cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
|
| 967 |
+
cos = cos.squeeze(0)
|
| 968 |
+
sin = sin.squeeze(0)
|
| 969 |
+
packed_position_embeddings = (cos, sin)
|
| 970 |
+
|
| 971 |
+
extra_inputs = {}
|
| 972 |
+
if self.use_moe:
|
| 973 |
+
assert packed_und_token_indexes is not None
|
| 974 |
+
if packed_gen_token_indexes is None:
|
| 975 |
+
packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
|
| 976 |
+
extra_inputs.update(
|
| 977 |
+
packed_und_token_indexes=packed_und_token_indexes,
|
| 978 |
+
packed_gen_token_indexes=packed_gen_token_indexes,
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
for decoder_layer in self.layers:
|
| 982 |
+
packed_sequence = decoder_layer(
|
| 983 |
+
packed_sequence=packed_sequence,
|
| 984 |
+
sample_lens=sample_lens,
|
| 985 |
+
attention_mask=attention_mask,
|
| 986 |
+
packed_position_embeddings=packed_position_embeddings,
|
| 987 |
+
**extra_inputs
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
if self.use_moe:
|
| 991 |
+
packed_sequence_ = torch.zeros_like(packed_sequence)
|
| 992 |
+
packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
|
| 993 |
+
if self.config.freeze_und:
|
| 994 |
+
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
|
| 995 |
+
packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
|
| 996 |
+
return packed_sequence_
|
| 997 |
+
else:
|
| 998 |
+
return self.norm(packed_sequence)
|
| 999 |
+
|
| 1000 |
+
def forward_inference(
|
| 1001 |
+
self,
|
| 1002 |
+
packed_query_sequence: torch.Tensor,
|
| 1003 |
+
query_lens: torch.Tensor,
|
| 1004 |
+
packed_query_position_ids: torch.Tensor,
|
| 1005 |
+
packed_query_indexes: torch.Tensor,
|
| 1006 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 1007 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 1008 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 1009 |
+
update_past_key_values=True,
|
| 1010 |
+
is_causal=True,
|
| 1011 |
+
mode="und",
|
| 1012 |
+
packed_vae_token_indexes=None,
|
| 1013 |
+
packed_text_indexes=None,
|
| 1014 |
+
) -> BaseNavitOutputWithPast:
|
| 1015 |
+
|
| 1016 |
+
# create position embeddings to be shared across the decoder layers
|
| 1017 |
+
cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))
|
| 1018 |
+
cos = cos.squeeze(0)
|
| 1019 |
+
sin = sin.squeeze(0)
|
| 1020 |
+
packed_query_position_embeddings = (cos, sin)
|
| 1021 |
+
|
| 1022 |
+
extra_inputs = {}
|
| 1023 |
+
if self.use_moe:
|
| 1024 |
+
extra_inputs.update(mode=mode)
|
| 1025 |
+
if mode == 'gen':
|
| 1026 |
+
assert packed_vae_token_indexes is not None
|
| 1027 |
+
assert packed_text_indexes is not None
|
| 1028 |
+
extra_inputs.update(
|
| 1029 |
+
packed_vae_token_indexes=packed_vae_token_indexes,
|
| 1030 |
+
packed_text_indexes=packed_text_indexes,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
for decoder_layer in self.layers:
|
| 1034 |
+
packed_query_sequence, past_key_values = decoder_layer(
|
| 1035 |
+
packed_query_sequence=packed_query_sequence,
|
| 1036 |
+
query_lens=query_lens,
|
| 1037 |
+
packed_query_position_embeddings=packed_query_position_embeddings,
|
| 1038 |
+
packed_query_indexes=packed_query_indexes,
|
| 1039 |
+
past_key_values=past_key_values,
|
| 1040 |
+
key_values_lens=key_values_lens,
|
| 1041 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 1042 |
+
update_past_key_values=update_past_key_values,
|
| 1043 |
+
is_causal=is_causal,
|
| 1044 |
+
**extra_inputs,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
if self.use_moe:
|
| 1048 |
+
if mode == "und":
|
| 1049 |
+
packed_query_sequence = self.norm(packed_query_sequence)
|
| 1050 |
+
elif mode == "gen":
|
| 1051 |
+
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
|
| 1052 |
+
packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
|
| 1053 |
+
packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
|
| 1054 |
+
packed_query_sequence = packed_query_sequence_
|
| 1055 |
+
else:
|
| 1056 |
+
packed_query_sequence = self.norm(packed_query_sequence)
|
| 1057 |
+
|
| 1058 |
+
return BaseNavitOutputWithPast(
|
| 1059 |
+
packed_query_sequence=packed_query_sequence,
|
| 1060 |
+
past_key_values=past_key_values,
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
| 1065 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1066 |
+
|
| 1067 |
+
def __init__(self, config):
|
| 1068 |
+
super().__init__(config)
|
| 1069 |
+
self.model = Qwen2Model(config)
|
| 1070 |
+
self.vocab_size = config.vocab_size
|
| 1071 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1072 |
+
|
| 1073 |
+
# Initialize weights and apply final processing
|
| 1074 |
+
self.post_init()
|
| 1075 |
+
|
| 1076 |
+
def init_moe(self):
|
| 1077 |
+
for name, param in self.named_parameters():
|
| 1078 |
+
if "moe_gen" in name:
|
| 1079 |
+
original_name = name.replace("_moe_gen", "")
|
| 1080 |
+
param.data.copy_(self.state_dict()[original_name].data)
|
| 1081 |
+
|
| 1082 |
+
def get_input_embeddings(self):
|
| 1083 |
+
return self.model.embed_tokens
|
| 1084 |
+
|
| 1085 |
+
def set_input_embeddings(self, value):
|
| 1086 |
+
self.model.embed_tokens = value
|
| 1087 |
+
|
| 1088 |
+
def get_output_embeddings(self):
|
| 1089 |
+
return self.lm_head
|
| 1090 |
+
|
| 1091 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1092 |
+
self.lm_head = new_embeddings
|
| 1093 |
+
|
| 1094 |
+
def set_decoder(self, decoder):
|
| 1095 |
+
self.model = decoder
|
| 1096 |
+
|
| 1097 |
+
def get_decoder(self):
|
| 1098 |
+
return self.model
|
| 1099 |
+
|
| 1100 |
+
def forward(self, *args, **kwargs):
|
| 1101 |
+
if self.training:
|
| 1102 |
+
return self.forward_train(*args, **kwargs)
|
| 1103 |
+
else:
|
| 1104 |
+
return self.forward_inference(*args, **kwargs)
|
| 1105 |
+
|
| 1106 |
+
def forward_train(
|
| 1107 |
+
self,
|
| 1108 |
+
packed_sequence: torch.Tensor,
|
| 1109 |
+
sample_lens: List[int],
|
| 1110 |
+
attention_mask,
|
| 1111 |
+
packed_position_ids: torch.Tensor,
|
| 1112 |
+
packed_und_token_indexes: Optional[torch.LongTensor] = None,
|
| 1113 |
+
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
|
| 1114 |
+
) -> torch.Tensor:
|
| 1115 |
+
|
| 1116 |
+
outputs = self.model(
|
| 1117 |
+
packed_sequence=packed_sequence,
|
| 1118 |
+
sample_lens=sample_lens,
|
| 1119 |
+
packed_position_ids=packed_position_ids,
|
| 1120 |
+
attention_mask=attention_mask,
|
| 1121 |
+
packed_und_token_indexes=packed_und_token_indexes,
|
| 1122 |
+
packed_gen_token_indexes=packed_gen_token_indexes,
|
| 1123 |
+
)
|
| 1124 |
+
return outputs
|
| 1125 |
+
|
| 1126 |
+
def forward_inference(
|
| 1127 |
+
self,
|
| 1128 |
+
packed_query_sequence: torch.Tensor,
|
| 1129 |
+
query_lens: torch.Tensor,
|
| 1130 |
+
packed_query_position_ids: torch.Tensor,
|
| 1131 |
+
packed_query_indexes: torch.Tensor,
|
| 1132 |
+
past_key_values: Optional[NaiveCache] = None,
|
| 1133 |
+
key_values_lens: Optional[torch.Tensor] = None,
|
| 1134 |
+
packed_key_value_indexes: Optional[torch.Tensor] = None,
|
| 1135 |
+
update_past_key_values=True,
|
| 1136 |
+
is_causal=True,
|
| 1137 |
+
mode="und",
|
| 1138 |
+
packed_vae_token_indexes=None,
|
| 1139 |
+
packed_text_indexes=None,
|
| 1140 |
+
) -> BaseNavitOutputWithPast:
|
| 1141 |
+
|
| 1142 |
+
outputs = self.model(
|
| 1143 |
+
packed_query_sequence=packed_query_sequence,
|
| 1144 |
+
query_lens=query_lens,
|
| 1145 |
+
packed_query_position_ids=packed_query_position_ids,
|
| 1146 |
+
packed_query_indexes=packed_query_indexes,
|
| 1147 |
+
past_key_values=past_key_values,
|
| 1148 |
+
key_values_lens=key_values_lens,
|
| 1149 |
+
packed_key_value_indexes=packed_key_value_indexes,
|
| 1150 |
+
update_past_key_values=update_past_key_values,
|
| 1151 |
+
is_causal=is_causal,
|
| 1152 |
+
mode=mode,
|
| 1153 |
+
packed_vae_token_indexes=packed_vae_token_indexes,
|
| 1154 |
+
packed_text_indexes=packed_text_indexes,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
return outputs
|
modeling/bagel/siglip_navit.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 The HuggingFace Inc. team.
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under Apache-2.0, with the full license text
|
| 8 |
+
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from transformers.activations import ACT2FN
|
| 16 |
+
from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
|
| 17 |
+
from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
|
| 18 |
+
from flash_attn import flash_attn_varlen_func
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SiglipVisionConfig(_SiglipVisionConfig):
|
| 22 |
+
r"""
|
| 23 |
+
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
| 24 |
+
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 25 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
|
| 26 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 27 |
+
|
| 28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 29 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 34 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 35 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 36 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 37 |
+
Number of hidden layers in the Transformer encoder.
|
| 38 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 40 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 41 |
+
Number of channels in the input images.
|
| 42 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 43 |
+
The size (resolution) of each image.
|
| 44 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 45 |
+
The size (resolution) of each patch.
|
| 46 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 47 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 48 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 49 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 50 |
+
The epsilon used by the layer normalization layers.
|
| 51 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
The dropout ratio for the attention probabilities.
|
| 53 |
+
|
| 54 |
+
Example:
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
| 58 |
+
|
| 59 |
+
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
|
| 60 |
+
>>> configuration = SiglipVisionConfig()
|
| 61 |
+
|
| 62 |
+
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 63 |
+
>>> model = SiglipVisionModel(configuration)
|
| 64 |
+
|
| 65 |
+
>>> # Accessing the model configuration
|
| 66 |
+
>>> configuration = model.config
|
| 67 |
+
```"""
|
| 68 |
+
|
| 69 |
+
model_type = "siglip_vision_model"
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
hidden_size=768,
|
| 74 |
+
intermediate_size=3072,
|
| 75 |
+
num_hidden_layers=12,
|
| 76 |
+
num_attention_heads=12,
|
| 77 |
+
num_channels=3,
|
| 78 |
+
image_size=224,
|
| 79 |
+
patch_size=16,
|
| 80 |
+
hidden_act="gelu_pytorch_tanh",
|
| 81 |
+
layer_norm_eps=1e-6,
|
| 82 |
+
attention_dropout=0.0,
|
| 83 |
+
rope=True,
|
| 84 |
+
**kwargs,
|
| 85 |
+
):
|
| 86 |
+
super().__init__(
|
| 87 |
+
hidden_size=hidden_size,
|
| 88 |
+
intermediate_size=intermediate_size,
|
| 89 |
+
num_hidden_layers=num_hidden_layers,
|
| 90 |
+
num_attention_heads=num_attention_heads,
|
| 91 |
+
num_channels=num_channels,
|
| 92 |
+
image_size=image_size,
|
| 93 |
+
patch_size=patch_size,
|
| 94 |
+
hidden_act=hidden_act,
|
| 95 |
+
layer_norm_eps=layer_norm_eps,
|
| 96 |
+
attention_dropout=attention_dropout,
|
| 97 |
+
**kwargs)
|
| 98 |
+
|
| 99 |
+
self.rope = rope
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class RotaryEmbedding2D(torch.nn.Module):
|
| 103 |
+
def __init__(self, dim, max_h, max_w, base=10000):
|
| 104 |
+
super().__init__()
|
| 105 |
+
freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
| 106 |
+
inv_freq = 1.0 / (base ** freq)
|
| 107 |
+
|
| 108 |
+
grid_h = torch.arange(0, max_h)
|
| 109 |
+
grid_h = grid_h.to(inv_freq.dtype)
|
| 110 |
+
grid_h = grid_h[:, None].repeat(1, max_w)
|
| 111 |
+
|
| 112 |
+
grid_w = torch.arange(0, max_w)
|
| 113 |
+
grid_w = grid_w.to(inv_freq.dtype)
|
| 114 |
+
grid_w = grid_w[None, :].repeat(max_h, 1)
|
| 115 |
+
|
| 116 |
+
cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
|
| 117 |
+
cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
|
| 118 |
+
|
| 119 |
+
self.register_buffer("cos_h", cos_h)
|
| 120 |
+
self.register_buffer("sin_h", sin_h)
|
| 121 |
+
self.register_buffer("cos_w", cos_w)
|
| 122 |
+
self.register_buffer("sin_w", sin_w)
|
| 123 |
+
|
| 124 |
+
def _forward_one_side(self, grid, inv_freq):
|
| 125 |
+
freqs = grid[..., None] * inv_freq[None, None, :]
|
| 126 |
+
emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
|
| 127 |
+
return emb.cos(), emb.sin()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def rotate_half(x):
|
| 131 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 132 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 133 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 137 |
+
# unsqueeze due to the head dimension
|
| 138 |
+
cos = cos.unsqueeze(1)
|
| 139 |
+
sin = sin.unsqueeze(1)
|
| 140 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 141 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 142 |
+
return q_embed, k_embed
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SiglipVisionEmbeddings(nn.Module):
|
| 146 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.config = config
|
| 149 |
+
self.embed_dim = config.hidden_size
|
| 150 |
+
self.image_size = config.image_size
|
| 151 |
+
self.patch_size = config.patch_size
|
| 152 |
+
|
| 153 |
+
self.patch_embedding = nn.Conv2d(
|
| 154 |
+
in_channels=config.num_channels,
|
| 155 |
+
out_channels=self.embed_dim,
|
| 156 |
+
kernel_size=self.patch_size,
|
| 157 |
+
stride=self.patch_size,
|
| 158 |
+
padding="valid",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
| 162 |
+
self.num_patches = self.num_patches_per_side**2
|
| 163 |
+
self.num_positions = self.num_patches
|
| 164 |
+
if not config.rope:
|
| 165 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 166 |
+
|
| 167 |
+
def convert_conv2d_to_linear(self, config, meta=False):
|
| 168 |
+
if meta:
|
| 169 |
+
linear_patch_embedding = nn.Linear(
|
| 170 |
+
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
linear_patch_embedding = nn.Linear(
|
| 174 |
+
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
|
| 175 |
+
)
|
| 176 |
+
W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
|
| 177 |
+
self.embed_dim, config.num_channels * self.patch_size ** 2
|
| 178 |
+
)
|
| 179 |
+
linear_patch_embedding.weight.data = W
|
| 180 |
+
linear_patch_embedding.bias.data = self.patch_embedding.bias.data
|
| 181 |
+
del self.patch_embedding
|
| 182 |
+
self.patch_embedding = linear_patch_embedding
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self,
|
| 186 |
+
packed_pixel_values: torch.FloatTensor,
|
| 187 |
+
packed_flattened_position_ids: torch.LongTensor
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
|
| 190 |
+
patch_embeds = self.patch_embedding(packed_pixel_values)
|
| 191 |
+
if not self.config.rope:
|
| 192 |
+
embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
|
| 193 |
+
else:
|
| 194 |
+
embeddings = patch_embeds
|
| 195 |
+
return embeddings
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class SiglipFlashAttention2(SiglipAttention):
|
| 199 |
+
def __init__(self, *args, **kwargs):
|
| 200 |
+
super().__init__(*args, **kwargs)
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
hidden_states: torch.Tensor,
|
| 205 |
+
cu_seqlens: torch.IntTensor,
|
| 206 |
+
max_seqlen: int,
|
| 207 |
+
cos_h: torch.Tensor = None,
|
| 208 |
+
sin_h: torch.Tensor = None,
|
| 209 |
+
cos_w: torch.Tensor = None,
|
| 210 |
+
sin_w: torch.Tensor = None,
|
| 211 |
+
**kwargs,
|
| 212 |
+
) -> torch.Tensor:
|
| 213 |
+
|
| 214 |
+
total_q_len, _ = hidden_states.size()
|
| 215 |
+
|
| 216 |
+
query_states = self.q_proj(hidden_states)
|
| 217 |
+
key_states = self.k_proj(hidden_states)
|
| 218 |
+
value_states = self.v_proj(hidden_states)
|
| 219 |
+
|
| 220 |
+
query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
|
| 221 |
+
key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
|
| 222 |
+
value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
|
| 223 |
+
|
| 224 |
+
if self.config.rope:
|
| 225 |
+
qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
|
| 226 |
+
kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
|
| 227 |
+
qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
|
| 228 |
+
qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
|
| 229 |
+
query_states = torch.cat([qh, qw], dim=-1)
|
| 230 |
+
key_states = torch.cat([kh, kw], dim=-1)
|
| 231 |
+
|
| 232 |
+
attn_output = flash_attn_varlen_func(
|
| 233 |
+
query_states.to(torch.bfloat16),
|
| 234 |
+
key_states.to(torch.bfloat16),
|
| 235 |
+
value_states.to(torch.bfloat16),
|
| 236 |
+
cu_seqlens_q=cu_seqlens,
|
| 237 |
+
cu_seqlens_k=cu_seqlens,
|
| 238 |
+
max_seqlen_q=max_seqlen,
|
| 239 |
+
max_seqlen_k=max_seqlen,
|
| 240 |
+
causal=False,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
|
| 244 |
+
return attn_output
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class SiglipMLP(nn.Module):
|
| 248 |
+
def __init__(self, config):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.config = config
|
| 251 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 252 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 253 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 254 |
+
|
| 255 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 256 |
+
hidden_states = self.fc1(hidden_states)
|
| 257 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 258 |
+
hidden_states = self.fc2(hidden_states)
|
| 259 |
+
return hidden_states
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class SiglipEncoderLayer(nn.Module):
|
| 263 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.embed_dim = config.hidden_size
|
| 266 |
+
self.self_attn = SiglipFlashAttention2(config)
|
| 267 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 268 |
+
self.mlp = SiglipMLP(config)
|
| 269 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 270 |
+
|
| 271 |
+
def forward(
|
| 272 |
+
self,
|
| 273 |
+
hidden_states: torch.Tensor,
|
| 274 |
+
cu_seqlens: torch.IntTensor,
|
| 275 |
+
max_seqlen: int,
|
| 276 |
+
cos_h: torch.Tensor = None,
|
| 277 |
+
sin_h: torch.Tensor = None,
|
| 278 |
+
cos_w: torch.Tensor = None,
|
| 279 |
+
sin_w: torch.Tensor = None
|
| 280 |
+
) -> torch.Tensor:
|
| 281 |
+
residual = hidden_states
|
| 282 |
+
|
| 283 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 284 |
+
hidden_states = self.self_attn(
|
| 285 |
+
hidden_states=hidden_states,
|
| 286 |
+
cu_seqlens=cu_seqlens,
|
| 287 |
+
max_seqlen=max_seqlen,
|
| 288 |
+
cos_h=cos_h,
|
| 289 |
+
sin_h=sin_h,
|
| 290 |
+
cos_w=cos_w,
|
| 291 |
+
sin_w=sin_w
|
| 292 |
+
)
|
| 293 |
+
hidden_states = residual + hidden_states
|
| 294 |
+
|
| 295 |
+
residual = hidden_states
|
| 296 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 297 |
+
hidden_states = self.mlp(hidden_states)
|
| 298 |
+
hidden_states = residual + hidden_states
|
| 299 |
+
|
| 300 |
+
return hidden_states
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class SiglipEncoder(nn.Module):
|
| 304 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.config = config
|
| 307 |
+
self.layers = nn.ModuleList(
|
| 308 |
+
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def forward(
|
| 312 |
+
self,
|
| 313 |
+
inputs_embeds: torch.Tensor,
|
| 314 |
+
cu_seqlens: torch.IntTensor,
|
| 315 |
+
max_seqlen: int,
|
| 316 |
+
cos_h: torch.Tensor = None,
|
| 317 |
+
sin_h: torch.Tensor = None,
|
| 318 |
+
cos_w: torch.Tensor = None,
|
| 319 |
+
sin_w: torch.Tensor = None,
|
| 320 |
+
) -> torch.Tensor:
|
| 321 |
+
|
| 322 |
+
hidden_states = inputs_embeds
|
| 323 |
+
for encoder_layer in self.layers:
|
| 324 |
+
hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
|
| 325 |
+
cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
|
| 326 |
+
|
| 327 |
+
return hidden_states
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class SiglipVisionTransformer(nn.Module):
|
| 331 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.config = config
|
| 334 |
+
embed_dim = config.hidden_size
|
| 335 |
+
|
| 336 |
+
self.embeddings = SiglipVisionEmbeddings(config)
|
| 337 |
+
if config.rope:
|
| 338 |
+
max_size = config.image_size // config.patch_size
|
| 339 |
+
dim_head = config.hidden_size // config.num_attention_heads
|
| 340 |
+
self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
|
| 341 |
+
|
| 342 |
+
self.encoder = SiglipEncoder(config)
|
| 343 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
packed_pixel_values: torch.Tensor,
|
| 348 |
+
packed_flattened_position_ids: torch.LongTensor,
|
| 349 |
+
cu_seqlens: torch.IntTensor,
|
| 350 |
+
max_seqlen: int,
|
| 351 |
+
) -> torch.Tensor:
|
| 352 |
+
hidden_states = self.embeddings(
|
| 353 |
+
packed_pixel_values=packed_pixel_values,
|
| 354 |
+
packed_flattened_position_ids=packed_flattened_position_ids
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
extra_inputs = {}
|
| 358 |
+
if self.config.rope:
|
| 359 |
+
extra_inputs.update(
|
| 360 |
+
cos_h = self.rope.cos_h[packed_flattened_position_ids],
|
| 361 |
+
sin_h = self.rope.sin_h[packed_flattened_position_ids],
|
| 362 |
+
cos_w = self.rope.cos_w[packed_flattened_position_ids],
|
| 363 |
+
sin_w = self.rope.sin_w[packed_flattened_position_ids]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
last_hidden_state = self.encoder(
|
| 367 |
+
inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
|
| 368 |
+
**extra_inputs
|
| 369 |
+
)
|
| 370 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 371 |
+
return last_hidden_state
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class SiglipVisionModel(SiglipPreTrainedModel):
|
| 375 |
+
config_class = SiglipVisionConfig
|
| 376 |
+
main_input_name = "packed_pixel_values"
|
| 377 |
+
|
| 378 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 379 |
+
super().__init__(config)
|
| 380 |
+
|
| 381 |
+
self.vision_model = SiglipVisionTransformer(config)
|
| 382 |
+
|
| 383 |
+
# Initialize weights and apply final processing
|
| 384 |
+
self.post_init()
|
| 385 |
+
|
| 386 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 387 |
+
return self.vision_model.embeddings.patch_embedding
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self,
|
| 391 |
+
packed_pixel_values: torch.Tensor,
|
| 392 |
+
packed_flattened_position_ids: torch.LongTensor,
|
| 393 |
+
cu_seqlens: torch.IntTensor,
|
| 394 |
+
max_seqlen: int,
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
|
| 397 |
+
return self.vision_model(
|
| 398 |
+
packed_pixel_values=packed_pixel_values,
|
| 399 |
+
packed_flattened_position_ids=packed_flattened_position_ids,
|
| 400 |
+
cu_seqlens=cu_seqlens,
|
| 401 |
+
max_seqlen=max_seqlen,
|
| 402 |
+
)
|
modeling/qwen2/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from transformers.utils import (
|
| 7 |
+
OptionalDependencyNotAvailable,
|
| 8 |
+
_LazyModule,
|
| 9 |
+
is_tokenizers_available,
|
| 10 |
+
is_torch_available,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_import_structure = {
|
| 15 |
+
"configuration_qwen2": ["Qwen2Config"],
|
| 16 |
+
"tokenization_qwen2": ["Qwen2Tokenizer"],
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
if not is_tokenizers_available():
|
| 21 |
+
raise OptionalDependencyNotAvailable()
|
| 22 |
+
except OptionalDependencyNotAvailable:
|
| 23 |
+
pass
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
if not is_torch_available():
|
| 29 |
+
raise OptionalDependencyNotAvailable()
|
| 30 |
+
except OptionalDependencyNotAvailable:
|
| 31 |
+
pass
|
| 32 |
+
else:
|
| 33 |
+
_import_structure["modeling_qwen2"] = [
|
| 34 |
+
"Qwen2ForCausalLM",
|
| 35 |
+
"Qwen2Model",
|
| 36 |
+
"Qwen2PreTrainedModel",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING:
|
| 41 |
+
from .configuration_qwen2 import Qwen2Config
|
| 42 |
+
from .tokenization_qwen2 import Qwen2Tokenizer
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
if not is_tokenizers_available():
|
| 46 |
+
raise OptionalDependencyNotAvailable()
|
| 47 |
+
except OptionalDependencyNotAvailable:
|
| 48 |
+
pass
|
| 49 |
+
else:
|
| 50 |
+
from .tokenization_qwen2_fast import Qwen2TokenizerFast
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
if not is_torch_available():
|
| 54 |
+
raise OptionalDependencyNotAvailable()
|
| 55 |
+
except OptionalDependencyNotAvailable:
|
| 56 |
+
pass
|
| 57 |
+
else:
|
| 58 |
+
from .modeling_qwen2 import (
|
| 59 |
+
Qwen2ForCausalLM,
|
| 60 |
+
Qwen2Model,
|
| 61 |
+
Qwen2PreTrainedModel,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
else:
|
| 66 |
+
import sys
|
| 67 |
+
|
| 68 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
modeling/qwen2/configuration_qwen2.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Qwen2 model configuration"""
|
| 5 |
+
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 8 |
+
from transformers.utils import logging
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Qwen2Config(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
|
| 17 |
+
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 18 |
+
with the defaults will yield a similar configuration to that of
|
| 19 |
+
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
|
| 20 |
+
|
| 21 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 22 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
| 27 |
+
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
|
| 28 |
+
`inputs_ids` passed when calling [`Qwen2Model`]
|
| 29 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 30 |
+
Dimension of the hidden representations.
|
| 31 |
+
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 32 |
+
Dimension of the MLP representations.
|
| 33 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 34 |
+
Number of hidden layers in the Transformer encoder.
|
| 35 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 36 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 37 |
+
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 38 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 39 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 40 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 41 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 42 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 43 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the decoder.
|
| 46 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 47 |
+
The maximum sequence length that this model might ever be used with.
|
| 48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 50 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 51 |
+
The epsilon used by the rms normalization layers.
|
| 52 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 53 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 54 |
+
relevant if `config.is_decoder=True`.
|
| 55 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 56 |
+
Whether the model's input and output word embeddings should be tied.
|
| 57 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 58 |
+
The base period of the RoPE embeddings.
|
| 59 |
+
rope_scaling (`Dict`, *optional*):
|
| 60 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 61 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 62 |
+
accordingly.
|
| 63 |
+
Expected contents:
|
| 64 |
+
`rope_type` (`str`):
|
| 65 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 66 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 67 |
+
`factor` (`float`, *optional*):
|
| 68 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 69 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 70 |
+
original maximum pre-trained length.
|
| 71 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 72 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 73 |
+
pretraining.
|
| 74 |
+
`attention_factor` (`float`, *optional*):
|
| 75 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 76 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 77 |
+
`factor` field to infer the suggested value.
|
| 78 |
+
`beta_fast` (`float`, *optional*):
|
| 79 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 80 |
+
ramp function. If unspecified, it defaults to 32.
|
| 81 |
+
`beta_slow` (`float`, *optional*):
|
| 82 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 83 |
+
ramp function. If unspecified, it defaults to 1.
|
| 84 |
+
`short_factor` (`List[float]`, *optional*):
|
| 85 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 86 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 87 |
+
size divided by the number of attention heads divided by 2
|
| 88 |
+
`long_factor` (`List[float]`, *optional*):
|
| 89 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 90 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 91 |
+
size divided by the number of attention heads divided by 2
|
| 92 |
+
`low_freq_factor` (`float`, *optional*):
|
| 93 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 94 |
+
`high_freq_factor` (`float`, *optional*):
|
| 95 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 96 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 97 |
+
Whether to use sliding window attention.
|
| 98 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 99 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 100 |
+
max_window_layers (`int`, *optional*, defaults to 28):
|
| 101 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 102 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 103 |
+
The dropout ratio for the attention probabilities.
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
>>> from transformers import Qwen2Model, Qwen2Config
|
| 107 |
+
|
| 108 |
+
>>> # Initializing a Qwen2 style configuration
|
| 109 |
+
>>> configuration = Qwen2Config()
|
| 110 |
+
|
| 111 |
+
>>> # Initializing a model from the Qwen2-7B style configuration
|
| 112 |
+
>>> model = Qwen2Model(configuration)
|
| 113 |
+
|
| 114 |
+
>>> # Accessing the model configuration
|
| 115 |
+
>>> configuration = model.config
|
| 116 |
+
```"""
|
| 117 |
+
|
| 118 |
+
model_type = "qwen2"
|
| 119 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
vocab_size=151936,
|
| 124 |
+
hidden_size=4096,
|
| 125 |
+
intermediate_size=22016,
|
| 126 |
+
num_hidden_layers=32,
|
| 127 |
+
num_attention_heads=32,
|
| 128 |
+
num_key_value_heads=32,
|
| 129 |
+
hidden_act="silu",
|
| 130 |
+
max_position_embeddings=32768,
|
| 131 |
+
initializer_range=0.02,
|
| 132 |
+
rms_norm_eps=1e-6,
|
| 133 |
+
use_cache=True,
|
| 134 |
+
tie_word_embeddings=False,
|
| 135 |
+
rope_theta=10000.0,
|
| 136 |
+
rope_scaling=None,
|
| 137 |
+
use_sliding_window=False,
|
| 138 |
+
sliding_window=4096,
|
| 139 |
+
max_window_layers=28,
|
| 140 |
+
attention_dropout=0.0,
|
| 141 |
+
is_causal=True,
|
| 142 |
+
_attn_implementation="flash_attention_2",
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
self.vocab_size = vocab_size
|
| 146 |
+
self.max_position_embeddings = max_position_embeddings
|
| 147 |
+
self.hidden_size = hidden_size
|
| 148 |
+
self.intermediate_size = intermediate_size
|
| 149 |
+
self.num_hidden_layers = num_hidden_layers
|
| 150 |
+
self.num_attention_heads = num_attention_heads
|
| 151 |
+
self.use_sliding_window = use_sliding_window
|
| 152 |
+
self.sliding_window = sliding_window if use_sliding_window else None
|
| 153 |
+
self.max_window_layers = max_window_layers
|
| 154 |
+
|
| 155 |
+
# for backward compatibility
|
| 156 |
+
if num_key_value_heads is None:
|
| 157 |
+
num_key_value_heads = num_attention_heads
|
| 158 |
+
|
| 159 |
+
self.num_key_value_heads = num_key_value_heads
|
| 160 |
+
self.hidden_act = hidden_act
|
| 161 |
+
self.initializer_range = initializer_range
|
| 162 |
+
self.rms_norm_eps = rms_norm_eps
|
| 163 |
+
self.use_cache = use_cache
|
| 164 |
+
self.rope_theta = rope_theta
|
| 165 |
+
self.rope_scaling = rope_scaling
|
| 166 |
+
self.attention_dropout = attention_dropout
|
| 167 |
+
self.is_causal = is_causal
|
| 168 |
+
self._attn_implementation = _attn_implementation
|
| 169 |
+
|
| 170 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 171 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 172 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 173 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 174 |
+
rope_config_validation(self)
|
| 175 |
+
|
| 176 |
+
super().__init__(
|
| 177 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
modeling/qwen2/modeling_qwen2.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""PyTorch Qwen2 model."""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from transformers.activations import ACT2FN
|
| 14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 15 |
+
from transformers.generation import GenerationMixin
|
| 16 |
+
from transformers.modeling_outputs import (
|
| 17 |
+
BaseModelOutputWithPast,
|
| 18 |
+
CausalLMOutputWithPast,
|
| 19 |
+
)
|
| 20 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 21 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 22 |
+
from transformers.utils import (
|
| 23 |
+
add_start_docstrings,
|
| 24 |
+
add_start_docstrings_to_model_forward,
|
| 25 |
+
is_flash_attn_2_available,
|
| 26 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 27 |
+
logging,
|
| 28 |
+
replace_return_docstrings,
|
| 29 |
+
)
|
| 30 |
+
from .configuration_qwen2 import Qwen2Config
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_flash_attn_2_available():
|
| 34 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
|
| 41 |
+
_CONFIG_FOR_DOC = "Qwen2Config"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
|
| 45 |
+
class Qwen2RMSNorm(nn.Module):
|
| 46 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 47 |
+
"""
|
| 48 |
+
Qwen2RMSNorm is equivalent to T5LayerNorm
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 52 |
+
self.variance_epsilon = eps
|
| 53 |
+
|
| 54 |
+
def forward(self, hidden_states):
|
| 55 |
+
input_dtype = hidden_states.dtype
|
| 56 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 57 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 58 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 59 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 60 |
+
|
| 61 |
+
def extra_repr(self):
|
| 62 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
|
| 66 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
dim=None,
|
| 70 |
+
max_position_embeddings=2048,
|
| 71 |
+
base=10000,
|
| 72 |
+
device=None,
|
| 73 |
+
scaling_factor=1.0,
|
| 74 |
+
rope_type="default",
|
| 75 |
+
config: Optional[Qwen2Config] = None,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 79 |
+
self.rope_kwargs = {}
|
| 80 |
+
if config is None:
|
| 81 |
+
logger.warning_once(
|
| 82 |
+
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 83 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 84 |
+
)
|
| 85 |
+
self.rope_kwargs = {
|
| 86 |
+
"rope_type": rope_type,
|
| 87 |
+
"factor": scaling_factor,
|
| 88 |
+
"dim": dim,
|
| 89 |
+
"base": base,
|
| 90 |
+
"max_position_embeddings": max_position_embeddings,
|
| 91 |
+
}
|
| 92 |
+
self.rope_type = rope_type
|
| 93 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 94 |
+
self.original_max_seq_len = max_position_embeddings
|
| 95 |
+
else:
|
| 96 |
+
# BC: "rope_type" was originally "type"
|
| 97 |
+
if config.rope_scaling is not None:
|
| 98 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 99 |
+
else:
|
| 100 |
+
self.rope_type = "default"
|
| 101 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 102 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 103 |
+
|
| 104 |
+
self.config = config
|
| 105 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 106 |
+
|
| 107 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 108 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 109 |
+
self.original_inv_freq = self.inv_freq
|
| 110 |
+
|
| 111 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 112 |
+
"""
|
| 113 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 114 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 115 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 116 |
+
"""
|
| 117 |
+
seq_len = torch.max(position_ids) + 1
|
| 118 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 119 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 120 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 121 |
+
)
|
| 122 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 123 |
+
self.max_seq_len_cached = seq_len
|
| 124 |
+
|
| 125 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 126 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 127 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
def forward(self, x, position_ids):
|
| 131 |
+
if "dynamic" in self.rope_type:
|
| 132 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 133 |
+
|
| 134 |
+
# Core RoPE block
|
| 135 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 136 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 137 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 138 |
+
device_type = x.device.type
|
| 139 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 140 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 141 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 142 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 143 |
+
cos = emb.cos()
|
| 144 |
+
sin = emb.sin()
|
| 145 |
+
|
| 146 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 147 |
+
cos = cos * self.attention_scaling
|
| 148 |
+
sin = sin * self.attention_scaling
|
| 149 |
+
|
| 150 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 154 |
+
def rotate_half(x):
|
| 155 |
+
"""Rotates half the hidden dims of the input."""
|
| 156 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 157 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 158 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 162 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 163 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
q (`torch.Tensor`): The query tensor.
|
| 167 |
+
k (`torch.Tensor`): The key tensor.
|
| 168 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 169 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 170 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 171 |
+
Deprecated and unused.
|
| 172 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 173 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 174 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 175 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 176 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 177 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 178 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 179 |
+
Returns:
|
| 180 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 181 |
+
"""
|
| 182 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 183 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 184 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 185 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 186 |
+
return q_embed, k_embed
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
|
| 190 |
+
class Qwen2MLP(nn.Module):
|
| 191 |
+
def __init__(self, config):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.hidden_size = config.hidden_size
|
| 194 |
+
self.intermediate_size = config.intermediate_size
|
| 195 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 196 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 197 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 198 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 199 |
+
|
| 200 |
+
def forward(self, hidden_state):
|
| 201 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 205 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 206 |
+
"""
|
| 207 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 208 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 209 |
+
"""
|
| 210 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 211 |
+
if n_rep == 1:
|
| 212 |
+
return hidden_states
|
| 213 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 214 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class Qwen2Attention(nn.Module):
|
| 218 |
+
"""
|
| 219 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 220 |
+
and "Generating Long Sequences with Sparse Transformers".
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.config = config
|
| 226 |
+
self.layer_idx = layer_idx
|
| 227 |
+
if layer_idx is None:
|
| 228 |
+
logger.warning_once(
|
| 229 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
| 230 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 231 |
+
"when creating this class."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.hidden_size = config.hidden_size
|
| 235 |
+
self.num_heads = config.num_attention_heads
|
| 236 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 237 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 238 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 239 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 240 |
+
self.rope_theta = config.rope_theta
|
| 241 |
+
self.is_causal = config.is_causal
|
| 242 |
+
self.attention_dropout = config.attention_dropout
|
| 243 |
+
|
| 244 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 247 |
+
f" and `num_heads`: {self.num_heads})."
|
| 248 |
+
)
|
| 249 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 250 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 251 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 252 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 253 |
+
|
| 254 |
+
def forward(
|
| 255 |
+
self,
|
| 256 |
+
hidden_states: torch.Tensor,
|
| 257 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 258 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 259 |
+
past_key_value: Optional[Cache] = None,
|
| 260 |
+
output_attentions: bool = False,
|
| 261 |
+
use_cache: bool = False,
|
| 262 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 263 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 264 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 265 |
+
bsz, q_len, _ = hidden_states.size()
|
| 266 |
+
|
| 267 |
+
query_states = self.q_proj(hidden_states)
|
| 268 |
+
key_states = self.k_proj(hidden_states)
|
| 269 |
+
value_states = self.v_proj(hidden_states)
|
| 270 |
+
|
| 271 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 272 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 273 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 274 |
+
|
| 275 |
+
if position_embeddings is None:
|
| 276 |
+
logger.warning_once(
|
| 277 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 278 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 279 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 280 |
+
"removed and `position_embeddings` will be mandatory."
|
| 281 |
+
)
|
| 282 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 283 |
+
else:
|
| 284 |
+
cos, sin = position_embeddings
|
| 285 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 286 |
+
|
| 287 |
+
if past_key_value is not None:
|
| 288 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 289 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 290 |
+
|
| 291 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 292 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 293 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 294 |
+
|
| 295 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 296 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 297 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 298 |
+
attn_weights = attn_weights + causal_mask
|
| 299 |
+
|
| 300 |
+
# upcast attention to fp32
|
| 301 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 302 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 303 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 304 |
+
|
| 305 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 308 |
+
f" {attn_output.size()}"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 312 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 313 |
+
|
| 314 |
+
attn_output = self.o_proj(attn_output)
|
| 315 |
+
|
| 316 |
+
if not output_attentions:
|
| 317 |
+
attn_weights = None
|
| 318 |
+
|
| 319 |
+
return attn_output, attn_weights, past_key_value
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class Qwen2FlashAttention2(Qwen2Attention):
|
| 323 |
+
"""
|
| 324 |
+
Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
|
| 325 |
+
as the weights of the module stays untouched. The only required change would be on the forward pass
|
| 326 |
+
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
| 327 |
+
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
| 328 |
+
config.max_window_layers layers.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
| 332 |
+
def __init__(self, *args, **kwargs):
|
| 333 |
+
super().__init__(*args, **kwargs)
|
| 334 |
+
|
| 335 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 336 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 337 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 338 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 339 |
+
|
| 340 |
+
def forward(
|
| 341 |
+
self,
|
| 342 |
+
hidden_states: torch.Tensor,
|
| 343 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 344 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 345 |
+
past_key_value: Optional[Cache] = None,
|
| 346 |
+
output_attentions: bool = False,
|
| 347 |
+
use_cache: bool = False,
|
| 348 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 349 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 350 |
+
):
|
| 351 |
+
bsz, q_len, _ = hidden_states.size()
|
| 352 |
+
|
| 353 |
+
query_states = self.q_proj(hidden_states)
|
| 354 |
+
key_states = self.k_proj(hidden_states)
|
| 355 |
+
value_states = self.v_proj(hidden_states)
|
| 356 |
+
|
| 357 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 358 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 359 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 360 |
+
|
| 361 |
+
if position_embeddings is None:
|
| 362 |
+
logger.warning_once(
|
| 363 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 364 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 365 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 366 |
+
"removed and `position_embeddings` will be mandatory."
|
| 367 |
+
)
|
| 368 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 369 |
+
else:
|
| 370 |
+
cos, sin = position_embeddings
|
| 371 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 372 |
+
|
| 373 |
+
if past_key_value is not None:
|
| 374 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
| 375 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 376 |
+
|
| 377 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 378 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 379 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 380 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 381 |
+
|
| 382 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 383 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 384 |
+
# cast them back in float16 just to be sure everything works as expected.
|
| 385 |
+
input_dtype = query_states.dtype
|
| 386 |
+
if input_dtype == torch.float32:
|
| 387 |
+
if torch.is_autocast_enabled():
|
| 388 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 389 |
+
# Handle the case where the model is quantized
|
| 390 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 391 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 392 |
+
else:
|
| 393 |
+
target_dtype = self.q_proj.weight.dtype
|
| 394 |
+
|
| 395 |
+
logger.warning_once(
|
| 396 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 397 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 398 |
+
f" {target_dtype}."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
query_states = query_states.to(target_dtype)
|
| 402 |
+
key_states = key_states.to(target_dtype)
|
| 403 |
+
value_states = value_states.to(target_dtype)
|
| 404 |
+
|
| 405 |
+
# Reashape to the expected shape for Flash Attention
|
| 406 |
+
query_states = query_states.transpose(1, 2)
|
| 407 |
+
key_states = key_states.transpose(1, 2)
|
| 408 |
+
value_states = value_states.transpose(1, 2)
|
| 409 |
+
|
| 410 |
+
if (
|
| 411 |
+
self.config.use_sliding_window
|
| 412 |
+
and getattr(self.config, "sliding_window", None) is not None
|
| 413 |
+
and self.layer_idx >= self.config.max_window_layers
|
| 414 |
+
):
|
| 415 |
+
sliding_window = self.config.sliding_window
|
| 416 |
+
else:
|
| 417 |
+
sliding_window = None
|
| 418 |
+
|
| 419 |
+
attn_output = _flash_attention_forward(
|
| 420 |
+
query_states,
|
| 421 |
+
key_states,
|
| 422 |
+
value_states,
|
| 423 |
+
attention_mask,
|
| 424 |
+
q_len,
|
| 425 |
+
position_ids=position_ids,
|
| 426 |
+
dropout=dropout_rate,
|
| 427 |
+
sliding_window=sliding_window,
|
| 428 |
+
is_causal=self.is_causal,
|
| 429 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 433 |
+
attn_output = self.o_proj(attn_output)
|
| 434 |
+
|
| 435 |
+
if not output_attentions:
|
| 436 |
+
attn_weights = None
|
| 437 |
+
|
| 438 |
+
return attn_output, attn_weights, past_key_value
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
QWEN2_ATTENTION_CLASSES = {
|
| 442 |
+
"eager": Qwen2Attention,
|
| 443 |
+
"flash_attention_2": Qwen2FlashAttention2,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class Qwen2DecoderLayer(nn.Module):
|
| 448 |
+
def __init__(self, config: Qwen2Config, layer_idx: int):
|
| 449 |
+
super().__init__()
|
| 450 |
+
self.hidden_size = config.hidden_size
|
| 451 |
+
|
| 452 |
+
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
| 453 |
+
logger.warning_once(
|
| 454 |
+
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
| 455 |
+
"unexpected results may be encountered."
|
| 456 |
+
)
|
| 457 |
+
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 458 |
+
|
| 459 |
+
self.mlp = Qwen2MLP(config)
|
| 460 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 461 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 462 |
+
|
| 463 |
+
def forward(
|
| 464 |
+
self,
|
| 465 |
+
hidden_states: torch.Tensor,
|
| 466 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 467 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 468 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 469 |
+
output_attentions: Optional[bool] = False,
|
| 470 |
+
use_cache: Optional[bool] = False,
|
| 471 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 472 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 473 |
+
**kwargs,
|
| 474 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 475 |
+
"""
|
| 476 |
+
Args:
|
| 477 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 478 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 479 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 480 |
+
output_attentions (`bool`, *optional*):
|
| 481 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 482 |
+
returned tensors for more detail.
|
| 483 |
+
use_cache (`bool`, *optional*):
|
| 484 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 485 |
+
(see `past_key_values`).
|
| 486 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 487 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 488 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 489 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 490 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 491 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 492 |
+
kwargs (`dict`, *optional*):
|
| 493 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 494 |
+
into the model
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
residual = hidden_states
|
| 498 |
+
|
| 499 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 500 |
+
|
| 501 |
+
# Self Attention
|
| 502 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 503 |
+
hidden_states=hidden_states,
|
| 504 |
+
attention_mask=attention_mask,
|
| 505 |
+
position_ids=position_ids,
|
| 506 |
+
past_key_value=past_key_value,
|
| 507 |
+
output_attentions=output_attentions,
|
| 508 |
+
use_cache=use_cache,
|
| 509 |
+
cache_position=cache_position,
|
| 510 |
+
position_embeddings=position_embeddings,
|
| 511 |
+
)
|
| 512 |
+
hidden_states = residual + hidden_states
|
| 513 |
+
|
| 514 |
+
# Fully Connected
|
| 515 |
+
residual = hidden_states
|
| 516 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 517 |
+
hidden_states = self.mlp(hidden_states)
|
| 518 |
+
hidden_states = residual + hidden_states
|
| 519 |
+
|
| 520 |
+
outputs = (hidden_states,)
|
| 521 |
+
|
| 522 |
+
if output_attentions:
|
| 523 |
+
outputs += (self_attn_weights,)
|
| 524 |
+
|
| 525 |
+
if use_cache:
|
| 526 |
+
outputs += (present_key_value,)
|
| 527 |
+
|
| 528 |
+
return outputs
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
QWEN2_START_DOCSTRING = r"""
|
| 532 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 533 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 534 |
+
etc.)
|
| 535 |
+
|
| 536 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 537 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 538 |
+
and behavior.
|
| 539 |
+
|
| 540 |
+
Parameters:
|
| 541 |
+
config ([`Qwen2Config`]):
|
| 542 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 543 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 544 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
@add_start_docstrings(
|
| 549 |
+
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
|
| 550 |
+
QWEN2_START_DOCSTRING,
|
| 551 |
+
)
|
| 552 |
+
class Qwen2PreTrainedModel(PreTrainedModel):
|
| 553 |
+
config_class = Qwen2Config
|
| 554 |
+
base_model_prefix = "model"
|
| 555 |
+
supports_gradient_checkpointing = True
|
| 556 |
+
_no_split_modules = ["Qwen2DecoderLayer"]
|
| 557 |
+
_skip_keys_device_placement = "past_key_values"
|
| 558 |
+
_supports_flash_attn_2 = True
|
| 559 |
+
_supports_cache_class = True
|
| 560 |
+
_supports_quantized_cache = True
|
| 561 |
+
_supports_static_cache = True
|
| 562 |
+
|
| 563 |
+
def _init_weights(self, module):
|
| 564 |
+
std = self.config.initializer_range
|
| 565 |
+
if isinstance(module, nn.Linear):
|
| 566 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 567 |
+
if module.bias is not None:
|
| 568 |
+
module.bias.data.zero_()
|
| 569 |
+
elif isinstance(module, nn.Embedding):
|
| 570 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 571 |
+
if module.padding_idx is not None:
|
| 572 |
+
module.weight.data[module.padding_idx].zero_()
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
QWEN2_INPUTS_DOCSTRING = r"""
|
| 576 |
+
Args:
|
| 577 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 578 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 579 |
+
it.
|
| 580 |
+
|
| 581 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 582 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 583 |
+
|
| 584 |
+
[What are input IDs?](../glossary#input-ids)
|
| 585 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 586 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 587 |
+
|
| 588 |
+
- 1 for tokens that are **not masked**,
|
| 589 |
+
- 0 for tokens that are **masked**.
|
| 590 |
+
|
| 591 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 592 |
+
|
| 593 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 594 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 595 |
+
|
| 596 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 597 |
+
`past_key_values`).
|
| 598 |
+
|
| 599 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 600 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 601 |
+
information on the default strategy.
|
| 602 |
+
|
| 603 |
+
- 1 indicates the head is **not masked**,
|
| 604 |
+
- 0 indicates the head is **masked**.
|
| 605 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 606 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 607 |
+
config.n_positions - 1]`.
|
| 608 |
+
|
| 609 |
+
[What are position IDs?](../glossary#position-ids)
|
| 610 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 611 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 612 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 613 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 614 |
+
|
| 615 |
+
Two formats are allowed:
|
| 616 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 617 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 618 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 619 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 620 |
+
cache format.
|
| 621 |
+
|
| 622 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 623 |
+
legacy cache format will be returned.
|
| 624 |
+
|
| 625 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 626 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 627 |
+
of shape `(batch_size, sequence_length)`.
|
| 628 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 629 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 630 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 631 |
+
model's internal embedding lookup matrix.
|
| 632 |
+
use_cache (`bool`, *optional*):
|
| 633 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 634 |
+
`past_key_values`).
|
| 635 |
+
output_attentions (`bool`, *optional*):
|
| 636 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 637 |
+
tensors for more detail.
|
| 638 |
+
output_hidden_states (`bool`, *optional*):
|
| 639 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 640 |
+
more detail.
|
| 641 |
+
return_dict (`bool`, *optional*):
|
| 642 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 643 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 644 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 645 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 646 |
+
the complete sequence length.
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
@add_start_docstrings(
|
| 651 |
+
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
|
| 652 |
+
QWEN2_START_DOCSTRING,
|
| 653 |
+
)
|
| 654 |
+
class Qwen2Model(Qwen2PreTrainedModel):
|
| 655 |
+
"""
|
| 656 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
config: Qwen2Config
|
| 660 |
+
"""
|
| 661 |
+
|
| 662 |
+
def __init__(self, config: Qwen2Config):
|
| 663 |
+
super().__init__(config)
|
| 664 |
+
self.padding_idx = config.pad_token_id
|
| 665 |
+
self.vocab_size = config.vocab_size
|
| 666 |
+
|
| 667 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 668 |
+
self.layers = nn.ModuleList(
|
| 669 |
+
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 670 |
+
)
|
| 671 |
+
self._attn_implementation = config._attn_implementation
|
| 672 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 673 |
+
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
| 674 |
+
|
| 675 |
+
self.gradient_checkpointing = False
|
| 676 |
+
# Initialize weights and apply final processing
|
| 677 |
+
self.post_init()
|
| 678 |
+
|
| 679 |
+
def get_input_embeddings(self):
|
| 680 |
+
return self.embed_tokens
|
| 681 |
+
|
| 682 |
+
def set_input_embeddings(self, value):
|
| 683 |
+
self.embed_tokens = value
|
| 684 |
+
|
| 685 |
+
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
| 686 |
+
def forward(
|
| 687 |
+
self,
|
| 688 |
+
input_ids: torch.LongTensor = None,
|
| 689 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 690 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 691 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 692 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 693 |
+
use_cache: Optional[bool] = None,
|
| 694 |
+
output_attentions: Optional[bool] = None,
|
| 695 |
+
output_hidden_states: Optional[bool] = None,
|
| 696 |
+
return_dict: Optional[bool] = None,
|
| 697 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 698 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 699 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 700 |
+
output_hidden_states = (
|
| 701 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 702 |
+
)
|
| 703 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 704 |
+
|
| 705 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 706 |
+
|
| 707 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 708 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 709 |
+
|
| 710 |
+
if self.gradient_checkpointing and self.training:
|
| 711 |
+
if use_cache:
|
| 712 |
+
logger.warning_once(
|
| 713 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 714 |
+
)
|
| 715 |
+
use_cache = False
|
| 716 |
+
|
| 717 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 718 |
+
return_legacy_cache = False
|
| 719 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 720 |
+
return_legacy_cache = True
|
| 721 |
+
if past_key_values is None:
|
| 722 |
+
past_key_values = DynamicCache()
|
| 723 |
+
else:
|
| 724 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 725 |
+
logger.warning_once(
|
| 726 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 727 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 728 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if inputs_embeds is None:
|
| 732 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 733 |
+
|
| 734 |
+
if cache_position is None:
|
| 735 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 736 |
+
cache_position = torch.arange(
|
| 737 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 738 |
+
)
|
| 739 |
+
if position_ids is None:
|
| 740 |
+
position_ids = cache_position.unsqueeze(0)
|
| 741 |
+
|
| 742 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 743 |
+
causal_mask = attention_mask
|
| 744 |
+
else:
|
| 745 |
+
causal_mask = None
|
| 746 |
+
|
| 747 |
+
hidden_states = inputs_embeds
|
| 748 |
+
# create position embeddings to be shared across the decoder layers
|
| 749 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 750 |
+
|
| 751 |
+
# decoder layers
|
| 752 |
+
all_hidden_states = () if output_hidden_states else None
|
| 753 |
+
all_self_attns = () if output_attentions else None
|
| 754 |
+
next_decoder_cache = None
|
| 755 |
+
|
| 756 |
+
for decoder_layer in self.layers:
|
| 757 |
+
if output_hidden_states:
|
| 758 |
+
all_hidden_states += (hidden_states,)
|
| 759 |
+
|
| 760 |
+
if self.gradient_checkpointing and self.training:
|
| 761 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 762 |
+
decoder_layer.__call__,
|
| 763 |
+
hidden_states,
|
| 764 |
+
causal_mask,
|
| 765 |
+
position_ids,
|
| 766 |
+
past_key_values,
|
| 767 |
+
output_attentions,
|
| 768 |
+
use_cache,
|
| 769 |
+
cache_position,
|
| 770 |
+
position_embeddings,
|
| 771 |
+
)
|
| 772 |
+
else:
|
| 773 |
+
layer_outputs = decoder_layer(
|
| 774 |
+
hidden_states,
|
| 775 |
+
attention_mask=causal_mask,
|
| 776 |
+
position_ids=position_ids,
|
| 777 |
+
past_key_value=past_key_values,
|
| 778 |
+
output_attentions=output_attentions,
|
| 779 |
+
use_cache=use_cache,
|
| 780 |
+
cache_position=cache_position,
|
| 781 |
+
position_embeddings=position_embeddings,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
hidden_states = layer_outputs[0]
|
| 785 |
+
|
| 786 |
+
if use_cache:
|
| 787 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 788 |
+
|
| 789 |
+
if output_attentions:
|
| 790 |
+
all_self_attns += (layer_outputs[1],)
|
| 791 |
+
|
| 792 |
+
hidden_states = self.norm(hidden_states)
|
| 793 |
+
|
| 794 |
+
# add hidden states from the last decoder layer
|
| 795 |
+
if output_hidden_states:
|
| 796 |
+
all_hidden_states += (hidden_states,)
|
| 797 |
+
|
| 798 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 799 |
+
if return_legacy_cache:
|
| 800 |
+
next_cache = next_cache.to_legacy_cache()
|
| 801 |
+
|
| 802 |
+
if not return_dict:
|
| 803 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 804 |
+
return BaseModelOutputWithPast(
|
| 805 |
+
last_hidden_state=hidden_states,
|
| 806 |
+
past_key_values=next_cache,
|
| 807 |
+
hidden_states=all_hidden_states,
|
| 808 |
+
attentions=all_self_attns,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
| 813 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 814 |
+
|
| 815 |
+
def __init__(self, config):
|
| 816 |
+
super().__init__(config)
|
| 817 |
+
self.model = Qwen2Model(config)
|
| 818 |
+
self.vocab_size = config.vocab_size
|
| 819 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 820 |
+
|
| 821 |
+
# Initialize weights and apply final processing
|
| 822 |
+
self.post_init()
|
| 823 |
+
|
| 824 |
+
def get_input_embeddings(self):
|
| 825 |
+
return self.model.embed_tokens
|
| 826 |
+
|
| 827 |
+
def set_input_embeddings(self, value):
|
| 828 |
+
self.model.embed_tokens = value
|
| 829 |
+
|
| 830 |
+
def get_output_embeddings(self):
|
| 831 |
+
return self.lm_head
|
| 832 |
+
|
| 833 |
+
def set_output_embeddings(self, new_embeddings):
|
| 834 |
+
self.lm_head = new_embeddings
|
| 835 |
+
|
| 836 |
+
def set_decoder(self, decoder):
|
| 837 |
+
self.model = decoder
|
| 838 |
+
|
| 839 |
+
def get_decoder(self):
|
| 840 |
+
return self.model
|
| 841 |
+
|
| 842 |
+
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
| 843 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 844 |
+
def forward(
|
| 845 |
+
self,
|
| 846 |
+
input_ids: torch.LongTensor = None,
|
| 847 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 848 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 849 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 850 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 851 |
+
labels: Optional[torch.LongTensor] = None,
|
| 852 |
+
use_cache: Optional[bool] = None,
|
| 853 |
+
output_attentions: Optional[bool] = None,
|
| 854 |
+
output_hidden_states: Optional[bool] = None,
|
| 855 |
+
return_dict: Optional[bool] = None,
|
| 856 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 857 |
+
num_logits_to_keep: int = 0,
|
| 858 |
+
**loss_kwargs,
|
| 859 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 860 |
+
r"""
|
| 861 |
+
Args:
|
| 862 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 863 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 864 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 865 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 866 |
+
|
| 867 |
+
num_logits_to_keep (`int`, *optional*):
|
| 868 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
| 869 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 870 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 871 |
+
|
| 872 |
+
Returns:
|
| 873 |
+
|
| 874 |
+
Example:
|
| 875 |
+
|
| 876 |
+
```python
|
| 877 |
+
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
| 878 |
+
|
| 879 |
+
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 880 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 881 |
+
|
| 882 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 883 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 884 |
+
|
| 885 |
+
>>> # Generate
|
| 886 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 887 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 888 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 889 |
+
```"""
|
| 890 |
+
|
| 891 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 892 |
+
output_hidden_states = (
|
| 893 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 894 |
+
)
|
| 895 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 896 |
+
|
| 897 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 898 |
+
outputs = self.model(
|
| 899 |
+
input_ids=input_ids,
|
| 900 |
+
attention_mask=attention_mask,
|
| 901 |
+
position_ids=position_ids,
|
| 902 |
+
past_key_values=past_key_values,
|
| 903 |
+
inputs_embeds=inputs_embeds,
|
| 904 |
+
use_cache=use_cache,
|
| 905 |
+
output_attentions=output_attentions,
|
| 906 |
+
output_hidden_states=output_hidden_states,
|
| 907 |
+
return_dict=return_dict,
|
| 908 |
+
cache_position=cache_position,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
hidden_states = outputs[0]
|
| 912 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 913 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 914 |
+
|
| 915 |
+
loss = None
|
| 916 |
+
if labels is not None:
|
| 917 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 918 |
+
|
| 919 |
+
if not return_dict:
|
| 920 |
+
output = (logits,) + outputs[1:]
|
| 921 |
+
return (loss,) + output if loss is not None else output
|
| 922 |
+
|
| 923 |
+
return CausalLMOutputWithPast(
|
| 924 |
+
loss=loss,
|
| 925 |
+
logits=logits,
|
| 926 |
+
past_key_values=outputs.past_key_values,
|
| 927 |
+
hidden_states=outputs.hidden_states,
|
| 928 |
+
attentions=outputs.attentions,
|
| 929 |
+
)
|
modeling/qwen2/tokenization_qwen2.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Tokenization classes for Qwen2."""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import unicodedata
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import regex as re
|
| 13 |
+
|
| 14 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 15 |
+
from transformers.utils import logging
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
VOCAB_FILES_NAMES = {
|
| 21 |
+
"vocab_file": "vocab.json",
|
| 22 |
+
"merges_file": "merges.txt",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
|
| 27 |
+
|
| 28 |
+
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@lru_cache()
|
| 32 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
| 33 |
+
def bytes_to_unicode():
|
| 34 |
+
"""
|
| 35 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 36 |
+
characters the bpe code barfs on.
|
| 37 |
+
|
| 38 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 39 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 40 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 41 |
+
tables between utf-8 bytes and unicode strings.
|
| 42 |
+
"""
|
| 43 |
+
bs = (
|
| 44 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 45 |
+
)
|
| 46 |
+
cs = bs[:]
|
| 47 |
+
n = 0
|
| 48 |
+
for b in range(2**8):
|
| 49 |
+
if b not in bs:
|
| 50 |
+
bs.append(b)
|
| 51 |
+
cs.append(2**8 + n)
|
| 52 |
+
n += 1
|
| 53 |
+
cs = [chr(n) for n in cs]
|
| 54 |
+
return dict(zip(bs, cs))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
| 58 |
+
def get_pairs(word):
|
| 59 |
+
"""
|
| 60 |
+
Return set of symbol pairs in a word.
|
| 61 |
+
|
| 62 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 63 |
+
"""
|
| 64 |
+
pairs = set()
|
| 65 |
+
prev_char = word[0]
|
| 66 |
+
for char in word[1:]:
|
| 67 |
+
pairs.add((prev_char, char))
|
| 68 |
+
prev_char = char
|
| 69 |
+
return pairs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Qwen2Tokenizer(PreTrainedTokenizer):
|
| 73 |
+
"""
|
| 74 |
+
Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 75 |
+
|
| 76 |
+
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
|
| 77 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import Qwen2Tokenizer
|
| 81 |
+
|
| 82 |
+
>>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
|
| 83 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 84 |
+
[9707, 1879]
|
| 85 |
+
|
| 86 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 87 |
+
[21927, 1879]
|
| 88 |
+
```
|
| 89 |
+
This is expected.
|
| 90 |
+
|
| 91 |
+
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
|
| 92 |
+
|
| 93 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 94 |
+
this superclass for more information regarding those methods.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
vocab_file (`str`):
|
| 98 |
+
Path to the vocabulary file.
|
| 99 |
+
merges_file (`str`):
|
| 100 |
+
Path to the merges file.
|
| 101 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 102 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 103 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 104 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 105 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 106 |
+
token instead.
|
| 107 |
+
bos_token (`str`, *optional*):
|
| 108 |
+
The beginning of sequence token. Not applicable for this tokenizer.
|
| 109 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 110 |
+
The end of sequence token.
|
| 111 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 112 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 113 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 114 |
+
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
| 115 |
+
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
|
| 116 |
+
split_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 117 |
+
Whether or not the special tokens should be split during the tokenization process. The default behavior is
|
| 118 |
+
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
|
| 119 |
+
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
|
| 120 |
+
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 124 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
vocab_file,
|
| 129 |
+
merges_file,
|
| 130 |
+
errors="replace",
|
| 131 |
+
unk_token="<|endoftext|>",
|
| 132 |
+
bos_token=None,
|
| 133 |
+
eos_token="<|endoftext|>",
|
| 134 |
+
pad_token="<|endoftext|>",
|
| 135 |
+
clean_up_tokenization_spaces=False,
|
| 136 |
+
split_special_tokens=False,
|
| 137 |
+
**kwargs,
|
| 138 |
+
):
|
| 139 |
+
# Qwen vocab does not contain control tokens; added tokens need to be special
|
| 140 |
+
bos_token = (
|
| 141 |
+
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 142 |
+
if isinstance(bos_token, str)
|
| 143 |
+
else bos_token
|
| 144 |
+
)
|
| 145 |
+
eos_token = (
|
| 146 |
+
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 147 |
+
if isinstance(eos_token, str)
|
| 148 |
+
else eos_token
|
| 149 |
+
)
|
| 150 |
+
unk_token = (
|
| 151 |
+
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 152 |
+
if isinstance(unk_token, str)
|
| 153 |
+
else unk_token
|
| 154 |
+
)
|
| 155 |
+
pad_token = (
|
| 156 |
+
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 157 |
+
if isinstance(pad_token, str)
|
| 158 |
+
else pad_token
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 162 |
+
self.encoder = json.load(vocab_handle)
|
| 163 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 164 |
+
self.errors = errors # how to handle errors in decoding
|
| 165 |
+
self.byte_encoder = bytes_to_unicode()
|
| 166 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 167 |
+
bpe_merges = []
|
| 168 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 169 |
+
for i, line in enumerate(merges_handle):
|
| 170 |
+
line = line.strip()
|
| 171 |
+
if (i == 0 and line.startswith("#version:")) or not line:
|
| 172 |
+
continue
|
| 173 |
+
bpe_merges.append(tuple(line.split()))
|
| 174 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 175 |
+
# NOTE: the cache can grow without bound and will get really large for long running processes
|
| 176 |
+
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
|
| 177 |
+
# not a memory leak but appears as one.
|
| 178 |
+
# GPT2Tokenizer has the same problem, so let's be consistent.
|
| 179 |
+
self.cache = {}
|
| 180 |
+
|
| 181 |
+
self.pat = re.compile(PRETOKENIZE_REGEX)
|
| 182 |
+
|
| 183 |
+
if kwargs.get("add_prefix_space", False):
|
| 184 |
+
logger.warning_once(
|
| 185 |
+
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
super().__init__(
|
| 189 |
+
errors=errors,
|
| 190 |
+
bos_token=bos_token,
|
| 191 |
+
eos_token=eos_token,
|
| 192 |
+
pad_token=pad_token,
|
| 193 |
+
unk_token=unk_token,
|
| 194 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 195 |
+
split_special_tokens=split_special_tokens,
|
| 196 |
+
**kwargs,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def vocab_size(self) -> int:
|
| 201 |
+
return len(self.encoder)
|
| 202 |
+
|
| 203 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
|
| 204 |
+
def get_vocab(self):
|
| 205 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 206 |
+
|
| 207 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
| 208 |
+
def bpe(self, token):
|
| 209 |
+
if token in self.cache:
|
| 210 |
+
return self.cache[token]
|
| 211 |
+
word = tuple(token)
|
| 212 |
+
pairs = get_pairs(word)
|
| 213 |
+
|
| 214 |
+
if not pairs:
|
| 215 |
+
return token
|
| 216 |
+
|
| 217 |
+
while True:
|
| 218 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 219 |
+
if bigram not in self.bpe_ranks:
|
| 220 |
+
break
|
| 221 |
+
first, second = bigram
|
| 222 |
+
new_word = []
|
| 223 |
+
i = 0
|
| 224 |
+
while i < len(word):
|
| 225 |
+
try:
|
| 226 |
+
j = word.index(first, i)
|
| 227 |
+
except ValueError:
|
| 228 |
+
new_word.extend(word[i:])
|
| 229 |
+
break
|
| 230 |
+
else:
|
| 231 |
+
new_word.extend(word[i:j])
|
| 232 |
+
i = j
|
| 233 |
+
|
| 234 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 235 |
+
new_word.append(first + second)
|
| 236 |
+
i += 2
|
| 237 |
+
else:
|
| 238 |
+
new_word.append(word[i])
|
| 239 |
+
i += 1
|
| 240 |
+
new_word = tuple(new_word)
|
| 241 |
+
word = new_word
|
| 242 |
+
if len(word) == 1:
|
| 243 |
+
break
|
| 244 |
+
else:
|
| 245 |
+
pairs = get_pairs(word)
|
| 246 |
+
word = " ".join(word)
|
| 247 |
+
self.cache[token] = word
|
| 248 |
+
return word
|
| 249 |
+
|
| 250 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
|
| 251 |
+
def _tokenize(self, text):
|
| 252 |
+
"""Tokenize a string."""
|
| 253 |
+
bpe_tokens = []
|
| 254 |
+
for token in re.findall(self.pat, text):
|
| 255 |
+
token = "".join(
|
| 256 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
| 257 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
| 258 |
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
| 259 |
+
return bpe_tokens
|
| 260 |
+
|
| 261 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
|
| 262 |
+
def _convert_token_to_id(self, token):
|
| 263 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 264 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 265 |
+
|
| 266 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
|
| 267 |
+
def _convert_id_to_token(self, index):
|
| 268 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 269 |
+
return self.decoder.get(index)
|
| 270 |
+
|
| 271 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
|
| 272 |
+
def convert_tokens_to_string(self, tokens):
|
| 273 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 274 |
+
text = "".join(tokens)
|
| 275 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
| 276 |
+
return text
|
| 277 |
+
|
| 278 |
+
def decode(
|
| 279 |
+
self,
|
| 280 |
+
token_ids,
|
| 281 |
+
skip_special_tokens: bool = False,
|
| 282 |
+
clean_up_tokenization_spaces: Optional[bool] = False,
|
| 283 |
+
spaces_between_special_tokens: bool = False,
|
| 284 |
+
**kwargs,
|
| 285 |
+
) -> str:
|
| 286 |
+
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
|
| 287 |
+
# and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
|
| 288 |
+
return super().decode(
|
| 289 |
+
token_ids,
|
| 290 |
+
skip_special_tokens=skip_special_tokens,
|
| 291 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 292 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 293 |
+
**kwargs,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
|
| 297 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 298 |
+
if not os.path.isdir(save_directory):
|
| 299 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 300 |
+
return
|
| 301 |
+
vocab_file = os.path.join(
|
| 302 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 303 |
+
)
|
| 304 |
+
merge_file = os.path.join(
|
| 305 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 309 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 310 |
+
|
| 311 |
+
index = 0
|
| 312 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 313 |
+
writer.write("#version: 0.2\n")
|
| 314 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 315 |
+
if index != token_index:
|
| 316 |
+
logger.warning(
|
| 317 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 318 |
+
" Please check that the tokenizer is not corrupted!"
|
| 319 |
+
)
|
| 320 |
+
index = token_index
|
| 321 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 322 |
+
index += 1
|
| 323 |
+
|
| 324 |
+
return vocab_file, merge_file
|
| 325 |
+
|
| 326 |
+
def prepare_for_tokenization(self, text, **kwargs):
|
| 327 |
+
text = unicodedata.normalize("NFC", text)
|
| 328 |
+
return (text, kwargs)
|
modeling/qwen2/tokenization_qwen2_fast.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Tokenization classes for Qwen2."""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from transformers.tokenization_utils import AddedToken
|
| 9 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
from .tokenization_qwen2 import Qwen2Tokenizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
VOCAB_FILES_NAMES = {
|
| 17 |
+
"vocab_file": "vocab.json",
|
| 18 |
+
"merges_file": "merges.txt",
|
| 19 |
+
"tokenizer_file": "tokenizer.json",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Qwen2TokenizerFast(PreTrainedTokenizerFast):
|
| 27 |
+
"""
|
| 28 |
+
Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
| 29 |
+
Byte-Pair-Encoding.
|
| 30 |
+
|
| 31 |
+
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
|
| 32 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
>>> from transformers import Qwen2TokenizerFast
|
| 36 |
+
|
| 37 |
+
>>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
|
| 38 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 39 |
+
[9707, 1879]
|
| 40 |
+
|
| 41 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 42 |
+
[21927, 1879]
|
| 43 |
+
```
|
| 44 |
+
This is expected.
|
| 45 |
+
|
| 46 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 47 |
+
refer to this superclass for more information regarding those methods.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
vocab_file (`str`, *optional*):
|
| 51 |
+
Path to the vocabulary file.
|
| 52 |
+
merges_file (`str`, *optional*):
|
| 53 |
+
Path to the merges file.
|
| 54 |
+
tokenizer_file (`str`, *optional*):
|
| 55 |
+
Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
| 56 |
+
contains everything needed to load the tokenizer.
|
| 57 |
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 58 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 59 |
+
token instead. Not applicable to this tokenizer.
|
| 60 |
+
bos_token (`str`, *optional*):
|
| 61 |
+
The beginning of sequence token. Not applicable for this tokenizer.
|
| 62 |
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 63 |
+
The end of sequence token.
|
| 64 |
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
| 65 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 69 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 70 |
+
slow_tokenizer_class = Qwen2Tokenizer
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
vocab_file=None,
|
| 75 |
+
merges_file=None,
|
| 76 |
+
tokenizer_file=None,
|
| 77 |
+
unk_token="<|endoftext|>",
|
| 78 |
+
bos_token=None,
|
| 79 |
+
eos_token="<|endoftext|>",
|
| 80 |
+
pad_token="<|endoftext|>",
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
# We need to at least pass vocab_file and merges_file to base class
|
| 84 |
+
# in case a slow tokenizer needs to be initialized; other can be
|
| 85 |
+
# configured through files.
|
| 86 |
+
# following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
|
| 87 |
+
|
| 88 |
+
bos_token = (
|
| 89 |
+
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 90 |
+
if isinstance(bos_token, str)
|
| 91 |
+
else bos_token
|
| 92 |
+
)
|
| 93 |
+
eos_token = (
|
| 94 |
+
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 95 |
+
if isinstance(eos_token, str)
|
| 96 |
+
else eos_token
|
| 97 |
+
)
|
| 98 |
+
unk_token = (
|
| 99 |
+
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 100 |
+
if isinstance(unk_token, str)
|
| 101 |
+
else unk_token
|
| 102 |
+
)
|
| 103 |
+
pad_token = (
|
| 104 |
+
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
| 105 |
+
if isinstance(pad_token, str)
|
| 106 |
+
else pad_token
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
super().__init__(
|
| 110 |
+
vocab_file=vocab_file,
|
| 111 |
+
merges_file=merges_file,
|
| 112 |
+
tokenizer_file=tokenizer_file,
|
| 113 |
+
unk_token=unk_token,
|
| 114 |
+
bos_token=bos_token,
|
| 115 |
+
eos_token=eos_token,
|
| 116 |
+
pad_token=pad_token,
|
| 117 |
+
**kwargs,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
|
| 121 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 122 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 123 |
+
return tuple(files)
|
modeling/siglip/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from transformers.utils import (
|
| 7 |
+
OptionalDependencyNotAvailable,
|
| 8 |
+
_LazyModule,
|
| 9 |
+
is_sentencepiece_available,
|
| 10 |
+
is_torch_available,
|
| 11 |
+
is_vision_available,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_import_structure = {
|
| 16 |
+
"configuration_siglip": [
|
| 17 |
+
"SiglipConfig",
|
| 18 |
+
"SiglipTextConfig",
|
| 19 |
+
"SiglipVisionConfig",
|
| 20 |
+
],
|
| 21 |
+
"processing_siglip": ["SiglipProcessor"],
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
if not is_sentencepiece_available():
|
| 26 |
+
raise OptionalDependencyNotAvailable()
|
| 27 |
+
except OptionalDependencyNotAvailable:
|
| 28 |
+
pass
|
| 29 |
+
else:
|
| 30 |
+
_import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
if not is_vision_available():
|
| 35 |
+
raise OptionalDependencyNotAvailable()
|
| 36 |
+
except OptionalDependencyNotAvailable:
|
| 37 |
+
pass
|
| 38 |
+
else:
|
| 39 |
+
_import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
if not is_torch_available():
|
| 43 |
+
raise OptionalDependencyNotAvailable()
|
| 44 |
+
except OptionalDependencyNotAvailable:
|
| 45 |
+
pass
|
| 46 |
+
else:
|
| 47 |
+
_import_structure["modeling_siglip"] = [
|
| 48 |
+
"SiglipModel",
|
| 49 |
+
"SiglipPreTrainedModel",
|
| 50 |
+
"SiglipTextModel",
|
| 51 |
+
"SiglipVisionModel",
|
| 52 |
+
"SiglipForImageClassification",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if TYPE_CHECKING:
|
| 57 |
+
from .configuration_siglip import (
|
| 58 |
+
SiglipConfig,
|
| 59 |
+
SiglipTextConfig,
|
| 60 |
+
SiglipVisionConfig,
|
| 61 |
+
)
|
| 62 |
+
from .processing_siglip import SiglipProcessor
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
if not is_sentencepiece_available():
|
| 66 |
+
raise OptionalDependencyNotAvailable()
|
| 67 |
+
except OptionalDependencyNotAvailable:
|
| 68 |
+
pass
|
| 69 |
+
else:
|
| 70 |
+
from .tokenization_siglip import SiglipTokenizer
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if not is_vision_available():
|
| 74 |
+
raise OptionalDependencyNotAvailable()
|
| 75 |
+
except OptionalDependencyNotAvailable:
|
| 76 |
+
pass
|
| 77 |
+
else:
|
| 78 |
+
from .image_processing_siglip import SiglipImageProcessor
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
if not is_torch_available():
|
| 82 |
+
raise OptionalDependencyNotAvailable()
|
| 83 |
+
except OptionalDependencyNotAvailable:
|
| 84 |
+
pass
|
| 85 |
+
else:
|
| 86 |
+
from .modeling_siglip import (
|
| 87 |
+
SiglipForImageClassification,
|
| 88 |
+
SiglipModel,
|
| 89 |
+
SiglipPreTrainedModel,
|
| 90 |
+
SiglipTextModel,
|
| 91 |
+
SiglipVisionModel,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
import sys
|
| 97 |
+
|
| 98 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
modeling/siglip/configuration_siglip.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Siglip model configuration"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiglipTextConfig(PretrainedConfig):
|
| 17 |
+
r"""
|
| 18 |
+
This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
|
| 19 |
+
Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 20 |
+
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
|
| 21 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 22 |
+
|
| 23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 24 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 28 |
+
Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
|
| 29 |
+
the `inputs_ids` passed when calling [`SiglipModel`].
|
| 30 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 31 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 32 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 33 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 34 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 35 |
+
Number of hidden layers in the Transformer encoder.
|
| 36 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 37 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 38 |
+
max_position_embeddings (`int`, *optional*, defaults to 64):
|
| 39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 42 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 43 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 44 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 45 |
+
The epsilon used by the layer normalization layers.
|
| 46 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 47 |
+
The dropout ratio for the attention probabilities.
|
| 48 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 49 |
+
The id of the padding token in the vocabulary.
|
| 50 |
+
bos_token_id (`int`, *optional*, defaults to 49406):
|
| 51 |
+
The id of the beginning-of-sequence token in the vocabulary.
|
| 52 |
+
eos_token_id (`int`, *optional*, defaults to 49407):
|
| 53 |
+
The id of the end-of-sequence token in the vocabulary.
|
| 54 |
+
|
| 55 |
+
Example:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
>>> from transformers import SiglipTextConfig, SiglipTextModel
|
| 59 |
+
|
| 60 |
+
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
|
| 61 |
+
>>> configuration = SiglipTextConfig()
|
| 62 |
+
|
| 63 |
+
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 64 |
+
>>> model = SiglipTextModel(configuration)
|
| 65 |
+
|
| 66 |
+
>>> # Accessing the model configuration
|
| 67 |
+
>>> configuration = model.config
|
| 68 |
+
```"""
|
| 69 |
+
|
| 70 |
+
model_type = "siglip_text_model"
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
vocab_size=32000,
|
| 75 |
+
hidden_size=768,
|
| 76 |
+
intermediate_size=3072,
|
| 77 |
+
num_hidden_layers=12,
|
| 78 |
+
num_attention_heads=12,
|
| 79 |
+
max_position_embeddings=64,
|
| 80 |
+
hidden_act="gelu_pytorch_tanh",
|
| 81 |
+
layer_norm_eps=1e-6,
|
| 82 |
+
attention_dropout=0.0,
|
| 83 |
+
# This differs from `CLIPTokenizer`'s default and from openai/siglip
|
| 84 |
+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
|
| 85 |
+
pad_token_id=1,
|
| 86 |
+
bos_token_id=49406,
|
| 87 |
+
eos_token_id=49407,
|
| 88 |
+
**kwargs,
|
| 89 |
+
):
|
| 90 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 91 |
+
|
| 92 |
+
self.vocab_size = vocab_size
|
| 93 |
+
self.hidden_size = hidden_size
|
| 94 |
+
self.intermediate_size = intermediate_size
|
| 95 |
+
self.num_hidden_layers = num_hidden_layers
|
| 96 |
+
self.num_attention_heads = num_attention_heads
|
| 97 |
+
self.max_position_embeddings = max_position_embeddings
|
| 98 |
+
self.layer_norm_eps = layer_norm_eps
|
| 99 |
+
self.hidden_act = hidden_act
|
| 100 |
+
self.attention_dropout = attention_dropout
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 104 |
+
cls._set_token_in_kwargs(kwargs)
|
| 105 |
+
|
| 106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 107 |
+
|
| 108 |
+
# get the text config dict if we are loading from SiglipConfig
|
| 109 |
+
if config_dict.get("model_type") == "siglip":
|
| 110 |
+
config_dict = config_dict["text_config"]
|
| 111 |
+
|
| 112 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 113 |
+
logger.warning(
|
| 114 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 115 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class SiglipVisionConfig(PretrainedConfig):
|
| 122 |
+
r"""
|
| 123 |
+
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
| 124 |
+
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 125 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
|
| 126 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 127 |
+
|
| 128 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 129 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 133 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 134 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 135 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 136 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 137 |
+
Number of hidden layers in the Transformer encoder.
|
| 138 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 139 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 140 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 141 |
+
Number of channels in the input images.
|
| 142 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 143 |
+
The size (resolution) of each image.
|
| 144 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 145 |
+
The size (resolution) of each patch.
|
| 146 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 147 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 148 |
+
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
| 149 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 150 |
+
The epsilon used by the layer normalization layers.
|
| 151 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 152 |
+
The dropout ratio for the attention probabilities.
|
| 153 |
+
|
| 154 |
+
Example:
|
| 155 |
+
|
| 156 |
+
```python
|
| 157 |
+
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
| 158 |
+
|
| 159 |
+
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
|
| 160 |
+
>>> configuration = SiglipVisionConfig()
|
| 161 |
+
|
| 162 |
+
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 163 |
+
>>> model = SiglipVisionModel(configuration)
|
| 164 |
+
|
| 165 |
+
>>> # Accessing the model configuration
|
| 166 |
+
>>> configuration = model.config
|
| 167 |
+
```"""
|
| 168 |
+
|
| 169 |
+
model_type = "siglip_vision_model"
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
hidden_size=768,
|
| 174 |
+
intermediate_size=3072,
|
| 175 |
+
num_hidden_layers=12,
|
| 176 |
+
num_attention_heads=12,
|
| 177 |
+
num_channels=3,
|
| 178 |
+
image_size=224,
|
| 179 |
+
patch_size=16,
|
| 180 |
+
hidden_act="gelu_pytorch_tanh",
|
| 181 |
+
layer_norm_eps=1e-6,
|
| 182 |
+
attention_dropout=0.0,
|
| 183 |
+
**kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__(**kwargs)
|
| 186 |
+
|
| 187 |
+
self.hidden_size = hidden_size
|
| 188 |
+
self.intermediate_size = intermediate_size
|
| 189 |
+
self.num_hidden_layers = num_hidden_layers
|
| 190 |
+
self.num_attention_heads = num_attention_heads
|
| 191 |
+
self.num_channels = num_channels
|
| 192 |
+
self.patch_size = patch_size
|
| 193 |
+
self.image_size = image_size
|
| 194 |
+
self.attention_dropout = attention_dropout
|
| 195 |
+
self.layer_norm_eps = layer_norm_eps
|
| 196 |
+
self.hidden_act = hidden_act
|
| 197 |
+
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 200 |
+
cls._set_token_in_kwargs(kwargs)
|
| 201 |
+
|
| 202 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 203 |
+
|
| 204 |
+
# get the vision config dict if we are loading from SiglipConfig
|
| 205 |
+
if config_dict.get("model_type") == "siglip":
|
| 206 |
+
config_dict = config_dict["vision_config"]
|
| 207 |
+
|
| 208 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 209 |
+
logger.warning(
|
| 210 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 211 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class SiglipConfig(PretrainedConfig):
|
| 218 |
+
r"""
|
| 219 |
+
[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
|
| 220 |
+
instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
|
| 221 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
|
| 222 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 223 |
+
|
| 224 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 225 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
text_config (`dict`, *optional*):
|
| 229 |
+
Dictionary of configuration options used to initialize [`SiglipTextConfig`].
|
| 230 |
+
vision_config (`dict`, *optional*):
|
| 231 |
+
Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
|
| 232 |
+
kwargs (*optional*):
|
| 233 |
+
Dictionary of keyword arguments.
|
| 234 |
+
|
| 235 |
+
Example:
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
>>> from transformers import SiglipConfig, SiglipModel
|
| 239 |
+
|
| 240 |
+
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
|
| 241 |
+
>>> configuration = SiglipConfig()
|
| 242 |
+
|
| 243 |
+
>>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 244 |
+
>>> model = SiglipModel(configuration)
|
| 245 |
+
|
| 246 |
+
>>> # Accessing the model configuration
|
| 247 |
+
>>> configuration = model.config
|
| 248 |
+
|
| 249 |
+
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
|
| 250 |
+
>>> from transformers import SiglipTextConfig, SiglipVisionConfig
|
| 251 |
+
|
| 252 |
+
>>> # Initializing a SiglipText and SiglipVision configuration
|
| 253 |
+
>>> config_text = SiglipTextConfig()
|
| 254 |
+
>>> config_vision = SiglipVisionConfig()
|
| 255 |
+
|
| 256 |
+
>>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
|
| 257 |
+
```"""
|
| 258 |
+
|
| 259 |
+
model_type = "siglip"
|
| 260 |
+
|
| 261 |
+
def __init__(self, text_config=None, vision_config=None, **kwargs):
|
| 262 |
+
super().__init__(**kwargs)
|
| 263 |
+
|
| 264 |
+
if text_config is None:
|
| 265 |
+
text_config = {}
|
| 266 |
+
logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
|
| 267 |
+
|
| 268 |
+
if vision_config is None:
|
| 269 |
+
vision_config = {}
|
| 270 |
+
logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
|
| 271 |
+
|
| 272 |
+
self.text_config = SiglipTextConfig(**text_config)
|
| 273 |
+
self.vision_config = SiglipVisionConfig(**vision_config)
|
| 274 |
+
|
| 275 |
+
self.initializer_factor = 1.0
|
| 276 |
+
|
| 277 |
+
@classmethod
|
| 278 |
+
def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
|
| 279 |
+
r"""
|
| 280 |
+
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
|
| 281 |
+
model configuration.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
[`SiglipConfig`]: An instance of a configuration object
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
modeling/siglip/convert_siglip_to_hf.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Convert SigLIP checkpoints from the original repository.
|
| 5 |
+
|
| 6 |
+
URL: https://github.com/google-research/big_vision/tree/main
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import collections
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import requests
|
| 15 |
+
import torch
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
from numpy import load
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
|
| 21 |
+
from transformers.utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logging.set_verbosity_info()
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
model_name_to_checkpoint = {
|
| 29 |
+
# base checkpoints
|
| 30 |
+
"siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
|
| 31 |
+
"siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
|
| 32 |
+
"siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
|
| 33 |
+
"siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
|
| 34 |
+
# large checkpoints
|
| 35 |
+
"siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
|
| 36 |
+
"siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
|
| 37 |
+
# multilingual checkpoint
|
| 38 |
+
"siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
|
| 39 |
+
# so400m checkpoints
|
| 40 |
+
"siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
model_name_to_image_size = {
|
| 44 |
+
"siglip-base-patch16-224": 224,
|
| 45 |
+
"siglip-base-patch16-256": 256,
|
| 46 |
+
"siglip-base-patch16-384": 384,
|
| 47 |
+
"siglip-base-patch16-512": 512,
|
| 48 |
+
"siglip-large-patch16-256": 256,
|
| 49 |
+
"siglip-large-patch16-384": 384,
|
| 50 |
+
"siglip-base-patch16-256-i18n": 256,
|
| 51 |
+
"siglip-so400m-patch14-384": 384,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_siglip_config(model_name):
|
| 56 |
+
config = SiglipConfig()
|
| 57 |
+
|
| 58 |
+
vocab_size = 250000 if "i18n" in model_name else 32000
|
| 59 |
+
image_size = model_name_to_image_size[model_name]
|
| 60 |
+
patch_size = 16 if "patch16" in model_name else 14
|
| 61 |
+
|
| 62 |
+
# size of the architecture
|
| 63 |
+
config.vision_config.image_size = image_size
|
| 64 |
+
config.vision_config.patch_size = patch_size
|
| 65 |
+
config.text_config.vocab_size = vocab_size
|
| 66 |
+
|
| 67 |
+
if "base" in model_name:
|
| 68 |
+
pass
|
| 69 |
+
elif "large" in model_name:
|
| 70 |
+
config.text_config.hidden_size = 1024
|
| 71 |
+
config.text_config.intermediate_size = 4096
|
| 72 |
+
config.text_config.num_hidden_layers = 24
|
| 73 |
+
config.text_config.num_attention_heads = 16
|
| 74 |
+
config.vision_config.hidden_size = 1024
|
| 75 |
+
config.vision_config.intermediate_size = 4096
|
| 76 |
+
config.vision_config.num_hidden_layers = 24
|
| 77 |
+
config.vision_config.num_attention_heads = 16
|
| 78 |
+
elif "so400m" in model_name:
|
| 79 |
+
config.text_config.hidden_size = 1152
|
| 80 |
+
config.text_config.intermediate_size = 4304
|
| 81 |
+
config.text_config.num_hidden_layers = 27
|
| 82 |
+
config.text_config.num_attention_heads = 16
|
| 83 |
+
config.vision_config.hidden_size = 1152
|
| 84 |
+
config.vision_config.intermediate_size = 4304
|
| 85 |
+
config.vision_config.num_hidden_layers = 27
|
| 86 |
+
config.vision_config.num_attention_heads = 16
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("Model not supported")
|
| 89 |
+
|
| 90 |
+
return config
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def create_rename_keys(config):
|
| 94 |
+
rename_keys = []
|
| 95 |
+
# fmt: off
|
| 96 |
+
|
| 97 |
+
# vision encoder
|
| 98 |
+
|
| 99 |
+
rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
|
| 100 |
+
rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
|
| 101 |
+
rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
|
| 102 |
+
|
| 103 |
+
for i in range(config.vision_config.num_hidden_layers):
|
| 104 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
| 105 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
| 106 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
| 107 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
| 108 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
| 109 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
| 110 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
| 111 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
| 112 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
|
| 113 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
|
| 114 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
|
| 115 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
|
| 116 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
|
| 117 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
|
| 118 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
|
| 119 |
+
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
|
| 120 |
+
|
| 121 |
+
rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
|
| 122 |
+
rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
|
| 123 |
+
|
| 124 |
+
rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
|
| 125 |
+
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
|
| 126 |
+
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
|
| 127 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
|
| 128 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
|
| 129 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
|
| 130 |
+
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
|
| 131 |
+
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
|
| 132 |
+
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
|
| 133 |
+
|
| 134 |
+
# text encoder
|
| 135 |
+
|
| 136 |
+
rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
|
| 137 |
+
rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
|
| 138 |
+
|
| 139 |
+
for i in range(config.text_config.num_hidden_layers):
|
| 140 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
|
| 141 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
|
| 142 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
|
| 143 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
|
| 144 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
|
| 145 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
|
| 146 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
|
| 147 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
|
| 148 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
|
| 149 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
|
| 150 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
|
| 151 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
|
| 152 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
|
| 153 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
|
| 154 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
|
| 155 |
+
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
|
| 156 |
+
|
| 157 |
+
rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
|
| 158 |
+
rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
|
| 159 |
+
rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
|
| 160 |
+
rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
|
| 161 |
+
|
| 162 |
+
# learned temperature and bias
|
| 163 |
+
rename_keys.append(("params/t", "logit_scale"))
|
| 164 |
+
rename_keys.append(("params/b", "logit_bias"))
|
| 165 |
+
|
| 166 |
+
# fmt: on
|
| 167 |
+
return rename_keys
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def rename_key(dct, old, new, config):
|
| 171 |
+
val = dct.pop(old)
|
| 172 |
+
|
| 173 |
+
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
|
| 174 |
+
val = val.reshape(-1, config.vision_config.hidden_size)
|
| 175 |
+
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
|
| 176 |
+
val = val.reshape(-1, config.text_config.hidden_size)
|
| 177 |
+
|
| 178 |
+
if "patch_embedding.weight" in new:
|
| 179 |
+
val = val.transpose(3, 2, 0, 1)
|
| 180 |
+
elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
|
| 181 |
+
val = val.T
|
| 182 |
+
|
| 183 |
+
if "position_embedding" in new and "vision" in new:
|
| 184 |
+
val = val.reshape(-1, config.vision_config.hidden_size)
|
| 185 |
+
if "position_embedding" in new and "text" in new:
|
| 186 |
+
val = val.reshape(-1, config.text_config.hidden_size)
|
| 187 |
+
|
| 188 |
+
if new.endswith("bias"):
|
| 189 |
+
val = val.reshape(-1)
|
| 190 |
+
|
| 191 |
+
dct[new] = torch.from_numpy(val)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def read_in_q_k_v_head(state_dict, config):
|
| 195 |
+
# read in individual input projection layers
|
| 196 |
+
key_proj_weight = (
|
| 197 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
|
| 198 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 199 |
+
.T
|
| 200 |
+
)
|
| 201 |
+
key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
|
| 202 |
+
value_proj_weight = (
|
| 203 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
|
| 204 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 205 |
+
.T
|
| 206 |
+
)
|
| 207 |
+
value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
|
| 208 |
+
query_proj_weight = (
|
| 209 |
+
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
|
| 210 |
+
.reshape(-1, config.vision_config.hidden_size)
|
| 211 |
+
.T
|
| 212 |
+
)
|
| 213 |
+
query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
|
| 214 |
+
|
| 215 |
+
# next, add them to the state dict as a single matrix + vector
|
| 216 |
+
state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
|
| 217 |
+
np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
|
| 218 |
+
)
|
| 219 |
+
state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
|
| 220 |
+
np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# We will verify our results on an image of cute cats
|
| 225 |
+
def prepare_img():
|
| 226 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 227 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 228 |
+
return image
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def flatten_nested_dict(params, parent_key="", sep="/"):
|
| 232 |
+
items = []
|
| 233 |
+
|
| 234 |
+
for k, v in params.items():
|
| 235 |
+
new_key = parent_key + sep + k if parent_key else k
|
| 236 |
+
|
| 237 |
+
if isinstance(v, collections.abc.MutableMapping):
|
| 238 |
+
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
| 239 |
+
else:
|
| 240 |
+
items.append((new_key, v))
|
| 241 |
+
return dict(items)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@torch.no_grad()
|
| 245 |
+
def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
|
| 246 |
+
"""
|
| 247 |
+
Copy/paste/tweak model's weights to our SigLIP structure.
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
# define default SigLIP configuration
|
| 251 |
+
config = get_siglip_config(model_name)
|
| 252 |
+
|
| 253 |
+
# get checkpoint
|
| 254 |
+
checkpoint = model_name_to_checkpoint[model_name]
|
| 255 |
+
|
| 256 |
+
# get vocab file
|
| 257 |
+
if "i18n" in model_name:
|
| 258 |
+
vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
|
| 259 |
+
else:
|
| 260 |
+
vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
|
| 261 |
+
|
| 262 |
+
# load original state dict
|
| 263 |
+
data = load(checkpoint)
|
| 264 |
+
state_dict = flatten_nested_dict(data)
|
| 265 |
+
|
| 266 |
+
# remove and rename some keys
|
| 267 |
+
rename_keys = create_rename_keys(config)
|
| 268 |
+
for src, dest in rename_keys:
|
| 269 |
+
rename_key(state_dict, src, dest, config)
|
| 270 |
+
|
| 271 |
+
# qkv matrices of attention pooling head need special treatment
|
| 272 |
+
read_in_q_k_v_head(state_dict, config)
|
| 273 |
+
|
| 274 |
+
# load HuggingFace model
|
| 275 |
+
model = SiglipModel(config).eval()
|
| 276 |
+
model.load_state_dict(state_dict)
|
| 277 |
+
|
| 278 |
+
# create processor
|
| 279 |
+
# important: make tokenizer not return attention_mask since original one doesn't require it
|
| 280 |
+
image_size = config.vision_config.image_size
|
| 281 |
+
size = {"height": image_size, "width": image_size}
|
| 282 |
+
image_processor = SiglipImageProcessor(size=size)
|
| 283 |
+
tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
|
| 284 |
+
processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| 285 |
+
|
| 286 |
+
# verify on dummy images and texts
|
| 287 |
+
url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
|
| 288 |
+
image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
|
| 289 |
+
url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
|
| 290 |
+
image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
|
| 291 |
+
texts = ["an apple", "a picture of an apple"]
|
| 292 |
+
|
| 293 |
+
inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
|
| 294 |
+
|
| 295 |
+
# verify input_ids against original ones
|
| 296 |
+
if image_size == 224:
|
| 297 |
+
filename = "siglip_pixel_values.pt"
|
| 298 |
+
elif image_size == 256:
|
| 299 |
+
filename = "siglip_pixel_values_256.pt"
|
| 300 |
+
elif image_size == 384:
|
| 301 |
+
filename = "siglip_pixel_values_384.pt"
|
| 302 |
+
elif image_size == 512:
|
| 303 |
+
filename = "siglip_pixel_values_512.pt"
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError("Image size not supported")
|
| 306 |
+
|
| 307 |
+
filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
|
| 308 |
+
original_pixel_values = torch.load(filepath)
|
| 309 |
+
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
|
| 310 |
+
original_input_ids = torch.load(filepath)
|
| 311 |
+
|
| 312 |
+
if "i18n" not in model_name:
|
| 313 |
+
assert inputs.input_ids.tolist() == original_input_ids.tolist()
|
| 314 |
+
|
| 315 |
+
print("Mean of original pixel values:", original_pixel_values.mean())
|
| 316 |
+
print("Mean of new pixel values:", inputs.pixel_values.mean())
|
| 317 |
+
|
| 318 |
+
# note: we're testing with original pixel values here since we don't have exact pixel values
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
|
| 321 |
+
|
| 322 |
+
# with torch.no_grad():
|
| 323 |
+
# outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
|
| 324 |
+
|
| 325 |
+
print(outputs.logits_per_image[:3, :3])
|
| 326 |
+
|
| 327 |
+
probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
|
| 328 |
+
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
|
| 329 |
+
print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
|
| 330 |
+
|
| 331 |
+
if verify_logits:
|
| 332 |
+
if model_name == "siglip-base-patch16-224":
|
| 333 |
+
expected_slice = torch.tensor(
|
| 334 |
+
[[-2.9621, -2.1672], [-0.2713, 0.2910]],
|
| 335 |
+
)
|
| 336 |
+
elif model_name == "siglip-base-patch16-256":
|
| 337 |
+
expected_slice = torch.tensor(
|
| 338 |
+
[[-3.1146, -1.9894], [-0.7312, 0.6387]],
|
| 339 |
+
)
|
| 340 |
+
elif model_name == "siglip-base-patch16-384":
|
| 341 |
+
expected_slice = torch.tensor(
|
| 342 |
+
[[-2.8098, -2.1891], [-0.4242, 0.4102]],
|
| 343 |
+
)
|
| 344 |
+
elif model_name == "siglip-base-patch16-512":
|
| 345 |
+
expected_slice = torch.tensor(
|
| 346 |
+
[[-2.7899, -2.2668], [-0.4295, -0.0735]],
|
| 347 |
+
)
|
| 348 |
+
elif model_name == "siglip-large-patch16-256":
|
| 349 |
+
expected_slice = torch.tensor(
|
| 350 |
+
[[-1.5827, -0.5801], [-0.9153, 0.1363]],
|
| 351 |
+
)
|
| 352 |
+
elif model_name == "siglip-large-patch16-384":
|
| 353 |
+
expected_slice = torch.tensor(
|
| 354 |
+
[[-2.1523, -0.2899], [-0.2959, 0.7884]],
|
| 355 |
+
)
|
| 356 |
+
elif model_name == "siglip-so400m-patch14-384":
|
| 357 |
+
expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
|
| 358 |
+
elif model_name == "siglip-base-patch16-256-i18n":
|
| 359 |
+
expected_slice = torch.tensor(
|
| 360 |
+
[[-0.9064, 0.1073], [-0.0299, 0.5304]],
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
|
| 364 |
+
print("Looks ok!")
|
| 365 |
+
|
| 366 |
+
if pytorch_dump_folder_path is not None:
|
| 367 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 368 |
+
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
| 369 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 370 |
+
print(f"Saving processor to {pytorch_dump_folder_path}")
|
| 371 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 372 |
+
|
| 373 |
+
if push_to_hub:
|
| 374 |
+
model.push_to_hub(f"nielsr/{model_name}")
|
| 375 |
+
processor.push_to_hub(f"nielsr/{model_name}")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
if __name__ == "__main__":
|
| 379 |
+
parser = argparse.ArgumentParser()
|
| 380 |
+
# Required parameters
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--model_name",
|
| 383 |
+
default="siglip-base-patch16-224",
|
| 384 |
+
type=str,
|
| 385 |
+
choices=model_name_to_checkpoint.keys(),
|
| 386 |
+
help="Name of the model you'd like to convert.",
|
| 387 |
+
)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--verify_logits",
|
| 393 |
+
action="store_false",
|
| 394 |
+
help="Whether to verify logits against the original implementation.",
|
| 395 |
+
)
|
| 396 |
+
parser.add_argument(
|
| 397 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
args = parser.parse_args()
|
| 401 |
+
convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
|
modeling/siglip/image_processing_siglip.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Image processor class for SigLIP."""
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 9 |
+
from transformers.image_transforms import (
|
| 10 |
+
convert_to_rgb,
|
| 11 |
+
resize,
|
| 12 |
+
to_channel_dimension_format,
|
| 13 |
+
)
|
| 14 |
+
from transformers.image_utils import (
|
| 15 |
+
IMAGENET_STANDARD_MEAN,
|
| 16 |
+
IMAGENET_STANDARD_STD,
|
| 17 |
+
ChannelDimension,
|
| 18 |
+
ImageInput,
|
| 19 |
+
PILImageResampling,
|
| 20 |
+
infer_channel_dimension_format,
|
| 21 |
+
is_scaled_image,
|
| 22 |
+
make_list_of_images,
|
| 23 |
+
to_numpy_array,
|
| 24 |
+
valid_images,
|
| 25 |
+
validate_preprocess_arguments,
|
| 26 |
+
)
|
| 27 |
+
from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_vision_available():
|
| 34 |
+
import PIL
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SiglipImageProcessor(BaseImageProcessor):
|
| 38 |
+
r"""
|
| 39 |
+
Constructs a SigLIP image processor.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 43 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
| 44 |
+
`do_resize` in the `preprocess` method.
|
| 45 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 46 |
+
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
| 47 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 48 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
| 49 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 50 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
| 51 |
+
the `preprocess` method.
|
| 52 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 53 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
| 54 |
+
method.
|
| 55 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 56 |
+
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
| 57 |
+
`do_normalize` in the `preprocess` method.
|
| 58 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 59 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 60 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 61 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 62 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 63 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 64 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 65 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether to convert the image to RGB.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
model_input_names = ["pixel_values"]
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
do_resize: bool = True,
|
| 74 |
+
size: Dict[str, int] = None,
|
| 75 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 76 |
+
do_rescale: bool = True,
|
| 77 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 78 |
+
do_normalize: bool = True,
|
| 79 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 80 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 81 |
+
do_convert_rgb: bool = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 86 |
+
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 87 |
+
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 88 |
+
|
| 89 |
+
self.do_resize = do_resize
|
| 90 |
+
self.size = size
|
| 91 |
+
self.resample = resample
|
| 92 |
+
self.do_rescale = do_rescale
|
| 93 |
+
self.rescale_factor = rescale_factor
|
| 94 |
+
self.do_normalize = do_normalize
|
| 95 |
+
self.image_mean = image_mean
|
| 96 |
+
self.image_std = image_std
|
| 97 |
+
self.do_convert_rgb = do_convert_rgb
|
| 98 |
+
|
| 99 |
+
@filter_out_non_signature_kwargs()
|
| 100 |
+
def preprocess(
|
| 101 |
+
self,
|
| 102 |
+
images: ImageInput,
|
| 103 |
+
do_resize: bool = None,
|
| 104 |
+
size: Dict[str, int] = None,
|
| 105 |
+
resample: PILImageResampling = None,
|
| 106 |
+
do_rescale: bool = None,
|
| 107 |
+
rescale_factor: float = None,
|
| 108 |
+
do_normalize: bool = None,
|
| 109 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 110 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 111 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 112 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 113 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 114 |
+
do_convert_rgb: bool = None,
|
| 115 |
+
) -> PIL.Image.Image:
|
| 116 |
+
"""
|
| 117 |
+
Preprocess an image or batch of images.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
images (`ImageInput`):
|
| 121 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 122 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 123 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 124 |
+
Whether to resize the image.
|
| 125 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 126 |
+
Size of the image after resizing.
|
| 127 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 128 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 129 |
+
has an effect if `do_resize` is set to `True`.
|
| 130 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 131 |
+
Whether to rescale the image.
|
| 132 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 133 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 134 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 135 |
+
Whether to normalize the image.
|
| 136 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 137 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 138 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 139 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 140 |
+
`True`.
|
| 141 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 142 |
+
The type of tensors to return. Can be one of:
|
| 143 |
+
- Unset: Return a list of `np.ndarray`.
|
| 144 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 145 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 146 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 147 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 148 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 149 |
+
The channel dimension format for the output image. Can be one of:
|
| 150 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 151 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 152 |
+
- Unset: Use the channel dimension format of the input image.
|
| 153 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 154 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 155 |
+
from the input image. Can be one of:
|
| 156 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 157 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 158 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 159 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 160 |
+
Whether to convert the image to RGB.
|
| 161 |
+
"""
|
| 162 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 163 |
+
size = size if size is not None else self.size
|
| 164 |
+
size = get_size_dict(size, param_name="size", default_to_square=False)
|
| 165 |
+
resample = resample if resample is not None else self.resample
|
| 166 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 167 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 168 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 169 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 170 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 171 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 172 |
+
|
| 173 |
+
images = make_list_of_images(images)
|
| 174 |
+
|
| 175 |
+
if not valid_images(images):
|
| 176 |
+
raise ValueError(
|
| 177 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 178 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 179 |
+
)
|
| 180 |
+
validate_preprocess_arguments(
|
| 181 |
+
do_rescale=do_rescale,
|
| 182 |
+
rescale_factor=rescale_factor,
|
| 183 |
+
do_normalize=do_normalize,
|
| 184 |
+
image_mean=image_mean,
|
| 185 |
+
image_std=image_std,
|
| 186 |
+
do_resize=do_resize,
|
| 187 |
+
size=size,
|
| 188 |
+
resample=resample,
|
| 189 |
+
)
|
| 190 |
+
# All transformations expect numpy arrays.
|
| 191 |
+
images = [to_numpy_array(image) for image in images]
|
| 192 |
+
|
| 193 |
+
if do_convert_rgb:
|
| 194 |
+
images = [convert_to_rgb(image) for image in images]
|
| 195 |
+
|
| 196 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 197 |
+
logger.warning_once(
|
| 198 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 199 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if input_data_format is None:
|
| 203 |
+
# We assume that all images have the same channel dimension format.
|
| 204 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 205 |
+
|
| 206 |
+
if do_resize:
|
| 207 |
+
height, width = size["height"], size["width"]
|
| 208 |
+
images = [
|
| 209 |
+
resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
|
| 210 |
+
for image in images
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
if do_rescale:
|
| 214 |
+
images = [
|
| 215 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 216 |
+
for image in images
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
if do_normalize:
|
| 220 |
+
images = [
|
| 221 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 222 |
+
for image in images
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
images = [
|
| 226 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
data = {"pixel_values": images}
|
| 230 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
modeling/siglip/modeling_siglip.py
ADDED
|
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""PyTorch Siglip model."""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import warnings
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 16 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
| 17 |
+
|
| 18 |
+
from transformers.activations import ACT2FN
|
| 19 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 20 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
| 21 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 22 |
+
from transformers.utils import (
|
| 23 |
+
ModelOutput,
|
| 24 |
+
add_start_docstrings,
|
| 25 |
+
add_start_docstrings_to_model_forward,
|
| 26 |
+
is_flash_attn_2_available,
|
| 27 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 28 |
+
logging,
|
| 29 |
+
replace_return_docstrings,
|
| 30 |
+
torch_int,
|
| 31 |
+
)
|
| 32 |
+
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_flash_attn_2_available():
|
| 36 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
# General docstring
|
| 42 |
+
_CONFIG_FOR_DOC = "SiglipConfig"
|
| 43 |
+
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
| 47 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 48 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 49 |
+
def norm_cdf(x):
|
| 50 |
+
# Computes standard normal cumulative distribution function
|
| 51 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 52 |
+
|
| 53 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 54 |
+
warnings.warn(
|
| 55 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 56 |
+
"The distribution of values may be incorrect.",
|
| 57 |
+
stacklevel=2,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Values are generated by using a truncated uniform distribution and
|
| 61 |
+
# then using the inverse CDF for the normal distribution.
|
| 62 |
+
# Get upper and lower cdf values
|
| 63 |
+
l = norm_cdf((a - mean) / std)
|
| 64 |
+
u = norm_cdf((b - mean) / std)
|
| 65 |
+
|
| 66 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 67 |
+
# [2l-1, 2u-1].
|
| 68 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 69 |
+
|
| 70 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 71 |
+
# standard normal
|
| 72 |
+
tensor.erfinv_()
|
| 73 |
+
|
| 74 |
+
# Transform to proper mean, std
|
| 75 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 76 |
+
tensor.add_(mean)
|
| 77 |
+
|
| 78 |
+
# Clamp to ensure it's in the proper range
|
| 79 |
+
tensor.clamp_(min=a, max=b)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def trunc_normal_tf_(
|
| 83 |
+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Fills the input Tensor with values drawn from a truncated
|
| 86 |
+
normal distribution. The values are effectively drawn from the
|
| 87 |
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 88 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 89 |
+
the bounds. The method used for generating the random values works
|
| 90 |
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
| 91 |
+
|
| 92 |
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
| 93 |
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
| 94 |
+
and the result is subsequently scaled and shifted by the mean and std args.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 98 |
+
mean: the mean of the normal distribution
|
| 99 |
+
std: the standard deviation of the normal distribution
|
| 100 |
+
a: the minimum cutoff value
|
| 101 |
+
b: the maximum cutoff value
|
| 102 |
+
"""
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
| 105 |
+
tensor.mul_(std).add_(mean)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
| 109 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
| 110 |
+
if mode == "fan_in":
|
| 111 |
+
denom = fan_in
|
| 112 |
+
elif mode == "fan_out":
|
| 113 |
+
denom = fan_out
|
| 114 |
+
elif mode == "fan_avg":
|
| 115 |
+
denom = (fan_in + fan_out) / 2
|
| 116 |
+
|
| 117 |
+
variance = scale / denom
|
| 118 |
+
|
| 119 |
+
if distribution == "truncated_normal":
|
| 120 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
| 121 |
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
| 122 |
+
elif distribution == "normal":
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
tensor.normal_(std=math.sqrt(variance))
|
| 125 |
+
elif distribution == "uniform":
|
| 126 |
+
bound = math.sqrt(3 * variance)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
tensor.uniform_(-bound, bound)
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"invalid distribution {distribution}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def lecun_normal_(tensor):
|
| 134 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def default_flax_embed_init(tensor):
|
| 138 |
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
|
| 143 |
+
class SiglipVisionModelOutput(ModelOutput):
|
| 144 |
+
"""
|
| 145 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 149 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 150 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 151 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 152 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 153 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 154 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 155 |
+
|
| 156 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 157 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 158 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 159 |
+
sequence_length)`.
|
| 160 |
+
|
| 161 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 162 |
+
heads.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 166 |
+
last_hidden_state: torch.FloatTensor = None
|
| 167 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 168 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@dataclass
|
| 172 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
|
| 173 |
+
class SiglipTextModelOutput(ModelOutput):
|
| 174 |
+
"""
|
| 175 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 179 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
| 180 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 181 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 182 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 183 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 184 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 185 |
+
|
| 186 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 187 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 188 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 189 |
+
sequence_length)`.
|
| 190 |
+
|
| 191 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 192 |
+
heads.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
| 196 |
+
last_hidden_state: torch.FloatTensor = None
|
| 197 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 198 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
|
| 203 |
+
class SiglipOutput(ModelOutput):
|
| 204 |
+
"""
|
| 205 |
+
Args:
|
| 206 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
| 207 |
+
Contrastive loss for image-text similarity.
|
| 208 |
+
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
| 209 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
| 210 |
+
similarity scores.
|
| 211 |
+
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
| 212 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
| 213 |
+
similarity scores.
|
| 214 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 215 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
|
| 216 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
| 217 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
|
| 218 |
+
text_model_output (`BaseModelOutputWithPooling`):
|
| 219 |
+
The output of the [`SiglipTextModel`].
|
| 220 |
+
vision_model_output (`BaseModelOutputWithPooling`):
|
| 221 |
+
The output of the [`SiglipVisionModel`].
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
loss: Optional[torch.FloatTensor] = None
|
| 225 |
+
logits_per_image: torch.FloatTensor = None
|
| 226 |
+
logits_per_text: torch.FloatTensor = None
|
| 227 |
+
text_embeds: torch.FloatTensor = None
|
| 228 |
+
image_embeds: torch.FloatTensor = None
|
| 229 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 230 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
| 231 |
+
|
| 232 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 233 |
+
return tuple(
|
| 234 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
| 235 |
+
for k in self.keys()
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class SiglipVisionEmbeddings(nn.Module):
|
| 240 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.config = config
|
| 243 |
+
self.embed_dim = config.hidden_size
|
| 244 |
+
self.image_size = config.image_size
|
| 245 |
+
self.patch_size = config.patch_size
|
| 246 |
+
|
| 247 |
+
self.patch_embedding = nn.Conv2d(
|
| 248 |
+
in_channels=config.num_channels,
|
| 249 |
+
out_channels=self.embed_dim,
|
| 250 |
+
kernel_size=self.patch_size,
|
| 251 |
+
stride=self.patch_size,
|
| 252 |
+
padding="valid",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 256 |
+
self.num_positions = self.num_patches
|
| 257 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 258 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
| 259 |
+
|
| 260 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 263 |
+
images. This method is also adapted to support torch.jit tracing and no class embeddings.
|
| 264 |
+
|
| 265 |
+
Adapted from:
|
| 266 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 267 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
num_patches = embeddings.shape[1]
|
| 271 |
+
num_positions = self.position_embedding.weight.shape[0]
|
| 272 |
+
|
| 273 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 274 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 275 |
+
return self.position_embedding(self.position_ids)
|
| 276 |
+
|
| 277 |
+
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
| 278 |
+
|
| 279 |
+
dim = embeddings.shape[-1]
|
| 280 |
+
|
| 281 |
+
new_height = height // self.patch_size
|
| 282 |
+
new_width = width // self.patch_size
|
| 283 |
+
|
| 284 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 285 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 286 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 287 |
+
|
| 288 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 289 |
+
patch_pos_embed,
|
| 290 |
+
size=(new_height, new_width),
|
| 291 |
+
mode="bicubic",
|
| 292 |
+
align_corners=False,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 296 |
+
return patch_pos_embed
|
| 297 |
+
|
| 298 |
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
| 299 |
+
_, _, height, width = pixel_values.shape
|
| 300 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 301 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 302 |
+
|
| 303 |
+
if interpolate_pos_encoding:
|
| 304 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 305 |
+
else:
|
| 306 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
| 307 |
+
return embeddings
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
|
| 311 |
+
class SiglipTextEmbeddings(nn.Module):
|
| 312 |
+
def __init__(self, config: SiglipTextConfig):
|
| 313 |
+
super().__init__()
|
| 314 |
+
embed_dim = config.hidden_size
|
| 315 |
+
|
| 316 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
| 317 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
| 318 |
+
|
| 319 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 320 |
+
self.register_buffer(
|
| 321 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def forward(
|
| 325 |
+
self,
|
| 326 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 327 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 328 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 329 |
+
) -> torch.Tensor:
|
| 330 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
| 331 |
+
|
| 332 |
+
if position_ids is None:
|
| 333 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 334 |
+
|
| 335 |
+
if inputs_embeds is None:
|
| 336 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 337 |
+
|
| 338 |
+
position_embeddings = self.position_embedding(position_ids)
|
| 339 |
+
embeddings = inputs_embeds + position_embeddings
|
| 340 |
+
|
| 341 |
+
return embeddings
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class SiglipAttention(nn.Module):
|
| 345 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 346 |
+
|
| 347 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
| 348 |
+
def __init__(self, config):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.config = config
|
| 351 |
+
self.embed_dim = config.hidden_size
|
| 352 |
+
self.num_heads = config.num_attention_heads
|
| 353 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 354 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 357 |
+
f" {self.num_heads})."
|
| 358 |
+
)
|
| 359 |
+
self.scale = self.head_dim**-0.5
|
| 360 |
+
self.dropout = config.attention_dropout
|
| 361 |
+
|
| 362 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 363 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 364 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 365 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
hidden_states: torch.Tensor,
|
| 370 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 371 |
+
output_attentions: Optional[bool] = False,
|
| 372 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 373 |
+
"""Input shape: Batch x Time x Channel"""
|
| 374 |
+
|
| 375 |
+
batch_size, q_len, _ = hidden_states.size()
|
| 376 |
+
|
| 377 |
+
query_states = self.q_proj(hidden_states)
|
| 378 |
+
key_states = self.k_proj(hidden_states)
|
| 379 |
+
value_states = self.v_proj(hidden_states)
|
| 380 |
+
|
| 381 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 382 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 383 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 384 |
+
|
| 385 |
+
k_v_seq_len = key_states.shape[-2]
|
| 386 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
| 387 |
+
|
| 388 |
+
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
| 389 |
+
raise ValueError(
|
| 390 |
+
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
| 391 |
+
f" {attn_weights.size()}"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if attention_mask is not None:
|
| 395 |
+
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
| 398 |
+
)
|
| 399 |
+
attn_weights = attn_weights + attention_mask
|
| 400 |
+
|
| 401 |
+
# upcast attention to fp32
|
| 402 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 403 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 404 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 405 |
+
|
| 406 |
+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
| 407 |
+
raise ValueError(
|
| 408 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
| 409 |
+
f" {attn_output.size()}"
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 413 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
| 414 |
+
|
| 415 |
+
attn_output = self.out_proj(attn_output)
|
| 416 |
+
|
| 417 |
+
return attn_output, attn_weights
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class SiglipFlashAttention2(SiglipAttention):
|
| 421 |
+
"""
|
| 422 |
+
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
|
| 423 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 424 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
is_causal = False
|
| 428 |
+
|
| 429 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
| 430 |
+
def __init__(self, *args, **kwargs):
|
| 431 |
+
super().__init__(*args, **kwargs)
|
| 432 |
+
|
| 433 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 434 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 435 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 436 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 437 |
+
|
| 438 |
+
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
| 439 |
+
def forward(
|
| 440 |
+
self,
|
| 441 |
+
hidden_states: torch.Tensor,
|
| 442 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 443 |
+
output_attentions: bool = False,
|
| 444 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 445 |
+
output_attentions = False
|
| 446 |
+
|
| 447 |
+
batch_size, q_len, _ = hidden_states.size()
|
| 448 |
+
|
| 449 |
+
query_states = self.q_proj(hidden_states)
|
| 450 |
+
key_states = self.k_proj(hidden_states)
|
| 451 |
+
value_states = self.v_proj(hidden_states)
|
| 452 |
+
|
| 453 |
+
# Flash attention requires the input to have the shape
|
| 454 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 455 |
+
# therefore we just need to keep the original shape
|
| 456 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 457 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 458 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 459 |
+
|
| 460 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 461 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 462 |
+
query_states = query_states.transpose(1, 2)
|
| 463 |
+
key_states = key_states.transpose(1, 2)
|
| 464 |
+
value_states = value_states.transpose(1, 2)
|
| 465 |
+
|
| 466 |
+
dropout_rate = self.dropout if self.training else 0.0
|
| 467 |
+
|
| 468 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 469 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 470 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 471 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 472 |
+
# in fp32.
|
| 473 |
+
|
| 474 |
+
input_dtype = query_states.dtype
|
| 475 |
+
if input_dtype == torch.float32:
|
| 476 |
+
if torch.is_autocast_enabled():
|
| 477 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 478 |
+
# Handle the case where the model is quantized
|
| 479 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 480 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 481 |
+
else:
|
| 482 |
+
target_dtype = self.q_proj.weight.dtype
|
| 483 |
+
|
| 484 |
+
logger.warning_once(
|
| 485 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 486 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 487 |
+
f" {target_dtype}."
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
query_states = query_states.to(target_dtype)
|
| 491 |
+
key_states = key_states.to(target_dtype)
|
| 492 |
+
value_states = value_states.to(target_dtype)
|
| 493 |
+
|
| 494 |
+
attn_output = _flash_attention_forward(
|
| 495 |
+
query_states,
|
| 496 |
+
key_states,
|
| 497 |
+
value_states,
|
| 498 |
+
attention_mask,
|
| 499 |
+
q_len,
|
| 500 |
+
dropout=dropout_rate,
|
| 501 |
+
is_causal=self.is_causal,
|
| 502 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
| 506 |
+
attn_output = self.out_proj(attn_output)
|
| 507 |
+
|
| 508 |
+
if not output_attentions:
|
| 509 |
+
attn_weights = None
|
| 510 |
+
|
| 511 |
+
return attn_output, attn_weights
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class SiglipSdpaAttention(SiglipAttention):
|
| 515 |
+
"""
|
| 516 |
+
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 517 |
+
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 518 |
+
SDPA API.
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
is_causal = False
|
| 522 |
+
|
| 523 |
+
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
|
| 524 |
+
def forward(
|
| 525 |
+
self,
|
| 526 |
+
hidden_states: torch.Tensor,
|
| 527 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 528 |
+
output_attentions: Optional[bool] = False,
|
| 529 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 530 |
+
if output_attentions:
|
| 531 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 532 |
+
logger.warning_once(
|
| 533 |
+
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 534 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 535 |
+
)
|
| 536 |
+
return super().forward(
|
| 537 |
+
hidden_states=hidden_states,
|
| 538 |
+
attention_mask=attention_mask,
|
| 539 |
+
output_attentions=output_attentions,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
batch_size, q_len, _ = hidden_states.size()
|
| 543 |
+
|
| 544 |
+
query_states = self.q_proj(hidden_states)
|
| 545 |
+
key_states = self.k_proj(hidden_states)
|
| 546 |
+
value_states = self.v_proj(hidden_states)
|
| 547 |
+
|
| 548 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 549 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 550 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 551 |
+
|
| 552 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 553 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 554 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 555 |
+
query_states = query_states.contiguous()
|
| 556 |
+
key_states = key_states.contiguous()
|
| 557 |
+
value_states = value_states.contiguous()
|
| 558 |
+
|
| 559 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 560 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 561 |
+
is_causal = True if self.is_causal and q_len > 1 else False
|
| 562 |
+
|
| 563 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 564 |
+
query_states,
|
| 565 |
+
key_states,
|
| 566 |
+
value_states,
|
| 567 |
+
attn_mask=attention_mask,
|
| 568 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 569 |
+
is_causal=is_causal,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 573 |
+
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
| 574 |
+
|
| 575 |
+
attn_output = self.out_proj(attn_output)
|
| 576 |
+
|
| 577 |
+
return attn_output, None
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
SIGLIP_ATTENTION_CLASSES = {
|
| 581 |
+
"eager": SiglipAttention,
|
| 582 |
+
"flash_attention_2": SiglipFlashAttention2,
|
| 583 |
+
"sdpa": SiglipSdpaAttention,
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
| 588 |
+
class SiglipMLP(nn.Module):
|
| 589 |
+
def __init__(self, config):
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.config = config
|
| 592 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 593 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 594 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 595 |
+
|
| 596 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 597 |
+
hidden_states = self.fc1(hidden_states)
|
| 598 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 599 |
+
hidden_states = self.fc2(hidden_states)
|
| 600 |
+
return hidden_states
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class SiglipEncoderLayer(nn.Module):
|
| 604 |
+
def __init__(self, config: SiglipConfig):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.embed_dim = config.hidden_size
|
| 607 |
+
self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
| 608 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 609 |
+
self.mlp = SiglipMLP(config)
|
| 610 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 611 |
+
|
| 612 |
+
# Ignore copy
|
| 613 |
+
def forward(
|
| 614 |
+
self,
|
| 615 |
+
hidden_states: torch.Tensor,
|
| 616 |
+
attention_mask: torch.Tensor,
|
| 617 |
+
output_attentions: Optional[bool] = False,
|
| 618 |
+
) -> Tuple[torch.FloatTensor]:
|
| 619 |
+
"""
|
| 620 |
+
Args:
|
| 621 |
+
hidden_states (`torch.FloatTensor`):
|
| 622 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
| 623 |
+
attention_mask (`torch.FloatTensor`):
|
| 624 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
| 625 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 626 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 627 |
+
returned tensors for more detail.
|
| 628 |
+
"""
|
| 629 |
+
residual = hidden_states
|
| 630 |
+
|
| 631 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 632 |
+
hidden_states, attn_weights = self.self_attn(
|
| 633 |
+
hidden_states=hidden_states,
|
| 634 |
+
attention_mask=attention_mask,
|
| 635 |
+
output_attentions=output_attentions,
|
| 636 |
+
)
|
| 637 |
+
hidden_states = residual + hidden_states
|
| 638 |
+
|
| 639 |
+
residual = hidden_states
|
| 640 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 641 |
+
hidden_states = self.mlp(hidden_states)
|
| 642 |
+
hidden_states = residual + hidden_states
|
| 643 |
+
|
| 644 |
+
outputs = (hidden_states,)
|
| 645 |
+
|
| 646 |
+
if output_attentions:
|
| 647 |
+
outputs += (attn_weights,)
|
| 648 |
+
|
| 649 |
+
return outputs
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class SiglipPreTrainedModel(PreTrainedModel):
|
| 653 |
+
"""
|
| 654 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 655 |
+
models.
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
+
config_class = SiglipConfig
|
| 659 |
+
base_model_prefix = "siglip"
|
| 660 |
+
supports_gradient_checkpointing = True
|
| 661 |
+
|
| 662 |
+
_no_split_modules = [
|
| 663 |
+
"SiglipTextEmbeddings",
|
| 664 |
+
"SiglipEncoderLayer",
|
| 665 |
+
"SiglipVisionEmbeddings",
|
| 666 |
+
"SiglipEncoderLayer",
|
| 667 |
+
"SiglipMultiheadAttentionPoolingHead",
|
| 668 |
+
]
|
| 669 |
+
_supports_flash_attn_2 = True
|
| 670 |
+
_supports_sdpa = True
|
| 671 |
+
|
| 672 |
+
def _init_weights(self, module):
|
| 673 |
+
"""Initialize the weights"""
|
| 674 |
+
if isinstance(module, SiglipVisionEmbeddings):
|
| 675 |
+
width = (
|
| 676 |
+
self.config.vision_config.hidden_size
|
| 677 |
+
if isinstance(self.config, SiglipConfig)
|
| 678 |
+
else self.config.hidden_size
|
| 679 |
+
)
|
| 680 |
+
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
| 681 |
+
elif isinstance(module, nn.Embedding):
|
| 682 |
+
default_flax_embed_init(module.weight)
|
| 683 |
+
elif isinstance(module, SiglipAttention):
|
| 684 |
+
nn.init.xavier_uniform_(module.q_proj.weight)
|
| 685 |
+
nn.init.xavier_uniform_(module.k_proj.weight)
|
| 686 |
+
nn.init.xavier_uniform_(module.v_proj.weight)
|
| 687 |
+
nn.init.xavier_uniform_(module.out_proj.weight)
|
| 688 |
+
nn.init.zeros_(module.q_proj.bias)
|
| 689 |
+
nn.init.zeros_(module.k_proj.bias)
|
| 690 |
+
nn.init.zeros_(module.v_proj.bias)
|
| 691 |
+
nn.init.zeros_(module.out_proj.bias)
|
| 692 |
+
elif isinstance(module, SiglipMLP):
|
| 693 |
+
nn.init.xavier_uniform_(module.fc1.weight)
|
| 694 |
+
nn.init.xavier_uniform_(module.fc2.weight)
|
| 695 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 696 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 697 |
+
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
| 698 |
+
nn.init.xavier_uniform_(module.probe.data)
|
| 699 |
+
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
|
| 700 |
+
nn.init.zeros_(module.attention.in_proj_bias.data)
|
| 701 |
+
elif isinstance(module, SiglipModel):
|
| 702 |
+
logit_scale_init = torch.log(torch.tensor(1.0))
|
| 703 |
+
module.logit_scale.data.fill_(logit_scale_init)
|
| 704 |
+
module.logit_bias.data.zero_()
|
| 705 |
+
elif isinstance(module, SiglipForImageClassification):
|
| 706 |
+
nn.init.normal_(
|
| 707 |
+
module.classifier.weight,
|
| 708 |
+
std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
|
| 709 |
+
)
|
| 710 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 711 |
+
lecun_normal_(module.weight)
|
| 712 |
+
if module.bias is not None:
|
| 713 |
+
nn.init.zeros_(module.bias)
|
| 714 |
+
elif isinstance(module, nn.LayerNorm):
|
| 715 |
+
module.bias.data.zero_()
|
| 716 |
+
module.weight.data.fill_(1.0)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
SIGLIP_START_DOCSTRING = r"""
|
| 720 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 721 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 722 |
+
etc.)
|
| 723 |
+
|
| 724 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 725 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 726 |
+
and behavior.
|
| 727 |
+
|
| 728 |
+
Parameters:
|
| 729 |
+
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
|
| 730 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 731 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
|
| 735 |
+
Args:
|
| 736 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 737 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 738 |
+
it.
|
| 739 |
+
|
| 740 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 741 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 742 |
+
|
| 743 |
+
[What are input IDs?](../glossary#input-ids)
|
| 744 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 745 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 746 |
+
|
| 747 |
+
- 1 for tokens that are **not masked**,
|
| 748 |
+
- 0 for tokens that are **masked**.
|
| 749 |
+
|
| 750 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 751 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 752 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 753 |
+
config.max_position_embeddings - 1]`.
|
| 754 |
+
|
| 755 |
+
[What are position IDs?](../glossary#position-ids)
|
| 756 |
+
output_attentions (`bool`, *optional*):
|
| 757 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 758 |
+
tensors for more detail.
|
| 759 |
+
output_hidden_states (`bool`, *optional*):
|
| 760 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 761 |
+
more detail.
|
| 762 |
+
return_dict (`bool`, *optional*):
|
| 763 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 764 |
+
"""
|
| 765 |
+
|
| 766 |
+
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
|
| 767 |
+
Args:
|
| 768 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 769 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 770 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 771 |
+
output_attentions (`bool`, *optional*):
|
| 772 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 773 |
+
tensors for more detail.
|
| 774 |
+
output_hidden_states (`bool`, *optional*):
|
| 775 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 776 |
+
more detail.
|
| 777 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
| 778 |
+
Whether to interpolate the pre-trained position encodings.
|
| 779 |
+
return_dict (`bool`, *optional*):
|
| 780 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 781 |
+
"""
|
| 782 |
+
|
| 783 |
+
SIGLIP_INPUTS_DOCSTRING = r"""
|
| 784 |
+
Args:
|
| 785 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 786 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 787 |
+
it.
|
| 788 |
+
|
| 789 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 790 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 791 |
+
|
| 792 |
+
[What are input IDs?](../glossary#input-ids)
|
| 793 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 794 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 795 |
+
|
| 796 |
+
- 1 for tokens that are **not masked**,
|
| 797 |
+
- 0 for tokens that are **masked**.
|
| 798 |
+
|
| 799 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 800 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 801 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 802 |
+
config.max_position_embeddings - 1]`.
|
| 803 |
+
|
| 804 |
+
[What are position IDs?](../glossary#position-ids)
|
| 805 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 806 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 807 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 808 |
+
return_loss (`bool`, *optional*):
|
| 809 |
+
Whether or not to return the contrastive loss.
|
| 810 |
+
output_attentions (`bool`, *optional*):
|
| 811 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 812 |
+
tensors for more detail.
|
| 813 |
+
output_hidden_states (`bool`, *optional*):
|
| 814 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 815 |
+
more detail.
|
| 816 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
| 817 |
+
Whether to interpolate the pre-trained position encodings.
|
| 818 |
+
return_dict (`bool`, *optional*):
|
| 819 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 820 |
+
"""
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
|
| 824 |
+
class SiglipEncoder(nn.Module):
|
| 825 |
+
"""
|
| 826 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 827 |
+
[`SiglipEncoderLayer`].
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
config: SiglipConfig
|
| 831 |
+
"""
|
| 832 |
+
|
| 833 |
+
def __init__(self, config: SiglipConfig):
|
| 834 |
+
super().__init__()
|
| 835 |
+
self.config = config
|
| 836 |
+
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 837 |
+
self.gradient_checkpointing = False
|
| 838 |
+
|
| 839 |
+
# Ignore copy
|
| 840 |
+
def forward(
|
| 841 |
+
self,
|
| 842 |
+
inputs_embeds,
|
| 843 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 844 |
+
output_attentions: Optional[bool] = None,
|
| 845 |
+
output_hidden_states: Optional[bool] = None,
|
| 846 |
+
return_dict: Optional[bool] = None,
|
| 847 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 848 |
+
r"""
|
| 849 |
+
Args:
|
| 850 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 851 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 852 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 853 |
+
than the model's internal embedding lookup matrix.
|
| 854 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 855 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 856 |
+
|
| 857 |
+
- 1 for tokens that are **not masked**,
|
| 858 |
+
- 0 for tokens that are **masked**.
|
| 859 |
+
|
| 860 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 861 |
+
output_attentions (`bool`, *optional*):
|
| 862 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 863 |
+
returned tensors for more detail.
|
| 864 |
+
output_hidden_states (`bool`, *optional*):
|
| 865 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 866 |
+
for more detail.
|
| 867 |
+
return_dict (`bool`, *optional*):
|
| 868 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 869 |
+
"""
|
| 870 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 871 |
+
output_hidden_states = (
|
| 872 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 873 |
+
)
|
| 874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 875 |
+
|
| 876 |
+
encoder_states = () if output_hidden_states else None
|
| 877 |
+
all_attentions = () if output_attentions else None
|
| 878 |
+
|
| 879 |
+
hidden_states = inputs_embeds
|
| 880 |
+
for encoder_layer in self.layers:
|
| 881 |
+
if output_hidden_states:
|
| 882 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 883 |
+
if self.gradient_checkpointing and self.training:
|
| 884 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 885 |
+
encoder_layer.__call__,
|
| 886 |
+
hidden_states,
|
| 887 |
+
attention_mask,
|
| 888 |
+
output_attentions,
|
| 889 |
+
)
|
| 890 |
+
else:
|
| 891 |
+
layer_outputs = encoder_layer(
|
| 892 |
+
hidden_states,
|
| 893 |
+
attention_mask,
|
| 894 |
+
output_attentions=output_attentions,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
hidden_states = layer_outputs[0]
|
| 898 |
+
|
| 899 |
+
if output_attentions:
|
| 900 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 901 |
+
|
| 902 |
+
if output_hidden_states:
|
| 903 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 904 |
+
|
| 905 |
+
if not return_dict:
|
| 906 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 907 |
+
return BaseModelOutput(
|
| 908 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class SiglipTextTransformer(nn.Module):
|
| 913 |
+
def __init__(self, config: SiglipTextConfig):
|
| 914 |
+
super().__init__()
|
| 915 |
+
self.config = config
|
| 916 |
+
embed_dim = config.hidden_size
|
| 917 |
+
self.embeddings = SiglipTextEmbeddings(config)
|
| 918 |
+
self.encoder = SiglipEncoder(config)
|
| 919 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 920 |
+
|
| 921 |
+
self.head = nn.Linear(embed_dim, embed_dim)
|
| 922 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 923 |
+
|
| 924 |
+
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
|
| 925 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
|
| 926 |
+
def forward(
|
| 927 |
+
self,
|
| 928 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 929 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 930 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 931 |
+
output_attentions: Optional[bool] = None,
|
| 932 |
+
output_hidden_states: Optional[bool] = None,
|
| 933 |
+
return_dict: Optional[bool] = None,
|
| 934 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 935 |
+
r"""
|
| 936 |
+
Returns:
|
| 937 |
+
|
| 938 |
+
"""
|
| 939 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 940 |
+
output_hidden_states = (
|
| 941 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 942 |
+
)
|
| 943 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 944 |
+
|
| 945 |
+
if input_ids is None:
|
| 946 |
+
raise ValueError("You have to specify input_ids")
|
| 947 |
+
|
| 948 |
+
input_shape = input_ids.size()
|
| 949 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 950 |
+
|
| 951 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 952 |
+
|
| 953 |
+
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
|
| 954 |
+
# expand attention_mask
|
| 955 |
+
if attention_mask is not None and not self._use_flash_attention_2:
|
| 956 |
+
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
| 957 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
| 958 |
+
|
| 959 |
+
encoder_outputs = self.encoder(
|
| 960 |
+
inputs_embeds=hidden_states,
|
| 961 |
+
attention_mask=attention_mask,
|
| 962 |
+
output_attentions=output_attentions,
|
| 963 |
+
output_hidden_states=output_hidden_states,
|
| 964 |
+
return_dict=return_dict,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
last_hidden_state = encoder_outputs[0]
|
| 968 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
| 969 |
+
|
| 970 |
+
# Assuming "sticky" EOS tokenization, last token is always EOS.
|
| 971 |
+
pooled_output = last_hidden_state[:, -1, :]
|
| 972 |
+
pooled_output = self.head(pooled_output)
|
| 973 |
+
|
| 974 |
+
if not return_dict:
|
| 975 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 976 |
+
|
| 977 |
+
return BaseModelOutputWithPooling(
|
| 978 |
+
last_hidden_state=last_hidden_state,
|
| 979 |
+
pooler_output=pooled_output,
|
| 980 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 981 |
+
attentions=encoder_outputs.attentions,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
@add_start_docstrings(
|
| 986 |
+
"""The text model from SigLIP without any head or projection on top.""",
|
| 987 |
+
SIGLIP_START_DOCSTRING,
|
| 988 |
+
)
|
| 989 |
+
class SiglipTextModel(SiglipPreTrainedModel):
|
| 990 |
+
config_class = SiglipTextConfig
|
| 991 |
+
|
| 992 |
+
def __init__(self, config: SiglipTextConfig):
|
| 993 |
+
super().__init__(config)
|
| 994 |
+
self.text_model = SiglipTextTransformer(config)
|
| 995 |
+
# Initialize weights and apply final processing
|
| 996 |
+
self.post_init()
|
| 997 |
+
|
| 998 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 999 |
+
return self.text_model.embeddings.token_embedding
|
| 1000 |
+
|
| 1001 |
+
def set_input_embeddings(self, value):
|
| 1002 |
+
self.text_model.embeddings.token_embedding = value
|
| 1003 |
+
|
| 1004 |
+
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
|
| 1005 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1009 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1010 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1011 |
+
output_attentions: Optional[bool] = None,
|
| 1012 |
+
output_hidden_states: Optional[bool] = None,
|
| 1013 |
+
return_dict: Optional[bool] = None,
|
| 1014 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 1015 |
+
r"""
|
| 1016 |
+
Returns:
|
| 1017 |
+
|
| 1018 |
+
Examples:
|
| 1019 |
+
|
| 1020 |
+
```python
|
| 1021 |
+
>>> from transformers import AutoTokenizer, SiglipTextModel
|
| 1022 |
+
|
| 1023 |
+
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
|
| 1024 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
|
| 1025 |
+
|
| 1026 |
+
>>> # important: make sure to set padding="max_length" as that's how the model was trained
|
| 1027 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
|
| 1028 |
+
|
| 1029 |
+
>>> outputs = model(**inputs)
|
| 1030 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1031 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
| 1032 |
+
```"""
|
| 1033 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1034 |
+
|
| 1035 |
+
return self.text_model(
|
| 1036 |
+
input_ids=input_ids,
|
| 1037 |
+
attention_mask=attention_mask,
|
| 1038 |
+
position_ids=position_ids,
|
| 1039 |
+
output_attentions=output_attentions,
|
| 1040 |
+
output_hidden_states=output_hidden_states,
|
| 1041 |
+
return_dict=return_dict,
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class SiglipVisionTransformer(nn.Module):
|
| 1046 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 1047 |
+
super().__init__()
|
| 1048 |
+
self.config = config
|
| 1049 |
+
embed_dim = config.hidden_size
|
| 1050 |
+
|
| 1051 |
+
self.embeddings = SiglipVisionEmbeddings(config)
|
| 1052 |
+
self.encoder = SiglipEncoder(config)
|
| 1053 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 1054 |
+
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
|
| 1055 |
+
if self.use_head:
|
| 1056 |
+
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
| 1057 |
+
|
| 1058 |
+
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1059 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
|
| 1060 |
+
def forward(
|
| 1061 |
+
self,
|
| 1062 |
+
pixel_values,
|
| 1063 |
+
output_attentions: Optional[bool] = None,
|
| 1064 |
+
output_hidden_states: Optional[bool] = None,
|
| 1065 |
+
return_dict: Optional[bool] = None,
|
| 1066 |
+
interpolate_pos_encoding: Optional[bool] = False,
|
| 1067 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 1068 |
+
r"""
|
| 1069 |
+
Returns:
|
| 1070 |
+
|
| 1071 |
+
"""
|
| 1072 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1073 |
+
output_hidden_states = (
|
| 1074 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1075 |
+
)
|
| 1076 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1077 |
+
|
| 1078 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1079 |
+
|
| 1080 |
+
encoder_outputs = self.encoder(
|
| 1081 |
+
inputs_embeds=hidden_states,
|
| 1082 |
+
output_attentions=output_attentions,
|
| 1083 |
+
output_hidden_states=output_hidden_states,
|
| 1084 |
+
return_dict=return_dict,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
last_hidden_state = encoder_outputs[0]
|
| 1088 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 1089 |
+
|
| 1090 |
+
pooler_output = self.head(last_hidden_state) if self.use_head else None
|
| 1091 |
+
if not return_dict:
|
| 1092 |
+
return (last_hidden_state, pooler_output) + encoder_outputs[1:]
|
| 1093 |
+
|
| 1094 |
+
return BaseModelOutputWithPooling(
|
| 1095 |
+
last_hidden_state=last_hidden_state,
|
| 1096 |
+
pooler_output=pooler_output,
|
| 1097 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1098 |
+
attentions=encoder_outputs.attentions,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
| 1103 |
+
"""Multihead Attention Pooling."""
|
| 1104 |
+
|
| 1105 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 1106 |
+
super().__init__()
|
| 1107 |
+
|
| 1108 |
+
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
| 1109 |
+
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
| 1110 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1111 |
+
self.mlp = SiglipMLP(config)
|
| 1112 |
+
|
| 1113 |
+
def forward(self, hidden_state):
|
| 1114 |
+
batch_size = hidden_state.shape[0]
|
| 1115 |
+
probe = self.probe.repeat(batch_size, 1, 1)
|
| 1116 |
+
|
| 1117 |
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
| 1118 |
+
|
| 1119 |
+
residual = hidden_state
|
| 1120 |
+
hidden_state = self.layernorm(hidden_state)
|
| 1121 |
+
hidden_state = residual + self.mlp(hidden_state)
|
| 1122 |
+
|
| 1123 |
+
return hidden_state[:, 0]
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
@add_start_docstrings(
|
| 1127 |
+
"""The vision model from SigLIP without any head or projection on top.""",
|
| 1128 |
+
SIGLIP_START_DOCSTRING,
|
| 1129 |
+
)
|
| 1130 |
+
class SiglipVisionModel(SiglipPreTrainedModel):
|
| 1131 |
+
config_class = SiglipVisionConfig
|
| 1132 |
+
main_input_name = "pixel_values"
|
| 1133 |
+
|
| 1134 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 1135 |
+
super().__init__(config)
|
| 1136 |
+
|
| 1137 |
+
self.vision_model = SiglipVisionTransformer(config)
|
| 1138 |
+
|
| 1139 |
+
# Initialize weights and apply final processing
|
| 1140 |
+
self.post_init()
|
| 1141 |
+
|
| 1142 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 1143 |
+
return self.vision_model.embeddings.patch_embedding
|
| 1144 |
+
|
| 1145 |
+
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1146 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
|
| 1147 |
+
def forward(
|
| 1148 |
+
self,
|
| 1149 |
+
pixel_values,
|
| 1150 |
+
output_attentions: Optional[bool] = None,
|
| 1151 |
+
output_hidden_states: Optional[bool] = None,
|
| 1152 |
+
return_dict: Optional[bool] = None,
|
| 1153 |
+
interpolate_pos_encoding: bool = False,
|
| 1154 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 1155 |
+
r"""
|
| 1156 |
+
Returns:
|
| 1157 |
+
|
| 1158 |
+
Examples:
|
| 1159 |
+
|
| 1160 |
+
```python
|
| 1161 |
+
>>> from PIL import Image
|
| 1162 |
+
>>> import requests
|
| 1163 |
+
>>> from transformers import AutoProcessor, SiglipVisionModel
|
| 1164 |
+
|
| 1165 |
+
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
|
| 1166 |
+
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
| 1167 |
+
|
| 1168 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1169 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1170 |
+
|
| 1171 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1172 |
+
|
| 1173 |
+
>>> outputs = model(**inputs)
|
| 1174 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
| 1175 |
+
>>> pooled_output = outputs.pooler_output # pooled features
|
| 1176 |
+
```"""
|
| 1177 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1178 |
+
|
| 1179 |
+
return self.vision_model(
|
| 1180 |
+
pixel_values=pixel_values,
|
| 1181 |
+
output_attentions=output_attentions,
|
| 1182 |
+
output_hidden_states=output_hidden_states,
|
| 1183 |
+
return_dict=return_dict,
|
| 1184 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
|
| 1188 |
+
@add_start_docstrings(SIGLIP_START_DOCSTRING)
|
| 1189 |
+
class SiglipModel(SiglipPreTrainedModel):
|
| 1190 |
+
config_class = SiglipConfig
|
| 1191 |
+
|
| 1192 |
+
def __init__(self, config: SiglipConfig):
|
| 1193 |
+
super().__init__(config)
|
| 1194 |
+
|
| 1195 |
+
if not isinstance(config.text_config, SiglipTextConfig):
|
| 1196 |
+
raise TypeError(
|
| 1197 |
+
"config.text_config is expected to be of type SiglipTextConfig but is of type"
|
| 1198 |
+
f" {type(config.text_config)}."
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
if not isinstance(config.vision_config, SiglipVisionConfig):
|
| 1202 |
+
raise TypeError(
|
| 1203 |
+
"config.vision_config is expected to be of type SiglipVisionConfig but is of type"
|
| 1204 |
+
f" {type(config.vision_config)}."
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
text_config = config.text_config
|
| 1208 |
+
vision_config = config.vision_config
|
| 1209 |
+
|
| 1210 |
+
# First, initialize the text and vision models with proper attention implementation
|
| 1211 |
+
text_model = SiglipTextModel._from_config(text_config)
|
| 1212 |
+
vision_model = SiglipVisionModel._from_config(vision_config)
|
| 1213 |
+
|
| 1214 |
+
# Second, get the text and vision submodules (for backward compatibility)
|
| 1215 |
+
self.text_model = text_model.text_model
|
| 1216 |
+
self.vision_model = vision_model.vision_model
|
| 1217 |
+
|
| 1218 |
+
self.logit_scale = nn.Parameter(torch.randn(1))
|
| 1219 |
+
self.logit_bias = nn.Parameter(torch.randn(1))
|
| 1220 |
+
|
| 1221 |
+
# Initialize weights and apply final processing
|
| 1222 |
+
self.post_init()
|
| 1223 |
+
|
| 1224 |
+
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
|
| 1225 |
+
def get_text_features(
|
| 1226 |
+
self,
|
| 1227 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1228 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1229 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1230 |
+
output_attentions: Optional[bool] = None,
|
| 1231 |
+
output_hidden_states: Optional[bool] = None,
|
| 1232 |
+
return_dict: Optional[bool] = None,
|
| 1233 |
+
) -> torch.FloatTensor:
|
| 1234 |
+
r"""
|
| 1235 |
+
Returns:
|
| 1236 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
| 1237 |
+
applying the projection layer to the pooled output of [`SiglipTextModel`].
|
| 1238 |
+
|
| 1239 |
+
Examples:
|
| 1240 |
+
|
| 1241 |
+
```python
|
| 1242 |
+
>>> from transformers import AutoTokenizer, AutoModel
|
| 1243 |
+
>>> import torch
|
| 1244 |
+
|
| 1245 |
+
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
|
| 1246 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
|
| 1247 |
+
|
| 1248 |
+
>>> # important: make sure to set padding="max_length" as that's how the model was trained
|
| 1249 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
|
| 1250 |
+
>>> with torch.no_grad():
|
| 1251 |
+
... text_features = model.get_text_features(**inputs)
|
| 1252 |
+
```"""
|
| 1253 |
+
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 1254 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1255 |
+
output_hidden_states = (
|
| 1256 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1257 |
+
)
|
| 1258 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1259 |
+
|
| 1260 |
+
text_outputs = self.text_model(
|
| 1261 |
+
input_ids=input_ids,
|
| 1262 |
+
attention_mask=attention_mask,
|
| 1263 |
+
position_ids=position_ids,
|
| 1264 |
+
output_attentions=output_attentions,
|
| 1265 |
+
output_hidden_states=output_hidden_states,
|
| 1266 |
+
return_dict=return_dict,
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
pooled_output = text_outputs[1]
|
| 1270 |
+
|
| 1271 |
+
return pooled_output
|
| 1272 |
+
|
| 1273 |
+
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 1274 |
+
def get_image_features(
|
| 1275 |
+
self,
|
| 1276 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1277 |
+
output_attentions: Optional[bool] = None,
|
| 1278 |
+
output_hidden_states: Optional[bool] = None,
|
| 1279 |
+
return_dict: Optional[bool] = None,
|
| 1280 |
+
interpolate_pos_encoding: bool = False,
|
| 1281 |
+
) -> torch.FloatTensor:
|
| 1282 |
+
r"""
|
| 1283 |
+
Returns:
|
| 1284 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
| 1285 |
+
applying the projection layer to the pooled output of [`SiglipVisionModel`].
|
| 1286 |
+
|
| 1287 |
+
Examples:
|
| 1288 |
+
|
| 1289 |
+
```python
|
| 1290 |
+
>>> from PIL import Image
|
| 1291 |
+
>>> import requests
|
| 1292 |
+
>>> from transformers import AutoProcessor, AutoModel
|
| 1293 |
+
>>> import torch
|
| 1294 |
+
|
| 1295 |
+
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
|
| 1296 |
+
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
| 1297 |
+
|
| 1298 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1299 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1300 |
+
|
| 1301 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
| 1302 |
+
|
| 1303 |
+
>>> with torch.no_grad():
|
| 1304 |
+
... image_features = model.get_image_features(**inputs)
|
| 1305 |
+
```"""
|
| 1306 |
+
# Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
|
| 1307 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1308 |
+
output_hidden_states = (
|
| 1309 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1310 |
+
)
|
| 1311 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1312 |
+
|
| 1313 |
+
vision_outputs = self.vision_model(
|
| 1314 |
+
pixel_values=pixel_values,
|
| 1315 |
+
output_attentions=output_attentions,
|
| 1316 |
+
output_hidden_states=output_hidden_states,
|
| 1317 |
+
return_dict=return_dict,
|
| 1318 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
pooled_output = vision_outputs[1]
|
| 1322 |
+
|
| 1323 |
+
return pooled_output
|
| 1324 |
+
|
| 1325 |
+
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
|
| 1326 |
+
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
|
| 1327 |
+
def forward(
|
| 1328 |
+
self,
|
| 1329 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1330 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1331 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1332 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1333 |
+
return_loss: Optional[bool] = None,
|
| 1334 |
+
output_attentions: Optional[bool] = None,
|
| 1335 |
+
output_hidden_states: Optional[bool] = None,
|
| 1336 |
+
return_dict: Optional[bool] = None,
|
| 1337 |
+
interpolate_pos_encoding: bool = False,
|
| 1338 |
+
) -> Union[Tuple, SiglipOutput]:
|
| 1339 |
+
r"""
|
| 1340 |
+
Returns:
|
| 1341 |
+
|
| 1342 |
+
Examples:
|
| 1343 |
+
|
| 1344 |
+
```python
|
| 1345 |
+
>>> from PIL import Image
|
| 1346 |
+
>>> import requests
|
| 1347 |
+
>>> from transformers import AutoProcessor, AutoModel
|
| 1348 |
+
>>> import torch
|
| 1349 |
+
|
| 1350 |
+
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
|
| 1351 |
+
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
| 1352 |
+
|
| 1353 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1354 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1355 |
+
|
| 1356 |
+
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
|
| 1357 |
+
>>> # important: we pass `padding=max_length` since the model was trained with this
|
| 1358 |
+
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
|
| 1359 |
+
|
| 1360 |
+
>>> with torch.no_grad():
|
| 1361 |
+
... outputs = model(**inputs)
|
| 1362 |
+
|
| 1363 |
+
>>> logits_per_image = outputs.logits_per_image
|
| 1364 |
+
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
| 1365 |
+
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
|
| 1366 |
+
31.9% that image 0 is 'a photo of 2 cats'
|
| 1367 |
+
```"""
|
| 1368 |
+
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
| 1369 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1370 |
+
output_hidden_states = (
|
| 1371 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1372 |
+
)
|
| 1373 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1374 |
+
|
| 1375 |
+
vision_outputs = self.vision_model(
|
| 1376 |
+
pixel_values=pixel_values,
|
| 1377 |
+
output_attentions=output_attentions,
|
| 1378 |
+
output_hidden_states=output_hidden_states,
|
| 1379 |
+
return_dict=return_dict,
|
| 1380 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
text_outputs = self.text_model(
|
| 1384 |
+
input_ids=input_ids,
|
| 1385 |
+
attention_mask=attention_mask,
|
| 1386 |
+
position_ids=position_ids,
|
| 1387 |
+
output_attentions=output_attentions,
|
| 1388 |
+
output_hidden_states=output_hidden_states,
|
| 1389 |
+
return_dict=return_dict,
|
| 1390 |
+
)
|
| 1391 |
+
|
| 1392 |
+
image_embeds = vision_outputs[1]
|
| 1393 |
+
text_embeds = text_outputs[1]
|
| 1394 |
+
|
| 1395 |
+
# normalized features
|
| 1396 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1397 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 1398 |
+
|
| 1399 |
+
# cosine similarity as logits
|
| 1400 |
+
logits_per_text = (
|
| 1401 |
+
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
|
| 1402 |
+
+ self.logit_bias
|
| 1403 |
+
)
|
| 1404 |
+
logits_per_image = logits_per_text.t()
|
| 1405 |
+
|
| 1406 |
+
loss = None
|
| 1407 |
+
if return_loss:
|
| 1408 |
+
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
|
| 1409 |
+
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
|
| 1410 |
+
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
|
| 1411 |
+
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
|
| 1412 |
+
nll = -torch.sum(loglik, dim=-1)
|
| 1413 |
+
loss = nll.mean()
|
| 1414 |
+
|
| 1415 |
+
if not return_dict:
|
| 1416 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
| 1417 |
+
return ((loss,) + output) if loss is not None else output
|
| 1418 |
+
|
| 1419 |
+
return SiglipOutput(
|
| 1420 |
+
loss=loss,
|
| 1421 |
+
logits_per_image=logits_per_image,
|
| 1422 |
+
logits_per_text=logits_per_text,
|
| 1423 |
+
text_embeds=text_embeds,
|
| 1424 |
+
image_embeds=image_embeds,
|
| 1425 |
+
text_model_output=text_outputs,
|
| 1426 |
+
vision_model_output=vision_outputs,
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
|
| 1430 |
+
@add_start_docstrings(
|
| 1431 |
+
"""
|
| 1432 |
+
SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
|
| 1433 |
+
the patch tokens) e.g. for ImageNet.
|
| 1434 |
+
""",
|
| 1435 |
+
SIGLIP_START_DOCSTRING,
|
| 1436 |
+
)
|
| 1437 |
+
class SiglipForImageClassification(SiglipPreTrainedModel):
|
| 1438 |
+
main_input_name = "pixel_values"
|
| 1439 |
+
|
| 1440 |
+
def __init__(self, config: SiglipConfig) -> None:
|
| 1441 |
+
super().__init__(config)
|
| 1442 |
+
|
| 1443 |
+
self.num_labels = config.num_labels
|
| 1444 |
+
|
| 1445 |
+
# Create the vision model with proper attention
|
| 1446 |
+
# and take only vision_model submodule (for backward compatibility)
|
| 1447 |
+
vision_model = SiglipVisionModel._from_config(config.vision_config)
|
| 1448 |
+
self.vision_model = vision_model.vision_model
|
| 1449 |
+
|
| 1450 |
+
# Classifier head
|
| 1451 |
+
self.classifier = (
|
| 1452 |
+
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 1453 |
+
)
|
| 1454 |
+
|
| 1455 |
+
# Initialize weights and apply final processing
|
| 1456 |
+
self.post_init()
|
| 1457 |
+
|
| 1458 |
+
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
|
| 1459 |
+
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 1460 |
+
def forward(
|
| 1461 |
+
self,
|
| 1462 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1463 |
+
labels: Optional[torch.Tensor] = None,
|
| 1464 |
+
output_attentions: Optional[bool] = None,
|
| 1465 |
+
output_hidden_states: Optional[bool] = None,
|
| 1466 |
+
return_dict: Optional[bool] = None,
|
| 1467 |
+
interpolate_pos_encoding: bool = False,
|
| 1468 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 1469 |
+
r"""
|
| 1470 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1471 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 1472 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1473 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1474 |
+
|
| 1475 |
+
Returns:
|
| 1476 |
+
|
| 1477 |
+
Examples:
|
| 1478 |
+
|
| 1479 |
+
```python
|
| 1480 |
+
>>> from transformers import AutoImageProcessor, SiglipForImageClassification
|
| 1481 |
+
>>> import torch
|
| 1482 |
+
>>> from PIL import Image
|
| 1483 |
+
>>> import requests
|
| 1484 |
+
|
| 1485 |
+
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
|
| 1486 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1487 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1488 |
+
|
| 1489 |
+
>>> # note: we are loading a `SiglipModel` from the hub here,
|
| 1490 |
+
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
|
| 1491 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
|
| 1492 |
+
>>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
|
| 1493 |
+
|
| 1494 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1495 |
+
>>> outputs = model(**inputs)
|
| 1496 |
+
>>> logits = outputs.logits
|
| 1497 |
+
>>> # model predicts one of the two classes
|
| 1498 |
+
>>> predicted_class_idx = logits.argmax(-1).item()
|
| 1499 |
+
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
| 1500 |
+
Predicted class: LABEL_1
|
| 1501 |
+
```"""
|
| 1502 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1503 |
+
output_hidden_states = (
|
| 1504 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1505 |
+
)
|
| 1506 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1507 |
+
|
| 1508 |
+
outputs = self.vision_model(
|
| 1509 |
+
pixel_values,
|
| 1510 |
+
output_attentions=output_attentions,
|
| 1511 |
+
output_hidden_states=output_hidden_states,
|
| 1512 |
+
return_dict=return_dict,
|
| 1513 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
sequence_output = outputs[0]
|
| 1517 |
+
|
| 1518 |
+
# average pool the patch tokens
|
| 1519 |
+
sequence_output = torch.mean(sequence_output, dim=1)
|
| 1520 |
+
# apply classifier
|
| 1521 |
+
logits = self.classifier(sequence_output)
|
| 1522 |
+
|
| 1523 |
+
loss = None
|
| 1524 |
+
if labels is not None:
|
| 1525 |
+
# move labels to correct device to enable model parallelism
|
| 1526 |
+
labels = labels.to(logits.device)
|
| 1527 |
+
if self.config.problem_type is None:
|
| 1528 |
+
if self.num_labels == 1:
|
| 1529 |
+
self.config.problem_type = "regression"
|
| 1530 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1531 |
+
self.config.problem_type = "single_label_classification"
|
| 1532 |
+
else:
|
| 1533 |
+
self.config.problem_type = "multi_label_classification"
|
| 1534 |
+
|
| 1535 |
+
if self.config.problem_type == "regression":
|
| 1536 |
+
loss_fct = MSELoss()
|
| 1537 |
+
if self.num_labels == 1:
|
| 1538 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1539 |
+
else:
|
| 1540 |
+
loss = loss_fct(logits, labels)
|
| 1541 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1542 |
+
loss_fct = CrossEntropyLoss()
|
| 1543 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1544 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1545 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1546 |
+
loss = loss_fct(logits, labels)
|
| 1547 |
+
|
| 1548 |
+
if not return_dict:
|
| 1549 |
+
output = (logits,) + outputs[2:]
|
| 1550 |
+
return ((loss,) + output) if loss is not None else output
|
| 1551 |
+
|
| 1552 |
+
return ImageClassifierOutput(
|
| 1553 |
+
loss=loss,
|
| 1554 |
+
logits=logits,
|
| 1555 |
+
hidden_states=outputs.hidden_states,
|
| 1556 |
+
attentions=outputs.attentions,
|
| 1557 |
+
)
|
modeling/siglip/processing_siglip.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Image/Text processor class for SigLIP.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
|
| 10 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 11 |
+
from transformers.image_utils import ImageInput
|
| 12 |
+
from transformers.processing_utils import ProcessorMixin
|
| 13 |
+
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 14 |
+
from transformers.utils import TensorType
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SiglipProcessor(ProcessorMixin):
|
| 18 |
+
r"""
|
| 19 |
+
Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
|
| 20 |
+
|
| 21 |
+
[`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
|
| 22 |
+
[`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image_processor ([`SiglipImageProcessor`]):
|
| 26 |
+
The image processor is a required input.
|
| 27 |
+
tokenizer ([`SiglipTokenizer`]):
|
| 28 |
+
The tokenizer is a required input.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
attributes = ["image_processor", "tokenizer"]
|
| 32 |
+
image_processor_class = "SiglipImageProcessor"
|
| 33 |
+
tokenizer_class = "SiglipTokenizer"
|
| 34 |
+
|
| 35 |
+
def __init__(self, image_processor, tokenizer):
|
| 36 |
+
super().__init__(image_processor, tokenizer)
|
| 37 |
+
|
| 38 |
+
def __call__(
|
| 39 |
+
self,
|
| 40 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 41 |
+
images: ImageInput = None,
|
| 42 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 43 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 44 |
+
max_length: int = None,
|
| 45 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
| 46 |
+
) -> BatchFeature:
|
| 47 |
+
"""
|
| 48 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 49 |
+
and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
|
| 50 |
+
the text. To prepare the image(s), this method forwards the `images` argument to
|
| 51 |
+
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
| 52 |
+
of the above two methods for more information.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 56 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 57 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 58 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 59 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 60 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 61 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 62 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
| 63 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 64 |
+
index) among:
|
| 65 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 66 |
+
sequence if provided).
|
| 67 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 68 |
+
acceptable input length for the model if that argument is not provided.
|
| 69 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 70 |
+
lengths).
|
| 71 |
+
max_length (`int`, *optional*):
|
| 72 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 73 |
+
truncation (`bool`, *optional*):
|
| 74 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
| 75 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 76 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 77 |
+
|
| 78 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 79 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 80 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 81 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 85 |
+
|
| 86 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 87 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 88 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 89 |
+
`None`).
|
| 90 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
if text is None and images is None:
|
| 94 |
+
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
| 95 |
+
|
| 96 |
+
if text is not None:
|
| 97 |
+
encoding = self.tokenizer(
|
| 98 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if images is not None:
|
| 102 |
+
image_features = self.image_processor(images, return_tensors=return_tensors)
|
| 103 |
+
|
| 104 |
+
if text is not None and images is not None:
|
| 105 |
+
encoding["pixel_values"] = image_features.pixel_values
|
| 106 |
+
return encoding
|
| 107 |
+
elif text is not None:
|
| 108 |
+
return encoding
|
| 109 |
+
else:
|
| 110 |
+
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
|
| 111 |
+
|
| 112 |
+
def decode(self, *args, **kwargs):
|
| 113 |
+
"""
|
| 114 |
+
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 115 |
+
the docstring of this method for more information.
|
| 116 |
+
"""
|
| 117 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 118 |
+
|
| 119 |
+
def batch_decode(self, *args, **kwargs):
|
| 120 |
+
"""
|
| 121 |
+
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 122 |
+
refer to the docstring of this method for more information.
|
| 123 |
+
"""
|
| 124 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
|
| 128 |
+
def model_input_names(self):
|
| 129 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 130 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 131 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
modeling/siglip/tokenization_siglip.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Tokenization class for SigLIP model."""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import string
|
| 9 |
+
import warnings
|
| 10 |
+
from shutil import copyfile
|
| 11 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import sentencepiece as spm
|
| 14 |
+
|
| 15 |
+
from transformers.convert_slow_tokenizer import import_protobuf
|
| 16 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 17 |
+
from transformers.tokenization_utils_base import AddedToken
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from transformers.tokenization_utils_base import TextInput
|
| 22 |
+
from transformers.utils import logging, requires_backends
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
SPIECE_UNDERLINE = "▁"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SiglipTokenizer(PreTrainedTokenizer):
|
| 34 |
+
"""
|
| 35 |
+
Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 36 |
+
|
| 37 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 38 |
+
this superclass for more information regarding those methods.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vocab_file (`str`):
|
| 42 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 43 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 44 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 45 |
+
The end of sequence token.
|
| 46 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
pad_token (`str`, *optional*, defaults to `"</s>"`):
|
| 50 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 51 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 52 |
+
Additional special tokens used by the tokenizer.
|
| 53 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 54 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 55 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 56 |
+
to set:
|
| 57 |
+
|
| 58 |
+
- `enable_sampling`: Enable subword regularization.
|
| 59 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 60 |
+
|
| 61 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 62 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 63 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 64 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 65 |
+
|
| 66 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 67 |
+
BPE-dropout.
|
| 68 |
+
model_max_length (`int`, *optional*, defaults to 64):
|
| 69 |
+
The maximum length (in number of tokens) for model inputs.
|
| 70 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 71 |
+
Whether or not to lowercase the input when tokenizing.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 75 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
vocab_file,
|
| 80 |
+
eos_token="</s>",
|
| 81 |
+
unk_token="<unk>",
|
| 82 |
+
pad_token="</s>",
|
| 83 |
+
additional_special_tokens=None,
|
| 84 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 85 |
+
model_max_length=64,
|
| 86 |
+
do_lower_case=True,
|
| 87 |
+
**kwargs,
|
| 88 |
+
) -> None:
|
| 89 |
+
requires_backends(self, "protobuf")
|
| 90 |
+
|
| 91 |
+
pad_token = (
|
| 92 |
+
AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
|
| 93 |
+
if isinstance(pad_token, str)
|
| 94 |
+
else pad_token
|
| 95 |
+
)
|
| 96 |
+
unk_token = (
|
| 97 |
+
AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
|
| 98 |
+
if isinstance(unk_token, str)
|
| 99 |
+
else unk_token
|
| 100 |
+
)
|
| 101 |
+
eos_token = (
|
| 102 |
+
AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
|
| 103 |
+
if isinstance(eos_token, str)
|
| 104 |
+
else eos_token
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 108 |
+
|
| 109 |
+
self.do_lower_case = do_lower_case
|
| 110 |
+
self.vocab_file = vocab_file
|
| 111 |
+
|
| 112 |
+
self.sp_model = self.get_spm_processor()
|
| 113 |
+
self.vocab_file = vocab_file
|
| 114 |
+
|
| 115 |
+
super().__init__(
|
| 116 |
+
eos_token=eos_token,
|
| 117 |
+
unk_token=unk_token,
|
| 118 |
+
pad_token=pad_token,
|
| 119 |
+
additional_special_tokens=additional_special_tokens,
|
| 120 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 121 |
+
model_max_length=model_max_length,
|
| 122 |
+
do_lower_case=do_lower_case,
|
| 123 |
+
**kwargs,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def get_spm_processor(self):
|
| 127 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 128 |
+
with open(self.vocab_file, "rb") as f:
|
| 129 |
+
sp_model = f.read()
|
| 130 |
+
model_pb2 = import_protobuf()
|
| 131 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
| 132 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
| 133 |
+
normalizer_spec.add_dummy_prefix = False
|
| 134 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
| 135 |
+
sp_model = model.SerializeToString()
|
| 136 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
| 137 |
+
return tokenizer
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
|
| 141 |
+
def vocab_size(self):
|
| 142 |
+
return self.sp_model.get_piece_size()
|
| 143 |
+
|
| 144 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
|
| 145 |
+
def get_vocab(self):
|
| 146 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 147 |
+
vocab.update(self.added_tokens_encoder)
|
| 148 |
+
return vocab
|
| 149 |
+
|
| 150 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
|
| 151 |
+
def get_special_tokens_mask(
|
| 152 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 153 |
+
) -> List[int]:
|
| 154 |
+
"""
|
| 155 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 156 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
token_ids_0 (`List[int]`):
|
| 160 |
+
List of IDs.
|
| 161 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 162 |
+
Optional second list of IDs for sequence pairs.
|
| 163 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 164 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 168 |
+
"""
|
| 169 |
+
if already_has_special_tokens:
|
| 170 |
+
return super().get_special_tokens_mask(
|
| 171 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# normal case: some special tokens
|
| 175 |
+
if token_ids_1 is None:
|
| 176 |
+
return ([0] * len(token_ids_0)) + [1]
|
| 177 |
+
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 178 |
+
|
| 179 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
|
| 180 |
+
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
| 181 |
+
"""Do not add eos again if user already added it."""
|
| 182 |
+
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
| 183 |
+
warnings.warn(
|
| 184 |
+
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
|
| 185 |
+
" eos tokens being added."
|
| 186 |
+
)
|
| 187 |
+
return token_ids
|
| 188 |
+
else:
|
| 189 |
+
return token_ids + [self.eos_token_id]
|
| 190 |
+
|
| 191 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
|
| 192 |
+
def create_token_type_ids_from_sequences(
|
| 193 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 194 |
+
) -> List[int]:
|
| 195 |
+
"""
|
| 196 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
| 197 |
+
use of token type ids, therefore a list of zeros is returned.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
token_ids_0 (`List[int]`):
|
| 201 |
+
List of IDs.
|
| 202 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 203 |
+
Optional second list of IDs for sequence pairs.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
`List[int]`: List of zeros.
|
| 207 |
+
"""
|
| 208 |
+
eos = [self.eos_token_id]
|
| 209 |
+
|
| 210 |
+
if token_ids_1 is None:
|
| 211 |
+
return len(token_ids_0 + eos) * [0]
|
| 212 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
| 213 |
+
|
| 214 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
|
| 215 |
+
def build_inputs_with_special_tokens(
|
| 216 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 217 |
+
) -> List[int]:
|
| 218 |
+
"""
|
| 219 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 220 |
+
adding special tokens. A sequence has the following format:
|
| 221 |
+
|
| 222 |
+
- single sequence: `X </s>`
|
| 223 |
+
- pair of sequences: `A </s> B </s>`
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
token_ids_0 (`List[int]`):
|
| 227 |
+
List of IDs to which the special tokens will be added.
|
| 228 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 229 |
+
Optional second list of IDs for sequence pairs.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 233 |
+
"""
|
| 234 |
+
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
|
| 235 |
+
if token_ids_1 is None:
|
| 236 |
+
return token_ids_0
|
| 237 |
+
else:
|
| 238 |
+
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
|
| 239 |
+
return token_ids_0 + token_ids_1
|
| 240 |
+
|
| 241 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
|
| 242 |
+
def __getstate__(self):
|
| 243 |
+
state = self.__dict__.copy()
|
| 244 |
+
state["sp_model"] = None
|
| 245 |
+
return state
|
| 246 |
+
|
| 247 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
|
| 248 |
+
def __setstate__(self, d):
|
| 249 |
+
self.__dict__ = d
|
| 250 |
+
|
| 251 |
+
# for backward compatibility
|
| 252 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 253 |
+
self.sp_model_kwargs = {}
|
| 254 |
+
|
| 255 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 256 |
+
self.sp_model.Load(self.vocab_file)
|
| 257 |
+
|
| 258 |
+
def remove_punctuation(self, text: str) -> str:
|
| 259 |
+
return text.translate(str.maketrans("", "", string.punctuation))
|
| 260 |
+
|
| 261 |
+
# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
|
| 262 |
+
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
|
| 263 |
+
"""Returns canonicalized `text` (puncuation removed).
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
text (`str`):
|
| 267 |
+
String to be canonicalized.
|
| 268 |
+
keep_punctuation_exact_string (`str`, *optional*):
|
| 269 |
+
If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
|
| 270 |
+
(but will still remove '{' and '}' that appear separately).
|
| 271 |
+
"""
|
| 272 |
+
if keep_punctuation_exact_string:
|
| 273 |
+
text = keep_punctuation_exact_string.join(
|
| 274 |
+
self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
text = self.remove_punctuation(text)
|
| 278 |
+
text = re.sub(r"\s+", " ", text)
|
| 279 |
+
text = text.strip()
|
| 280 |
+
|
| 281 |
+
return text
|
| 282 |
+
|
| 283 |
+
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
|
| 284 |
+
"""
|
| 285 |
+
Converts a string to a list of tokens.
|
| 286 |
+
"""
|
| 287 |
+
tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
|
| 288 |
+
|
| 289 |
+
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
| 290 |
+
tokens = tokens[1:]
|
| 291 |
+
return tokens
|
| 292 |
+
|
| 293 |
+
@property
|
| 294 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
|
| 295 |
+
def unk_token_length(self):
|
| 296 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
| 297 |
+
|
| 298 |
+
def _tokenize(self, text, **kwargs):
|
| 299 |
+
"""
|
| 300 |
+
Returns a tokenized string.
|
| 301 |
+
|
| 302 |
+
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
| 303 |
+
SPIECE_UNDERLINE.
|
| 304 |
+
|
| 305 |
+
For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
|
| 306 |
+
|
| 307 |
+
Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
| 308 |
+
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
| 309 |
+
"""
|
| 310 |
+
text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
|
| 311 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
| 312 |
+
|
| 313 |
+
# 1. Encode string + prefix ex: "<unk> Hey"
|
| 314 |
+
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
| 315 |
+
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
| 316 |
+
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
| 317 |
+
|
| 318 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
|
| 319 |
+
def _convert_token_to_id(self, token):
|
| 320 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 321 |
+
return self.sp_model.piece_to_id(token)
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
|
| 324 |
+
def _convert_id_to_token(self, index):
|
| 325 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 326 |
+
token = self.sp_model.IdToPiece(index)
|
| 327 |
+
return token
|
| 328 |
+
|
| 329 |
+
def convert_tokens_to_string(self, tokens):
|
| 330 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 331 |
+
current_sub_tokens = []
|
| 332 |
+
out_string = ""
|
| 333 |
+
prev_is_special = False
|
| 334 |
+
for token in tokens:
|
| 335 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 336 |
+
if token in self.all_special_tokens:
|
| 337 |
+
if not prev_is_special:
|
| 338 |
+
out_string += " "
|
| 339 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 340 |
+
prev_is_special = True
|
| 341 |
+
current_sub_tokens = []
|
| 342 |
+
else:
|
| 343 |
+
current_sub_tokens.append(token)
|
| 344 |
+
prev_is_special = False
|
| 345 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 346 |
+
return out_string.strip()
|
| 347 |
+
|
| 348 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
|
| 349 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 350 |
+
if not os.path.isdir(save_directory):
|
| 351 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 352 |
+
return
|
| 353 |
+
out_vocab_file = os.path.join(
|
| 354 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 358 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 359 |
+
elif not os.path.isfile(self.vocab_file):
|
| 360 |
+
with open(out_vocab_file, "wb") as fi:
|
| 361 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 362 |
+
fi.write(content_spiece_model)
|
| 363 |
+
|
| 364 |
+
return (out_vocab_file,)
|
run.err
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id h200-zebra-cot-20251025_211359-run0.
|
| 6 |
+
[rank2]:[W1025 21:14:15.369652204 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 7 |
+
[rank7]:[W1025 21:14:15.502472578 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 7] using GPU 7 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 8 |
+
[rank5]:[W1025 21:14:15.521361526 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 5] using GPU 5 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 9 |
+
[rank4]:[W1025 21:14:15.539230512 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 4] using GPU 4 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 10 |
+
[rank1]:[W1025 21:14:15.559660446 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 11 |
+
[rank3]:[W1025 21:14:15.636618409 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 12 |
+
[rank6]:[W1025 21:14:15.814060558 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 6] using GPU 6 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 13 |
+
wandb: Tracking run with wandb version 0.22.2
|
| 14 |
+
wandb: W&B syncing is set to `offline` in this directory. Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
|
| 15 |
+
wandb: Run data is saved locally in /scratch/by2593/Bagel-Zebra-CoT-origin/wandb/offline-run-20251025_211414-h200-zebra-cot-20251025_211359-run0
|
| 16 |
+
wandb: Detected [huggingface_hub.inference] in use.
|
| 17 |
+
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
|
| 18 |
+
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/
|
| 19 |
+
[rank0]:[W1025 21:14:16.181889866 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 20 |
+
[[34m2025-10-25 21:14:20[0m] Training arguments TrainingArguments(visual_gen=True, visual_und=True, results_dir='results/', checkpoint_dir='results/checkpoints_smm_semantic_part1_v1_origin/', wandb_project='zebra-cot', wandb_name='h200-zebra-cot-20251025_211359', wandb_runid='0', wandb_resume='allow', wandb_offline=True, global_seed=4396, auto_resume=True, resume_from='/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8', resume_model_only=True, finetune_from_ema=False, finetune_from_hf=True, log_every=1, save_every=50, total_steps=5000, warmup_steps=50, lr_scheduler='cosine', lr=2e-05, min_lr=1e-06, beta1=0.9, beta2=0.95, eps=1e-08, ema=0.9999, max_grad_norm=1.0, timestep_shift=1.0, mse_weight=1.0, ce_weight=1.0, ce_loss_reweighting=False, expected_num_tokens=40000, num_replicate=1, num_shard=8, sharding_strategy='HYBRID_SHARD', backward_prefetch='BACKWARD_PRE', cpu_offload=True, freeze_llm=False, freeze_vit=False, freeze_vae=True, freeze_und=False, copy_init_moe=True, use_flex=False)
|
| 21 |
+
[[34m2025-10-25 21:14:20[0m] Model arguments ModelArguments(model_path='/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8', llm_path='hf/Qwen2.5-0.5B-Instruct/', llm_qk_norm=True, tie_word_embeddings=False, layer_module='Qwen2MoTDecoderLayer', vae_path='flux/vae/ae.safetensors', vit_path='hf/siglip-so400m-14-980-flash-attn2-navit/', max_latent_size=64, latent_patch_size=2, vit_patch_size=14, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', interpolate_pos=False, vit_select_layer=-2, vit_rope=False, text_cond_dropout_prob=0.1, vae_cond_dropout_prob=0.3, vit_cond_dropout_prob=0.3)
|
| 22 |
+
[[34m2025-10-25 21:14:20[0m] Data arguments DataArguments(dataset_config_file='./data/configs/example_smm_semantic.yaml', prefetch_factor=2, num_workers=1, max_num_tokens_per_sample=40000, max_num_tokens=40000, prefer_buffer_before=10000, max_buffer_size=50, data_seed=42)
|
| 23 |
+
[[34m2025-10-25 21:16:50[0m] Loading checkpoint from /scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8.
|
| 24 |
+
[[34m2025-10-25 21:18:10[0m] _IncompatibleKeys(missing_keys=['latent_pos_embed.pos_embed', 'vit_pos_embed.pos_embed'], unexpected_keys=[])
|
| 25 |
+
[[34m2025-10-25 21:18:10[0m] replicaing ema model from /scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8/model_bf16.safetensors.
|
| 26 |
+
[[34m2025-10-25 21:18:20[0m] _IncompatibleKeys(missing_keys=['latent_pos_embed.pos_embed', 'vit_pos_embed.pos_embed'], unexpected_keys=[])
|
| 27 |
+
[[34m2025-10-25 21:18:51[0m] Training for 5000 steps, starting at 0...
|
| 28 |
+
[[34m2025-10-25 21:20:20[0m] (step=0000000) Train Loss mse: 0.0185, Train Loss ce: 1.8625, Train Steps/Sec: 0.01,
|
| 29 |
+
[[34m2025-10-25 21:20:57[0m] (step=0000001) Train Loss mse: 0.0168, Train Loss ce: 1.8560, Train Steps/Sec: 0.03,
|
| 30 |
+
[[34m2025-10-25 21:21:32[0m] (step=0000002) Train Loss mse: 0.0208, Train Loss ce: 1.8139, Train Steps/Sec: 0.03,
|
| 31 |
+
[[34m2025-10-25 21:22:13[0m] (step=0000003) Train Loss mse: 0.0200, Train Loss ce: 1.6772, Train Steps/Sec: 0.02,
|
| 32 |
+
[[34m2025-10-25 21:22:49[0m] (step=0000004) Train Loss mse: 0.0164, Train Loss ce: 1.7684, Train Steps/Sec: 0.03,
|
| 33 |
+
[[34m2025-10-25 21:23:31[0m] (step=0000005) Train Loss mse: 0.0199, Train Loss ce: 1.8439, Train Steps/Sec: 0.02,
|
| 34 |
+
[[34m2025-10-25 21:24:04[0m] (step=0000006) Train Loss mse: 0.0166, Train Loss ce: 1.6152, Train Steps/Sec: 0.03,
|
| 35 |
+
[[34m2025-10-25 21:24:40[0m] (step=0000007) Train Loss mse: 0.0181, Train Loss ce: 1.7539, Train Steps/Sec: 0.03,
|
| 36 |
+
[[34m2025-10-25 21:25:15[0m] (step=0000008) Train Loss mse: 0.0164, Train Loss ce: 1.7400, Train Steps/Sec: 0.03,
|
| 37 |
+
[[34m2025-10-25 21:25:49[0m] (step=0000009) Train Loss mse: 0.0167, Train Loss ce: 1.8076, Train Steps/Sec: 0.03,
|
| 38 |
+
[[34m2025-10-25 21:26:25[0m] (step=0000010) Train Loss mse: 0.0233, Train Loss ce: 1.4616, Train Steps/Sec: 0.03,
|
| 39 |
+
[[34m2025-10-25 21:26:56[0m] (step=0000011) Train Loss mse: 0.0168, Train Loss ce: 1.6259, Train Steps/Sec: 0.03,
|
| 40 |
+
[[34m2025-10-25 21:27:37[0m] (step=0000012) Train Loss mse: 0.0170, Train Loss ce: 1.5824, Train Steps/Sec: 0.02,
|
| 41 |
+
[[34m2025-10-25 21:28:08[0m] (step=0000013) Train Loss mse: 0.0189, Train Loss ce: 1.5811, Train Steps/Sec: 0.03,
|
| 42 |
+
[[34m2025-10-25 21:28:42[0m] (step=0000014) Train Loss mse: 0.0221, Train Loss ce: 1.2260, Train Steps/Sec: 0.03,
|
| 43 |
+
[[34m2025-10-25 21:29:16[0m] (step=0000015) Train Loss mse: 0.0140, Train Loss ce: 1.1394, Train Steps/Sec: 0.03,
|
| 44 |
+
[[34m2025-10-25 21:29:49[0m] (step=0000016) Train Loss mse: 0.0163, Train Loss ce: 1.1381, Train Steps/Sec: 0.03,
|
| 45 |
+
[[34m2025-10-25 21:30:26[0m] (step=0000017) Train Loss mse: 0.0229, Train Loss ce: 1.0493, Train Steps/Sec: 0.03,
|
| 46 |
+
[[34m2025-10-25 21:31:02[0m] (step=0000018) Train Loss mse: 0.0169, Train Loss ce: 1.0484, Train Steps/Sec: 0.03,
|
| 47 |
+
[[34m2025-10-25 21:31:43[0m] (step=0000019) Train Loss mse: 0.0187, Train Loss ce: 0.5945, Train Steps/Sec: 0.02,
|
| 48 |
+
[[34m2025-10-25 21:32:19[0m] (step=0000020) Train Loss mse: 0.0158, Train Loss ce: 0.6128, Train Steps/Sec: 0.03,
|
| 49 |
+
[[34m2025-10-25 21:33:00[0m] (step=0000021) Train Loss mse: 0.0157, Train Loss ce: 0.4668, Train Steps/Sec: 0.02,
|
| 50 |
+
[[34m2025-10-25 21:33:33[0m] (step=0000022) Train Loss mse: 0.0181, Train Loss ce: 0.4042, Train Steps/Sec: 0.03,
|
| 51 |
+
[[34m2025-10-25 21:34:07[0m] (step=0000023) Train Loss mse: 0.0209, Train Loss ce: 0.2930, Train Steps/Sec: 0.03,
|
| 52 |
+
[[34m2025-10-25 21:34:40[0m] (step=0000024) Train Loss mse: 0.0190, Train Loss ce: 0.2934, Train Steps/Sec: 0.03,
|
| 53 |
+
[[34m2025-10-25 21:35:16[0m] (step=0000025) Train Loss mse: 0.0144, Train Loss ce: 0.2189, Train Steps/Sec: 0.03,
|
| 54 |
+
[[34m2025-10-25 21:35:49[0m] (step=0000026) Train Loss mse: 0.0185, Train Loss ce: 0.1414, Train Steps/Sec: 0.03,
|
| 55 |
+
[[34m2025-10-25 21:36:22[0m] (step=0000027) Train Loss mse: 0.0166, Train Loss ce: 0.1090, Train Steps/Sec: 0.03,
|
| 56 |
+
[[34m2025-10-25 21:36:59[0m] (step=0000028) Train Loss mse: 0.0202, Train Loss ce: 0.1350, Train Steps/Sec: 0.03,
|
| 57 |
+
[[34m2025-10-25 21:37:36[0m] (step=0000029) Train Loss mse: 0.0175, Train Loss ce: 0.1263, Train Steps/Sec: 0.03,
|
| 58 |
+
[[34m2025-10-25 21:38:11[0m] (step=0000030) Train Loss mse: 0.0165, Train Loss ce: 0.0860, Train Steps/Sec: 0.03,
|
| 59 |
+
[[34m2025-10-25 21:38:47[0m] (step=0000031) Train Loss mse: 0.0169, Train Loss ce: 0.0864, Train Steps/Sec: 0.03,
|
| 60 |
+
[[34m2025-10-25 21:39:20[0m] (step=0000032) Train Loss mse: 0.0218, Train Loss ce: 0.0792, Train Steps/Sec: 0.03,
|
| 61 |
+
[[34m2025-10-25 21:39:57[0m] (step=0000033) Train Loss mse: 0.0203, Train Loss ce: 0.0852, Train Steps/Sec: 0.03,
|
| 62 |
+
[[34m2025-10-25 21:40:30[0m] (step=0000034) Train Loss mse: 0.0200, Train Loss ce: 0.0734, Train Steps/Sec: 0.03,
|
| 63 |
+
[[34m2025-10-25 21:41:07[0m] (step=0000035) Train Loss mse: 0.0166, Train Loss ce: 0.0830, Train Steps/Sec: 0.03,
|
| 64 |
+
[[34m2025-10-25 21:41:42[0m] (step=0000036) Train Loss mse: 0.0167, Train Loss ce: 0.0776, Train Steps/Sec: 0.03,
|
| 65 |
+
[[34m2025-10-25 21:42:14[0m] (step=0000037) Train Loss mse: 0.0175, Train Loss ce: 0.0556, Train Steps/Sec: 0.03,
|
| 66 |
+
[[34m2025-10-25 21:42:51[0m] (step=0000038) Train Loss mse: 0.0176, Train Loss ce: 0.0520, Train Steps/Sec: 0.03,
|
| 67 |
+
[[34m2025-10-25 21:43:23[0m] (step=0000039) Train Loss mse: 0.0144, Train Loss ce: 0.0607, Train Steps/Sec: 0.03,
|
| 68 |
+
[[34m2025-10-25 21:43:59[0m] (step=0000040) Train Loss mse: 0.0151, Train Loss ce: 0.0683, Train Steps/Sec: 0.03,
|
| 69 |
+
[[34m2025-10-25 21:44:32[0m] (step=0000041) Train Loss mse: 0.0180, Train Loss ce: 0.0456, Train Steps/Sec: 0.03,
|
| 70 |
+
[[34m2025-10-25 21:45:08[0m] (step=0000042) Train Loss mse: 0.0157, Train Loss ce: 0.0620, Train Steps/Sec: 0.03,
|
| 71 |
+
[[34m2025-10-25 21:45:51[0m] (step=0000043) Train Loss mse: 0.0167, Train Loss ce: 0.0552, Train Steps/Sec: 0.02,
|
| 72 |
+
[[34m2025-10-25 21:46:28[0m] (step=0000044) Train Loss mse: 0.0143, Train Loss ce: 0.0522, Train Steps/Sec: 0.03,
|
| 73 |
+
[[34m2025-10-25 21:47:08[0m] (step=0000045) Train Loss mse: 0.0159, Train Loss ce: 0.0494, Train Steps/Sec: 0.02,
|
| 74 |
+
[[34m2025-10-25 21:47:41[0m] (step=0000046) Train Loss mse: 0.0160, Train Loss ce: 0.0484, Train Steps/Sec: 0.03,
|
| 75 |
+
[[34m2025-10-25 21:48:14[0m] (step=0000047) Train Loss mse: 0.0187, Train Loss ce: 0.0599, Train Steps/Sec: 0.03,
|
| 76 |
+
[[34m2025-10-25 21:48:52[0m] (step=0000048) Train Loss mse: 0.0173, Train Loss ce: 0.0629, Train Steps/Sec: 0.03,
|
| 77 |
+
[[34m2025-10-25 21:49:26[0m] (step=0000049) Train Loss mse: 0.0167, Train Loss ce: 0.0466, Train Steps/Sec: 0.03,
|
| 78 |
+
[[34m2025-10-25 21:50:00[0m] (step=0000050) Train Loss mse: 0.0150, Train Loss ce: 0.0540, Train Steps/Sec: 0.03,
|
| 79 |
+
[[34m2025-10-25 21:50:01[0m] Saving checkpoint to results/checkpoints_smm_semantic_part1_v1_origin/0000050.
|
| 80 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 81 |
+
warnings.warn(
|
| 82 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 83 |
+
warnings.warn(
|
| 84 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 85 |
+
warnings.warn(
|
| 86 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 87 |
+
warnings.warn(
|
| 88 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 89 |
+
warnings.warn(
|
| 90 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 91 |
+
warnings.warn(
|
| 92 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 93 |
+
warnings.warn(
|
| 94 |
+
/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
|
| 95 |
+
warnings.warn(
|
| 96 |
+
[[34m2025-10-25 21:55:05[0m] Sorted checkpoint directories: ['0000050']
|
| 97 |
+
[[34m2025-10-25 21:55:40[0m] (step=0000051) Train Loss mse: 0.0139, Train Loss ce: 0.0539, Train Steps/Sec: 0.00,
|
| 98 |
+
[[34m2025-10-25 21:56:13[0m] (step=0000052) Train Loss mse: 0.0176, Train Loss ce: 0.0495, Train Steps/Sec: 0.03,
|
| 99 |
+
[[34m2025-10-25 21:56:51[0m] (step=0000053) Train Loss mse: 0.0168, Train Loss ce: 0.0485, Train Steps/Sec: 0.03,
|
| 100 |
+
[[34m2025-10-25 21:57:23[0m] (step=0000054) Train Loss mse: 0.0151, Train Loss ce: 0.0446, Train Steps/Sec: 0.03,
|
| 101 |
+
[[34m2025-10-25 21:58:00[0m] (step=0000055) Train Loss mse: 0.0144, Train Loss ce: 0.0490, Train Steps/Sec: 0.03,
|
| 102 |
+
[[34m2025-10-25 21:58:37[0m] (step=0000056) Train Loss mse: 0.0143, Train Loss ce: 0.0461, Train Steps/Sec: 0.03,
|
| 103 |
+
[[34m2025-10-25 21:59:11[0m] (step=0000057) Train Loss mse: 0.0152, Train Loss ce: 0.0459, Train Steps/Sec: 0.03,
|
| 104 |
+
[[34m2025-10-25 21:59:48[0m] (step=0000058) Train Loss mse: 0.0152, Train Loss ce: 0.0402, Train Steps/Sec: 0.03,
|
| 105 |
+
[[34m2025-10-25 22:00:22[0m] (step=0000059) Train Loss mse: 0.0145, Train Loss ce: 0.0566, Train Steps/Sec: 0.03,
|
| 106 |
+
[[34m2025-10-25 22:00:59[0m] (step=0000060) Train Loss mse: 0.0174, Train Loss ce: 0.0509, Train Steps/Sec: 0.03,
|
| 107 |
+
[rank6]: Traceback (most recent call last):
|
| 108 |
+
[rank6]: File "/scratch/by2593/Bagel-Zebra-CoT-origin/train/pretrain_unified_navit.py", line 727, in <module>
|
| 109 |
+
[rank6]: main()
|
| 110 |
+
[rank6]: File "/scratch/by2593/Bagel-Zebra-CoT-origin/train/pretrain_unified_navit.py", line 609, in main
|
| 111 |
+
[rank6]: assert not training_args.visual_und
|
| 112 |
+
[rank6]: AssertionError
|
| 113 |
+
[rank6]:[W1025 22:01:04.973896433 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
|
| 114 |
+
W1025 22:01:11.227000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808294 closing signal SIGTERM
|
| 115 |
+
W1025 22:01:11.264000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808295 closing signal SIGTERM
|
| 116 |
+
W1025 22:01:11.265000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808296 closing signal SIGTERM
|
| 117 |
+
W1025 22:01:11.271000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808297 closing signal SIGTERM
|
| 118 |
+
W1025 22:01:11.314000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808298 closing signal SIGTERM
|
| 119 |
+
W1025 22:01:11.332000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808299 closing signal SIGTERM
|
| 120 |
+
W1025 22:01:11.357000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808301 closing signal SIGTERM
|
| 121 |
+
E1025 22:01:37.654000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 6 (pid: 2808300) of binary: /scratch/by2593/miniconda3/envs/bagel/bin/python3.10
|
| 122 |
+
Traceback (most recent call last):
|
| 123 |
+
File "/scratch/by2593/miniconda3/envs/bagel/bin/torchrun", line 7, in <module>
|
| 124 |
+
sys.exit(main())
|
| 125 |
+
File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 126 |
+
return f(*args, **kwargs)
|
| 127 |
+
File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
|
| 128 |
+
run(args)
|
| 129 |
+
File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
|
| 130 |
+
elastic_launch(
|
| 131 |
+
File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 132 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 133 |
+
File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
|
| 134 |
+
raise ChildFailedError(
|
| 135 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 136 |
+
============================================================
|
| 137 |
+
train/pretrain_unified_navit.py FAILED
|
| 138 |
+
------------------------------------------------------------
|
| 139 |
+
Failures:
|
| 140 |
+
<NO_OTHER_FAILURES>
|
| 141 |
+
------------------------------------------------------------
|
| 142 |
+
Root Cause (first observed failure):
|
| 143 |
+
[0]:
|
| 144 |
+
time : 2025-10-25_22:01:11
|
| 145 |
+
host : gh129.hpc.nyu.edu
|
| 146 |
+
rank : 6 (local_rank: 6)
|
| 147 |
+
exitcode : 1 (pid: 2808300)
|
| 148 |
+
error_file: <N/A>
|
| 149 |
+
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
| 150 |
+
============================================================
|
run.out
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 2 |
+
rank-3 worker-0 dataset-block_dataset: resuming data at row#0
|
| 3 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 4 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 5 |
+
rank-6 worker-0 dataset-block_dataset: resuming data at row#0
|
| 6 |
+
rank-4 worker-0 dataset-block_dataset: resuming data at row#0
|
| 7 |
+
FullyShardedDataParallel(
|
| 8 |
+
(_fsdp_wrapped_module): Bagel(
|
| 9 |
+
(language_model): Qwen2ForCausalLM(
|
| 10 |
+
(model): Qwen2Model(
|
| 11 |
+
(embed_tokens): Embedding(152064, 3584)
|
| 12 |
+
(layers): ModuleList(
|
| 13 |
+
(0-27): 28 x FullyShardedDataParallel(
|
| 14 |
+
(_fsdp_wrapped_module): CheckpointWrapper(
|
| 15 |
+
(_checkpoint_wrapped_module): Qwen2MoTDecoderLayer(
|
| 16 |
+
(self_attn): PackedAttentionMoT(
|
| 17 |
+
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
|
| 18 |
+
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
|
| 19 |
+
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
|
| 20 |
+
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
|
| 21 |
+
(q_norm): Qwen2RMSNorm((128,), eps=1e-06)
|
| 22 |
+
(k_norm): Qwen2RMSNorm((128,), eps=1e-06)
|
| 23 |
+
(q_norm_moe_gen): Qwen2RMSNorm((128,), eps=1e-06)
|
| 24 |
+
(k_norm_moe_gen): Qwen2RMSNorm((128,), eps=1e-06)
|
| 25 |
+
(q_proj_moe_gen): Linear(in_features=3584, out_features=3584, bias=True)
|
| 26 |
+
(k_proj_moe_gen): Linear(in_features=3584, out_features=512, bias=True)
|
| 27 |
+
(v_proj_moe_gen): Linear(in_features=3584, out_features=512, bias=True)
|
| 28 |
+
(o_proj_moe_gen): Linear(in_features=3584, out_features=3584, bias=False)
|
| 29 |
+
)
|
| 30 |
+
(mlp): Qwen2MLP(
|
| 31 |
+
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
|
| 32 |
+
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
|
| 33 |
+
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
|
| 34 |
+
(act_fn): SiLU()
|
| 35 |
+
)
|
| 36 |
+
(mlp_moe_gen): Qwen2MLP(
|
| 37 |
+
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
|
| 38 |
+
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
|
| 39 |
+
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
|
| 40 |
+
(act_fn): SiLU()
|
| 41 |
+
)
|
| 42 |
+
(input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 43 |
+
(input_layernorm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 44 |
+
(post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 45 |
+
(post_attention_layernorm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
(norm): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 51 |
+
(norm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
|
| 52 |
+
(rotary_emb): Qwen2RotaryEmbedding()
|
| 53 |
+
)
|
| 54 |
+
(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
|
| 55 |
+
)
|
| 56 |
+
(time_embedder): FullyShardedDataParallel(
|
| 57 |
+
(_fsdp_wrapped_module): TimestepEmbedder(
|
| 58 |
+
(mlp): Sequential(
|
| 59 |
+
(0): Linear(in_features=256, out_features=3584, bias=True)
|
| 60 |
+
(1): SiLU()
|
| 61 |
+
(2): Linear(in_features=3584, out_features=3584, bias=True)
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
(vae2llm): Linear(in_features=64, out_features=3584, bias=True)
|
| 66 |
+
(llm2vae): Linear(in_features=3584, out_features=64, bias=True)
|
| 67 |
+
(latent_pos_embed): FullyShardedDataParallel(
|
| 68 |
+
(_fsdp_wrapped_module): PositionEmbedding()
|
| 69 |
+
)
|
| 70 |
+
(vit_model): SiglipVisionModel(
|
| 71 |
+
(vision_model): FullyShardedDataParallel(
|
| 72 |
+
(_fsdp_wrapped_module): SiglipVisionTransformer(
|
| 73 |
+
(embeddings): SiglipVisionEmbeddings(
|
| 74 |
+
(position_embedding): Embedding(4900, 1152)
|
| 75 |
+
(patch_embedding): Linear(in_features=588, out_features=1152, bias=True)
|
| 76 |
+
)
|
| 77 |
+
(encoder): SiglipEncoder(
|
| 78 |
+
(layers): ModuleList(
|
| 79 |
+
(0-25): 26 x FullyShardedDataParallel(
|
| 80 |
+
(_fsdp_wrapped_module): CheckpointWrapper(
|
| 81 |
+
(_checkpoint_wrapped_module): SiglipEncoderLayer(
|
| 82 |
+
(self_attn): SiglipFlashAttention2(
|
| 83 |
+
(k_proj): Linear(in_features=1152, out_features=1152, bias=True)
|
| 84 |
+
(v_proj): Linear(in_features=1152, out_features=1152, bias=True)
|
| 85 |
+
(q_proj): Linear(in_features=1152, out_features=1152, bias=True)
|
| 86 |
+
(out_proj): Linear(in_features=1152, out_features=1152, bias=True)
|
| 87 |
+
)
|
| 88 |
+
(layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
|
| 89 |
+
(mlp): SiglipMLP(
|
| 90 |
+
(activation_fn): PytorchGELUTanh()
|
| 91 |
+
(fc1): Linear(in_features=1152, out_features=4304, bias=True)
|
| 92 |
+
(fc2): Linear(in_features=4304, out_features=1152, bias=True)
|
| 93 |
+
)
|
| 94 |
+
(layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
(post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
(connector): FullyShardedDataParallel(
|
| 105 |
+
(_fsdp_wrapped_module): CheckpointWrapper(
|
| 106 |
+
(_checkpoint_wrapped_module): MLPconnector(
|
| 107 |
+
(activation_fn): PytorchGELUTanh()
|
| 108 |
+
(fc1): Linear(in_features=1152, out_features=3584, bias=True)
|
| 109 |
+
(fc2): Linear(in_features=3584, out_features=3584, bias=True)
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
(vit_pos_embed): FullyShardedDataParallel(
|
| 114 |
+
(_fsdp_wrapped_module): PositionEmbedding()
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
_flat_param True
|
| 119 |
+
language_model.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 120 |
+
language_model.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 121 |
+
language_model.model.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 122 |
+
language_model.model.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 123 |
+
language_model.model.layers.4._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 124 |
+
language_model.model.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 125 |
+
language_model.model.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 126 |
+
language_model.model.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 127 |
+
language_model.model.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 128 |
+
language_model.model.layers.9._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 129 |
+
language_model.model.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 130 |
+
language_model.model.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 131 |
+
language_model.model.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 132 |
+
language_model.model.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 133 |
+
language_model.model.layers.14._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 134 |
+
language_model.model.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 135 |
+
language_model.model.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 136 |
+
language_model.model.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 137 |
+
language_model.model.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 138 |
+
language_model.model.layers.19._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 139 |
+
language_model.model.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 140 |
+
language_model.model.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 141 |
+
language_model.model.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 142 |
+
language_model.model.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 143 |
+
language_model.model.layers.24._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 144 |
+
language_model.model.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 145 |
+
language_model.model.layers.26._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 146 |
+
language_model.model.layers.27._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 147 |
+
time_embedder._fsdp_wrapped_module._flat_param True
|
| 148 |
+
latent_pos_embed._fsdp_wrapped_module._flat_param False
|
| 149 |
+
vit_model.vision_model._fsdp_wrapped_module._flat_param True
|
| 150 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 151 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 152 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 153 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 154 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.4._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 155 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 156 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 157 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 158 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 159 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.9._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 160 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 161 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 162 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 163 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 164 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.14._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 165 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 166 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 167 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 168 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 169 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.19._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 170 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 171 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 172 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 173 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 174 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.24._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 175 |
+
vit_model.vision_model._fsdp_wrapped_module.encoder.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 176 |
+
connector._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
|
| 177 |
+
vit_pos_embed._fsdp_wrapped_module._flat_param False
|
| 178 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 179 |
+
Preparing Dataset block_dataset/block_dataset
|
| 180 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 181 |
+
rank-0 worker-0 dataset-block_dataset: resuming data at row#0
|
| 182 |
+
rank-7 worker-0 dataset-block_dataset: resuming data at row#0
|
| 183 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 184 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 185 |
+
rank-2 worker-0 dataset-block_dataset: resuming data at row#0
|
| 186 |
+
rank-5 worker-0 dataset-block_dataset: resuming data at row#0
|
| 187 |
+
{'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
|
| 188 |
+
rank-1 worker-0 dataset-block_dataset: resuming data at row#0
|
| 189 |
+
skip a sample with length 43202
|
| 190 |
+
skip a sample with length 48060
|
| 191 |
+
skip a sample with length 41094
|
| 192 |
+
skip a sample with length 43245
|
| 193 |
+
skip a sample with length 57756
|
| 194 |
+
skip a sample with length 41160
|
| 195 |
+
skip a sample with length 44611
|
| 196 |
+
skip a sample with length 41094
|
| 197 |
+
skip a sample with length 48060
|
| 198 |
+
skip a sample with length 50787
|
| 199 |
+
skip a sample with length 44611
|
| 200 |
+
skip a sample with length 43245
|
| 201 |
+
skip a sample with length 41106
|
| 202 |
+
skip a sample with length 41160
|
| 203 |
+
skip a sample with length 57756
|
| 204 |
+
skip a sample with length 42480
|
| 205 |
+
skip a sample with length 42486
|
| 206 |
+
skip a sample with length 42486
|
| 207 |
+
skip a sample with length 50787
|
| 208 |
+
skip a sample with length 43202
|
| 209 |
+
skip a sample with length 42480
|
| 210 |
+
block_dataset repeat in rank-3 worker-0
|
| 211 |
+
block_dataset repeat in rank-4 worker-0
|
| 212 |
+
block_dataset repeat in rank-6 worker-0
|
| 213 |
+
block_dataset repeat in rank-7 worker-0
|
| 214 |
+
block_dataset repeat in rank-0 worker-0
|
| 215 |
+
block_dataset repeat in rank-5 worker-0
|
| 216 |
+
block_dataset repeat in rank-2 worker-0
|
| 217 |
+
skip a sample with length 41106
|
| 218 |
+
skip a sample with length 48060
|
| 219 |
+
skip a sample with length 43202
|
| 220 |
+
block_dataset repeat in rank-1 worker-0
|
| 221 |
+
skip a sample with length 41094
|
| 222 |
+
skip a sample with length 57756
|
| 223 |
+
Yielding data with length 31517
|
| 224 |
+
skip a sample with length 43245
|
| 225 |
+
skip a sample with length 41160
|
| 226 |
+
skip a sample with length 44611
|
| 227 |
+
Yielding data with length 33637
|
| 228 |
+
Yielding data with length 33154
|
| 229 |
+
Yielding data with length 15542
|
| 230 |
+
skip a sample with length 50787
|
| 231 |
+
Yielding data with length 35486
|
| 232 |
+
Yielding data with length 12716
|
| 233 |
+
skip a sample with length 48060
|
| 234 |
+
skip a sample with length 41094
|
| 235 |
+
skip a sample with length 43245
|
| 236 |
+
skip a sample with length 41160
|
| 237 |
+
skip a sample with length 44611
|
| 238 |
+
Yielding data with length 26172
|
| 239 |
+
Yielding data with length 23933
|
| 240 |
+
skip a sample with length 41106
|
| 241 |
+
skip a sample with length 57756
|
| 242 |
+
Yielding data with length 32737
|
| 243 |
+
Yielding data with length 27691
|
| 244 |
+
Yielding data with length 31628
|
| 245 |
+
Yielding data with length 36149
|
| 246 |
+
skip a sample with length 42486
|
| 247 |
+
Yielding data with length 30708
|
| 248 |
+
Yielding data with length 13411
|
| 249 |
+
Yielding data with length 18973
|
| 250 |
+
Yielding data with length 27959
|
| 251 |
+
Yielding data with length 23821
|
| 252 |
+
skip a sample with length 50787
|
| 253 |
+
Yielding data with length 27474
|
| 254 |
+
Yielding data with length 7870
|
| 255 |
+
skip a sample with length 42486
|
| 256 |
+
Yielding data with length 37241
|
| 257 |
+
Yielding data with length 27998
|
| 258 |
+
Yielding data with length 13811
|
| 259 |
+
Yielding data with length 20795
|
| 260 |
+
Yielding data with length 32169
|
| 261 |
+
Yielding data with length 16921
|
| 262 |
+
Yielding data with length 16202
|
| 263 |
+
Yielding data with length 21081
|
| 264 |
+
Yielding data with length 21217
|
| 265 |
+
Yielding data with length 26994
|
| 266 |
+
Yielding data with length 17856
|
| 267 |
+
Yielding data with length 33309
|
| 268 |
+
Yielding data with length 31064
|
| 269 |
+
Yielding data with length 23492
|
| 270 |
+
Yielding data with length 20761
|
| 271 |
+
Yielding data with length 31378
|
| 272 |
+
Yielding data with length 23451
|
| 273 |
+
Yielding data with length 25220
|
| 274 |
+
Yielding data with length 26611
|
| 275 |
+
Yielding data with length 27250
|
| 276 |
+
Yielding data with length 35216
|
| 277 |
+
skip a sample with length 42480
|
| 278 |
+
Yielding data with length 13720
|
| 279 |
+
Yielding data with length 19578
|
| 280 |
+
Yielding data with length 25498
|
| 281 |
+
Yielding data with length 22109
|
| 282 |
+
Yielding data with length 19619
|
| 283 |
+
Yielding data with length 23415
|
| 284 |
+
Yielding data with length 30332
|
| 285 |
+
Yielding data with length 34858
|
| 286 |
+
block_dataset repeat in rank-0 worker-0
|
| 287 |
+
block_dataset repeat in rank-6 worker-0
|
| 288 |
+
block_dataset repeat in rank-5 worker-0
|
| 289 |
+
Yielding data with length 19720
|
| 290 |
+
Yielding data with length 25991
|
| 291 |
+
Yielding data with length 29387
|
| 292 |
+
Yielding data with length 21979
|
| 293 |
+
skip a sample with length 43202
|
| 294 |
+
skip a sample with length 41106
|
| 295 |
+
Yielding data with length 23402
|
| 296 |
+
Yielding data with length 22465
|
| 297 |
+
Yielding data with length 21998
|
| 298 |
+
Yielding data with length 25679
|
| 299 |
+
block_dataset repeat in rank-3 worker-0
|
| 300 |
+
Yielding data with length 17957
|
| 301 |
+
Yielding data with length 22013
|
| 302 |
+
Yielding data with length 20711
|
| 303 |
+
Yielding data with length 23461
|
| 304 |
+
Yielding data with length 24469
|
| 305 |
+
Yielding data with length 24915
|
| 306 |
+
Yielding data with length 27691
|
| 307 |
+
Yielding data with length 37262
|
| 308 |
+
skip a sample with length 43202
|
| 309 |
+
block_dataset repeat in rank-7 worker-0
|
| 310 |
+
block_dataset repeat in rank-1 worker-0
|
| 311 |
+
Yielding data with length 17288
|
| 312 |
+
Yielding data with length 20687
|
| 313 |
+
Yielding data with length 20361
|
| 314 |
+
Yielding data with length 28560
|
| 315 |
+
Yielding data with length 31247
|
| 316 |
+
Yielding data with length 17983
|
| 317 |
+
block_dataset repeat in rank-2 worker-0
|
| 318 |
+
Yielding data with length 27946
|
| 319 |
+
Yielding data with length 27631
|
| 320 |
+
skip a sample with length 41094
|
| 321 |
+
skip a sample with length 42480
|
| 322 |
+
Yielding data with length 10650
|
| 323 |
+
Yielding data with length 14641
|
| 324 |
+
Yielding data with length 23037
|
| 325 |
+
skip a sample with length 43245
|
| 326 |
+
Yielding data with length 16219
|
| 327 |
+
Yielding data with length 35530
|
| 328 |
+
Yielding data with length 16208
|
| 329 |
+
Yielding data with length 26188
|
| 330 |
+
Yielding data with length 27937
|
| 331 |
+
block_dataset repeat in rank-4 worker-0
|
| 332 |
+
Yielding data with length 11424
|
| 333 |
+
Yielding data with length 12453
|
| 334 |
+
Yielding data with length 16146
|
| 335 |
+
Yielding data with length 18287
|
| 336 |
+
Yielding data with length 20791
|
| 337 |
+
Yielding data with length 24236
|
| 338 |
+
Yielding data with length 25579
|
| 339 |
+
Yielding data with length 28956
|
| 340 |
+
Yielding data with length 14121
|
| 341 |
+
Yielding data with length 14781
|
| 342 |
+
Yielding data with length 15221
|
| 343 |
+
Yielding data with length 15921
|
| 344 |
+
skip a sample with length 41160
|
| 345 |
+
Yielding data with length 28466
|
| 346 |
+
skip a sample with length 48060
|
| 347 |
+
Yielding data with length 17646
|
| 348 |
+
Yielding data with length 31256
|
| 349 |
+
Yielding data with length 26792
|
| 350 |
+
Yielding data with length 17122
|
| 351 |
+
Yielding data with length 20057
|
| 352 |
+
skip a sample with length 48060
|
| 353 |
+
Yielding data with length 31691
|
| 354 |
+
Yielding data with length 32761
|
| 355 |
+
Yielding data with length 23701
|
| 356 |
+
Yielding data with length 23722
|
| 357 |
+
Yielding data with length 27340
|
| 358 |
+
Yielding data with length 33869
|
| 359 |
+
skip a sample with length 44611
|
| 360 |
+
skip a sample with length 57756
|
| 361 |
+
skip a sample with length 41094
|
| 362 |
+
Yielding data with length 15091Yielding data with length 16206
|
| 363 |
+
|
| 364 |
+
Yielding data with length 13157
|
| 365 |
+
Yielding data with length 26843
|
| 366 |
+
Yielding data with length 21094
|
| 367 |
+
Yielding data with length 24549
|
| 368 |
+
Yielding data with length 20404
|
| 369 |
+
Yielding data with length 25400
|
| 370 |
+
skip a sample with length 41106
|
| 371 |
+
Yielding data with length 18332
|
| 372 |
+
Yielding data with length 20708
|
| 373 |
+
Yielding data with length 21310
|
| 374 |
+
skip a sample with length 50787
|
| 375 |
+
Yielding data with length 27881
|
| 376 |
+
Yielding data with length 25557
|
| 377 |
+
skip a sample with length 57756
|
| 378 |
+
Yielding data with length 24894
|
| 379 |
+
Yielding data with length 28219
|
| 380 |
+
Yielding data with length 24140
|
| 381 |
+
skip a sample with length 43245
|
| 382 |
+
skip a sample with length 44611
|
| 383 |
+
skip a sample with length 41160
|
| 384 |
+
Yielding data with length 27592
|
| 385 |
+
Yielding data with length 26168
|
| 386 |
+
Yielding data with length 20709
|
| 387 |
+
Yielding data with length 23581
|
| 388 |
+
skip a sample with length 42486
|
| 389 |
+
Yielding data with length 29274
|
| 390 |
+
Yielding data with length 24805
|
| 391 |
+
Yielding data with length 31112
|
| 392 |
+
Yielding data with length 36407
|
| 393 |
+
skip a sample with length 50787
|
| 394 |
+
Yielding data with length 18262
|
| 395 |
+
Yielding data with length 26439
|
| 396 |
+
Yielding data with length 18322
|
| 397 |
+
Yielding data with length 33505
|
| 398 |
+
Yielding data with length 29023
|
| 399 |
+
Yielding data with length 25487
|
| 400 |
+
Yielding data with length 31643
|
| 401 |
+
Yielding data with length 27712
|
| 402 |
+
skip a sample with length 42486
|
| 403 |
+
Yielding data with length 15735
|
| 404 |
+
Yielding data with length 17616
|
| 405 |
+
Yielding data with length 13811
|
| 406 |
+
Yielding data with length 19365
|
| 407 |
+
Yielding data with length 19566
|
| 408 |
+
Yielding data with length 24227
|
| 409 |
+
Yielding data with length 28214
|
| 410 |
+
Yielding data with length 30026
|
| 411 |
+
Yielding data with length 18195
|
| 412 |
+
Yielding data with length 18206
|
| 413 |
+
Yielding data with length 19699
|
| 414 |
+
Yielding data with length 23103
|
| 415 |
+
Yielding data with length 33474
|
| 416 |
+
Yielding data with length 29109
|
| 417 |
+
Yielding data with length 36518
|
| 418 |
+
Yielding data with length 27659
|
| 419 |
+
Yielding data with length 21031
|
| 420 |
+
Yielding data with length 27532
|
| 421 |
+
Yielding data with length 21080
|
| 422 |
+
Yielding data with length 20740
|
| 423 |
+
Yielding data with length 24066
|
| 424 |
+
Yielding data with length 26959
|
| 425 |
+
Yielding data with length 32162
|
| 426 |
+
skip a sample with length 42480
|
| 427 |
+
Yielding data with length 31373
|
| 428 |
+
block_dataset repeat in rank-0 worker-0
|
| 429 |
+
Yielding data with length 9629
|
| 430 |
+
block_dataset repeat in rank-6 worker-0
|
| 431 |
+
Yielding data with length 12734
|
| 432 |
+
Yielding data with length 20622
|
| 433 |
+
Yielding data with length 31650
|
| 434 |
+
Yielding data with length 23291
|
| 435 |
+
Yielding data with length 25245
|
| 436 |
+
Yielding data with length 27515
|
| 437 |
+
Yielding data with length 28296
|
| 438 |
+
Yielding data with length 20698
|
| 439 |
+
Yielding data with length 21726
|
| 440 |
+
skip a sample with length 43202
|
| 441 |
+
Yielding data with length 21768
|
| 442 |
+
Yielding data with length 18011
|
| 443 |
+
Yielding data with length 23070
|
| 444 |
+
Yielding data with length 19691
|
| 445 |
+
Yielding data with length 25171
|
| 446 |
+
Yielding data with length 33860
|
| 447 |
+
block_dataset repeat in rank-5 worker-0
|
| 448 |
+
Yielding data with length 8964
|
| 449 |
+
skip a sample with length 43202
|
| 450 |
+
block_dataset repeat in rank-3 worker-0
|
| 451 |
+
Yielding data with length 19248
|
| 452 |
+
Yielding data with length 16262
|
| 453 |
+
Yielding data with length 29186
|
| 454 |
+
skip a sample with length 41106
|
| 455 |
+
Yielding data with length 19245
|
| 456 |
+
Yielding data with length 24191
|
| 457 |
+
Yielding data with length 23133
|
| 458 |
+
Yielding data with length 35614
|
| 459 |
+
block_dataset repeat in rank-2 worker-0
|
| 460 |
+
Yielding data with length 13769
|
| 461 |
+
Yielding data with length 24400
|
| 462 |
+
Yielding data with length 31113
|
| 463 |
+
Yielding data with length 25652
|
| 464 |
+
Yielding data with length 25500
|
| 465 |
+
Yielding data with length 26979
|
| 466 |
+
block_dataset repeat in rank-7 worker-0
|
| 467 |
+
Yielding data with length 24263
|
| 468 |
+
Yielding data with length 27393
|
| 469 |
+
skip a sample with length 41094
|
| 470 |
+
Yielding data with length 23188
|
| 471 |
+
Yielding data with length 19658
|
| 472 |
+
block_dataset repeat in rank-1 worker-0
|
| 473 |
+
Yielding data with length 24787
|
| 474 |
+
Yielding data with length 26221
|
| 475 |
+
Yielding data with length 21409
|
| 476 |
+
Yielding data with length 32059
|
| 477 |
+
skip a sample with length 43245
|
| 478 |
+
Yielding data with length 26058
|
| 479 |
+
Yielding data with length 24507
|
| 480 |
+
Yielding data with length 8292
|
| 481 |
+
skip a sample with length 42480
|
| 482 |
+
Yielding data with length 12746
|
| 483 |
+
Yielding data with length 17288
|
| 484 |
+
Yielding data with length 20793
|
| 485 |
+
Yielding data with length 17252
|
| 486 |
+
Yielding data with length 25240
|
| 487 |
+
Yielding data with length 25304
|
| 488 |
+
Yielding data with length 30376
|
| 489 |
+
block_dataset repeat in rank-4 worker-0
|
| 490 |
+
Yielding data with length 14138
|
| 491 |
+
skip a sample with length 48060
|
| 492 |
+
skip a sample with length 41160
|
| 493 |
+
skip a sample with length 48060
|
| 494 |
+
Yielding data with length 19684
|
| 495 |
+
Yielding data with length 14748
|
| 496 |
+
Yielding data with length 21158
|
| 497 |
+
Yielding data with length 21425
|
| 498 |
+
Yielding data with length 30781
|
| 499 |
+
Yielding data with length 33027
|
| 500 |
+
Yielding data with length 33537
|
| 501 |
+
Yielding data with length 14837
|
| 502 |
+
Yielding data with length 12766
|
| 503 |
+
Yielding data with length 14115
|
| 504 |
+
Yielding data with length 15474
|
| 505 |
+
Yielding data with length 21749
|
| 506 |
+
Yielding data with length 33147
|
| 507 |
+
Yielding data with length 25621
|
| 508 |
+
skip a sample with length 57756
|
| 509 |
+
Yielding data with length 22466
|
| 510 |
+
skip a sample with length 41094
|
| 511 |
+
Yielding data with length 22413
|
| 512 |
+
skip a sample with length 50787
|
| 513 |
+
Yielding data with length 27913
|
| 514 |
+
Yielding data with length 25090
|
| 515 |
+
Yielding data with length 25551
|
| 516 |
+
Yielding data with length 25335
|
| 517 |
+
skip a sample with length 57756
|
| 518 |
+
skip a sample with length 43245
|
| 519 |
+
Yielding data with length 25947
|
| 520 |
+
skip a sample with length 41160
|
| 521 |
+
Yielding data with length 31872
|
| 522 |
+
Yielding data with length 36109
|
| 523 |
+
skip a sample with length 44611
|
| 524 |
+
Yielding data with length 16234
|
| 525 |
+
Yielding data with length 19945
|
| 526 |
+
Yielding data with length 19685
|
| 527 |
+
Yielding data with length 34186
|
| 528 |
+
Yielding data with length 36943
|
| 529 |
+
Yielding data with length 23090
|
| 530 |
+
Yielding data with length 29034
|
| 531 |
+
Yielding data with length 30067
|
| 532 |
+
Yielding data with length 8489
|
| 533 |
+
skip a sample with length 41106
|
| 534 |
+
skip a sample with length 44611
|
| 535 |
+
skip a sample with length 50787
|
| 536 |
+
Yielding data with length 10008
|
| 537 |
+
Yielding data with length 32829
|
| 538 |
+
Yielding data with length 23593
|
| 539 |
+
Yielding data with length 29907
|
| 540 |
+
skip a sample with length 42486
|
| 541 |
+
Yielding data with length 25500
|
| 542 |
+
Yielding data with length 34717
|
| 543 |
+
Yielding data with length 29714
|
| 544 |
+
Yielding data with length 16266
|
| 545 |
+
Yielding data with length 17271
|
| 546 |
+
Yielding data with length 20547
|
| 547 |
+
Yielding data with length 22351
|
| 548 |
+
Yielding data with length 26637
|
| 549 |
+
Yielding data with length 32390
|
| 550 |
+
Yielding data with length 30503
|
| 551 |
+
Yielding data with length 29728
|
| 552 |
+
skip a sample with length 42486
|
| 553 |
+
Yielding data with length 21561
|
| 554 |
+
Yielding data with length 16923
|
| 555 |
+
Yielding data with length 19642
|
| 556 |
+
Yielding data with length 20198
|
| 557 |
+
Yielding data with length 22735
|
| 558 |
+
Yielding data with length 32930
|
| 559 |
+
Yielding data with length 24262
|
| 560 |
+
Yielding data with length 34823
|
| 561 |
+
Yielding data with length 28608
|
| 562 |
+
Yielding data with length 28122
|
| 563 |
+
Yielding data with length 24532
|
| 564 |
+
Yielding data with length 26210
|
| 565 |
+
Yielding data with length 36308
|
| 566 |
+
Yielding data with length 27414
|
| 567 |
+
Yielding data with length 30425
|
| 568 |
+
Yielding data with length 30774
|
| 569 |
+
block_dataset repeat in rank-6 worker-0
|
| 570 |
+
block_dataset repeat in rank-0 worker-0
|
| 571 |
+
skip a sample with length 42480
|
| 572 |
+
Yielding data with length 15870
|
| 573 |
+
Yielding data with length 15590
|
| 574 |
+
Yielding data with length 18509
|
| 575 |
+
Yielding data with length 23812
|
| 576 |
+
Yielding data with length 18170
|
| 577 |
+
Yielding data with length 32514
|
| 578 |
+
Yielding data with length 24814
|
| 579 |
+
Yielding data with length 28298
|
| 580 |
+
Yielding data with length 9988
|
| 581 |
+
Yielding data with length 18332
|
| 582 |
+
Yielding data with length 21420
|
| 583 |
+
Yielding data with length 23903
|
| 584 |
+
Yielding data with length 25120
|
| 585 |
+
Yielding data with length 28991
|
| 586 |
+
Yielding data with length 30114
|
| 587 |
+
Yielding data with length 30128
|
| 588 |
+
skip a sample with length 43202
|
| 589 |
+
skip a sample with length 43202
|
| 590 |
+
Yielding data with length 17850
|
| 591 |
+
Yielding data with length 18166
|
| 592 |
+
Yielding data with length 22663
|
| 593 |
+
Yielding data with length 20751
|
| 594 |
+
Yielding data with length 19273
|
| 595 |
+
Yielding data with length 17552
|
| 596 |
+
Yielding data with length 26616
|
| 597 |
+
Yielding data with length 28527
|
| 598 |
+
block_dataset repeat in rank-3 worker-0
|
| 599 |
+
block_dataset repeat in rank-5 worker-0
|
| 600 |
+
Yielding data with length 13466
|
| 601 |
+
Yielding data with length 14852
|
| 602 |
+
block_dataset repeat in rank-7 worker-0
|
| 603 |
+
Yielding data with length 20760
|
| 604 |
+
Yielding data with length 22448
|
| 605 |
+
Yielding data with length 20269
|
| 606 |
+
Yielding data with length 27307
|
| 607 |
+
Yielding data with length 31128
|
| 608 |
+
Yielding data with length 23848
|
| 609 |
+
block_dataset repeat in rank-2 worker-0
|
| 610 |
+
Yielding data with length 8948
|
| 611 |
+
skip a sample with length 41106
|
| 612 |
+
Yielding data with length 10367
|
| 613 |
+
Yielding data with length 12612
|
| 614 |
+
Yielding data with length 18632
|
| 615 |
+
Yielding data with length 32428
|
| 616 |
+
Yielding data with length 25651
|
| 617 |
+
Yielding data with length 22117
|
| 618 |
+
Yielding data with length 30468
|
| 619 |
+
Yielding data with length 12051
|
| 620 |
+
Yielding data with length 13346
|
| 621 |
+
Yielding data with length 15726
|
| 622 |
+
Yielding data with length 11383
|
| 623 |
+
skip a sample with length 41094
|
| 624 |
+
Yielding data with length 19358
|
| 625 |
+
Yielding data with length 31964
|
| 626 |
+
skip a sample with length 43245
|
| 627 |
+
Yielding data with length 34359
|
| 628 |
+
Yielding data with length 25146
|
| 629 |
+
block_dataset repeat in rank-1 worker-0
|
| 630 |
+
Yielding data with length 12528
|
| 631 |
+
Yielding data with length 14445
|
| 632 |
+
Yielding data with length 21808
|
| 633 |
+
skip a sample with length 48060
|
| 634 |
+
Yielding data with length 24973
|
| 635 |
+
Yielding data with length 24141
|
| 636 |
+
Yielding data with length 35965
|
| 637 |
+
Yielding data with length 29665
|
| 638 |
+
Yielding data with length 28975
|
| 639 |
+
skip a sample with length 42480
|
| 640 |
+
skip a sample with length 48060
|
| 641 |
+
Yielding data with length 15182
|
| 642 |
+
Yielding data with length 19712
|
| 643 |
+
skip a sample with length 41160
|
| 644 |
+
Yielding data with length 19698
|
| 645 |
+
Yielding data with length 18255
|
| 646 |
+
Yielding data with length 30749
|
| 647 |
+
Yielding data with length 34841
|
| 648 |
+
Yielding data with length 22848
|
| 649 |
+
Yielding data with length 28618
|
| 650 |
+
block_dataset repeat in rank-4 worker-0
|
| 651 |
+
Yielding data with length 12071
|
| 652 |
+
Yielding data with length 15527
|
| 653 |
+
Yielding data with length 19227
|
| 654 |
+
Yielding data with length 19199
|
| 655 |
+
skip a sample with length 50787
|
| 656 |
+
Yielding data with length 25207
|
| 657 |
+
Yielding data with length 26500
|
| 658 |
+
skip a sample with length 57756
|
| 659 |
+
Yielding data with length 25915
|
| 660 |
+
skip a sample with length 57756
|
| 661 |
+
Yielding data with length 29886
|
| 662 |
+
skip a sample with length 43245
|
| 663 |
+
skip a sample with length 41160
|
| 664 |
+
skip a sample with length 41094
|
| 665 |
+
Yielding data with length 18659
|
| 666 |
+
Yielding data with length 23460
|
| 667 |
+
Yielding data with length 29942
|
| 668 |
+
Yielding data with length 30289
|
| 669 |
+
Yielding data with length 27297
|
| 670 |
+
Yielding data with length 28034
|
| 671 |
+
Yielding data with length 29025
|
| 672 |
+
Yielding data with length 36590
|
| 673 |
+
skip a sample with length 44611
|
| 674 |
+
skip a sample with length 50787
|
| 675 |
+
skip a sample with length 44611
|
| 676 |
+
Yielding data with length 24792
|
| 677 |
+
Yielding data with length 20748
|
| 678 |
+
Yielding data with length 23187
|
| 679 |
+
Yielding data with length 19037
|
| 680 |
+
Yielding data with length 31561
|
| 681 |
+
Yielding data with length 34200
|
| 682 |
+
Yielding data with length 26330
|
| 683 |
+
Yielding data with length 30027
|
| 684 |
+
skip a sample with length 41106
|
| 685 |
+
Yielding data with length 12095
|
| 686 |
+
Yielding data with length 15214
|
| 687 |
+
Yielding data with length 17243
|
| 688 |
+
Yielding data with length 23097
|
| 689 |
+
Yielding data with length 24142
|
| 690 |
+
Yielding data with length 28934
|
| 691 |
+
Yielding data with length 29052
|
| 692 |
+
skip a sample with length 42486
|
| 693 |
+
Yielding data with length 34556
|
| 694 |
+
Yielding data with length 14895
|
| 695 |
+
Yielding data with length 19552
|
| 696 |
+
Yielding data with length 22053
|
| 697 |
+
Yielding data with length 29467
|
| 698 |
+
Yielding data with length 23444
|
| 699 |
+
Yielding data with length 26636
|
| 700 |
+
Yielding data with length 33801
|
| 701 |
+
Yielding data with length 34191
|
| 702 |
+
skip a sample with length 42486
|
| 703 |
+
Yielding data with length 20414
|
| 704 |
+
Yielding data with length 21739
|
| 705 |
+
Yielding data with length 23877
|
| 706 |
+
Yielding data with length 26520
|
| 707 |
+
Yielding data with length 24877
|
| 708 |
+
Yielding data with length 27696
|
| 709 |
+
Yielding data with length 27597
|
| 710 |
+
Yielding data with length 32703
|
| 711 |
+
block_dataset repeat in rank-0 worker-0
|
| 712 |
+
block_dataset repeat in rank-6 worker-0
|
| 713 |
+
Yielding data with length 18005
|
| 714 |
+
Yielding data with length 26527
|
| 715 |
+
Yielding data with length 20791
|
| 716 |
+
Yielding data with length 20719
|
| 717 |
+
Yielding data with length 22114
|
| 718 |
+
Yielding data with length 22512
|
| 719 |
+
Yielding data with length 29336
|
| 720 |
+
Yielding data with length 31527
|
| 721 |
+
Yielding data with length 9284
|
| 722 |
+
skip a sample with length 42480
|
| 723 |
+
Yielding data with length 17316
|
| 724 |
+
Yielding data with length 19314
|
| 725 |
+
Yielding data with length 25239
|
| 726 |
+
Yielding data with length 19703
|
| 727 |
+
Yielding data with length 21232
|
| 728 |
+
Yielding data with length 17268
|
| 729 |
+
Yielding data with length 26931
|
| 730 |
+
skip a sample with length 43202
|
| 731 |
+
Yielding data with length 16230
|
| 732 |
+
Yielding data with length 19692
|
| 733 |
+
Yielding data with length 23196
|
| 734 |
+
Yielding data with length 22444
|
| 735 |
+
Yielding data with length 29708
|
| 736 |
+
Yielding data with length 20680
|
| 737 |
+
Yielding data with length 30765
|
| 738 |
+
Yielding data with length 27917
|
| 739 |
+
skip a sample with length 43202
|
| 740 |
+
Yielding data with length 17207
|
| 741 |
+
Yielding data with length 17853
|
| 742 |
+
Yielding data with length 23427
|
| 743 |
+
block_dataset repeat in rank-5 worker-0
|
| 744 |
+
Yielding data with length 27646
|
| 745 |
+
Yielding data with length 25169
|
| 746 |
+
Yielding data with length 26475
|
| 747 |
+
Yielding data with length 25127
|
| 748 |
+
Yielding data with length 27339
|
| 749 |
+
block_dataset repeat in rank-3 worker-0
|
| 750 |
+
block_dataset repeat in rank-7 worker-0
|
| 751 |
+
Yielding data with length 16546
|
| 752 |
+
Yielding data with length 16256
|
| 753 |
+
Yielding data with length 22339
|
| 754 |
+
Yielding data with length 17919
|
| 755 |
+
Yielding data with length 23138
|
| 756 |
+
Yielding data with length 19676
|
| 757 |
+
Yielding data with length 24070
|
| 758 |
+
Yielding data with length 25924
|
| 759 |
+
block_dataset repeat in rank-2 worker-0
|
| 760 |
+
Yielding data with length 14569
|
| 761 |
+
Yielding data with length 31705
|
| 762 |
+
Yielding data with length 24120
|
| 763 |
+
Yielding data with length 33709
|
| 764 |
+
Yielding data with length 26245
|
| 765 |
+
Yielding data with length 39397
|
| 766 |
+
Yielding data with length 31035
|
| 767 |
+
Yielding data with length 15921
|
| 768 |
+
skip a sample with length 41094
|
| 769 |
+
skip a sample with length 41106
|
| 770 |
+
Yielding data with length 11323
|
| 771 |
+
Yielding data with length 20758
|
| 772 |
+
Yielding data with length 24109
|
| 773 |
+
skip a sample with length 43245
|
| 774 |
+
skip a sample with length 48060
|
| 775 |
+
Yielding data with length 21739
|
| 776 |
+
Yielding data with length 22062
|
| 777 |
+
Yielding data with length 11069
|
| 778 |
+
Yielding data with length 33774
|
| 779 |
+
Yielding data with length 24783
|
| 780 |
+
Yielding data with length 13348
|
| 781 |
+
Yielding data with length 13218
|
| 782 |
+
Yielding data with length 17288
|
| 783 |
+
Yielding data with length 26493
|
| 784 |
+
Yielding data with length 24246
|
| 785 |
+
Yielding data with length 26920
|
| 786 |
+
Yielding data with length 28599
|
| 787 |
+
Yielding data with length 31042
|
| 788 |
+
block_dataset repeat in rank-1 worker-0
|
| 789 |
+
skip a sample with length 48060
|
| 790 |
+
skip a sample with length 41160
|
| 791 |
+
Yielding data with length 25722
|
| 792 |
+
Yielding data with length 33186
|
| 793 |
+
skip a sample with length 50787
|
| 794 |
+
Yielding data with length 19367
|
| 795 |
+
Yielding data with length 26598
|
| 796 |
+
Yielding data with length 18672
|
| 797 |
+
Yielding data with length 27291
|
| 798 |
+
Yielding data with length 33105
|
| 799 |
+
skip a sample with length 57756
|
| 800 |
+
Yielding data with length 31380
|
| 801 |
+
skip a sample with length 43245
|
| 802 |
+
skip a sample with length 42480
|
| 803 |
+
skip a sample with length 41160
|
| 804 |
+
Yielding data with length 22996
|
| 805 |
+
Yielding data with length 18896
|
| 806 |
+
Yielding data with length 19621
|
| 807 |
+
Yielding data with length 24453
|
| 808 |
+
Yielding data with length 37227
|
| 809 |
+
Yielding data with length 28758
|
| 810 |
+
skip a sample with length 57756
|
| 811 |
+
Yielding data with length 31736
|
| 812 |
+
Yielding data with length 26241
|
| 813 |
+
block_dataset repeat in rank-4 worker-0
|
| 814 |
+
skip a sample with length 44611
|
| 815 |
+
skip a sample with length 50787
|
| 816 |
+
Yielding data with length 8502
|
| 817 |
+
skip a sample with length 41094
|
| 818 |
+
Yielding data with length 23339
|
| 819 |
+
Yielding data with length 26828
|
| 820 |
+
Yielding data with length 22141
|
| 821 |
+
Yielding data with length 27917
|
| 822 |
+
Yielding data with length 30731
|
| 823 |
+
Yielding data with length 35152
|
| 824 |
+
Yielding data with length 32504
|
| 825 |
+
Yielding data with length 14524
|
| 826 |
+
Yielding data with length 21770
|
| 827 |
+
Yielding data with length 23021
|
| 828 |
+
Yielding data with length 31645
|
| 829 |
+
Yielding data with length 34056
|
| 830 |
+
skip a sample with length 44611
|
| 831 |
+
Yielding data with length 24506
|
| 832 |
+
Yielding data with length 27457
|
| 833 |
+
Yielding data with length 28513
|
| 834 |
+
Yielding data with length 15147
|
| 835 |
+
skip a sample with length 41106
|
| 836 |
+
Yielding data with length 16968
|
| 837 |
+
Yielding data with length 13491
|
| 838 |
+
Yielding data with length 22125
|
| 839 |
+
Yielding data with length 21138
|
| 840 |
+
Yielding data with length 24903
|
| 841 |
+
Yielding data with length 28043
|
| 842 |
+
skip a sample with length 42486
|
| 843 |
+
Yielding data with length 31782
|
| 844 |
+
Yielding data with length 6701
|
| 845 |
+
Yielding data with length 13494
|
| 846 |
+
Yielding data with length 15875
|
| 847 |
+
Yielding data with length 17545
|
| 848 |
+
Yielding data with length 21060
|
| 849 |
+
Yielding data with length 22115
|
| 850 |
+
Yielding data with length 29729
|
| 851 |
+
Yielding data with length 31752
|
| 852 |
+
skip a sample with length 42486
|
| 853 |
+
Yielding data with length 17316
|
| 854 |
+
block_dataset repeat in rank-0 worker-0
|
| 855 |
+
Yielding data with length 24478
|
| 856 |
+
Yielding data with length 24714
|
| 857 |
+
skip a sample with length 42480
|
| 858 |
+
Yielding data with length 24145
|
| 859 |
+
Yielding data with length 25188
|
| 860 |
+
Yielding data with length 21724
|
| 861 |
+
block_dataset repeat in rank-6 worker-0
|
| 862 |
+
Yielding data with length 28652
|
| 863 |
+
Yielding data with length 31606
|
| 864 |
+
Yielding data with length 8619
|
| 865 |
+
Yielding data with length 16608
|
| 866 |
+
Yielding data with length 21134
|
| 867 |
+
Yielding data with length 28671
|
| 868 |
+
Yielding data with length 24139
|
| 869 |
+
Yielding data with length 34737
|
| 870 |
+
Yielding data with length 28959
|
| 871 |
+
Yielding data with length 30967
|
scripts/eval/eval_vlm.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# Check if enough arguments are provided
|
| 5 |
+
if [ $# -lt 2 ]; then
|
| 6 |
+
echo "Error: PREFIX_DIR and MODEL_PATH are required as the first and second arguments respectively."
|
| 7 |
+
exit 1
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
LOG_PATH=$1
|
| 11 |
+
if [ ! -d "$LOG_PATH" ]; then
|
| 12 |
+
mkdir -p "$LOG_PATH"
|
| 13 |
+
fi
|
| 14 |
+
shift 1
|
| 15 |
+
ARGS=("$@")
|
| 16 |
+
export MASTER_PORT=10042
|
| 17 |
+
|
| 18 |
+
FULL_MODEL_PATH="$PREFIX_DIR/$MODEL_PATH"
|
| 19 |
+
|
| 20 |
+
IFS=' ' read -r -a DATASETS <<< "$DATASETS_STR"
|
| 21 |
+
|
| 22 |
+
for DATASET in "${DATASETS[@]}"; do
|
| 23 |
+
bash eval/vlm/evaluate.sh \
|
| 24 |
+
"$DATASET" \
|
| 25 |
+
--out-dir "$LOG_PATH/$DATASET" \
|
| 26 |
+
"${ARGS[@]}"
|
| 27 |
+
done
|
scripts/eval/run_eval_vlm.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
# Set proxy and API key
|
| 7 |
+
export OPENAI_API_KEY=$openai_api_key
|
| 8 |
+
|
| 9 |
+
export GPUS=1
|
| 10 |
+
|
| 11 |
+
DATASETS=("mme" "mmbench-dev-en" "mmvet" "mmmu-val" "mathvista-testmini" "mmvp")
|
| 12 |
+
# DATASETS=("mmmu-val_cot")
|
| 13 |
+
|
| 14 |
+
DATASETS_STR="${DATASETS[*]}"
|
| 15 |
+
export DATASETS_STR
|
| 16 |
+
|
| 17 |
+
bash scripts/eval/eval_vlm.sh \
|
| 18 |
+
$output_path \
|
| 19 |
+
--model-path $model_path
|
scripts/eval/run_gedit.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# run this script at the root of the project folder
|
| 5 |
+
pip install httpx==0.23.0
|
| 6 |
+
pip install openai==1.87.0
|
| 7 |
+
pip install datasets
|
| 8 |
+
pip install megfile
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
N_GPU=8 # Number of GPU used in for the evaluation
|
| 12 |
+
MODEL_PATH="/Path/to/BAGEL-7B-MoT"
|
| 13 |
+
OUTPUT_DIR="/Path/to/save/results"
|
| 14 |
+
GEN_DIR="$OUTPUT_DIR/gen_image"
|
| 15 |
+
LOG_DIR="$OUTPUT_DIR/logs"
|
| 16 |
+
|
| 17 |
+
AZURE_ENDPOINT="https://azure_endpoint_url_you_use" # set up the azure openai endpoint url
|
| 18 |
+
AZURE_OPENAI_KEY="" # set up the azure openai key
|
| 19 |
+
N_GPT_PARALLEL=10
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
mkdir -p "$OUTPUT_DIR"
|
| 23 |
+
mkdir -p "$GEN_DIR"
|
| 24 |
+
mkdir -p "$LOG_DIR"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# # ----------------------------
|
| 28 |
+
# # Download GEdit Dataset
|
| 29 |
+
# # ----------------------------
|
| 30 |
+
python -c "from datasets import load_dataset; dataset = load_dataset('stepfun-ai/GEdit-Bench')"
|
| 31 |
+
echo "Dataset Downloaded"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# # ---------------------
|
| 35 |
+
# # Generate Images
|
| 36 |
+
# # ---------------------
|
| 37 |
+
for ((i=0; i<$N_GPU; i++)); do
|
| 38 |
+
nohup python3 eval/gen/gedit/gen_images_gedit.py --model_path "$MODEL_PATH" --output_dir "$GEN_DIR" --shard_id $i --total_shards "$N_GPU" --device $i 2>&1 | tee "$LOG_DIR"/request_$(($N_GPU + i)).log &
|
| 39 |
+
done
|
| 40 |
+
|
| 41 |
+
wait
|
| 42 |
+
echo "Image Generation Done"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# # ---------------------
|
| 46 |
+
# # GPT Evaluation
|
| 47 |
+
# # ---------------------
|
| 48 |
+
cd eval/gen/gedit
|
| 49 |
+
python test_gedit_score.py --save_path "$OUTPUT_DIR" --azure_endpoint "$AZURE_ENDPOINT" --gpt_keys "$AZURE_OPENAI_KEY" --max_workers "$N_GPT_PARALLEL"
|
| 50 |
+
echo "Evaluation Done"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# # --------------------
|
| 54 |
+
# # Print Results
|
| 55 |
+
# # --------------------
|
| 56 |
+
python calculate_statistics.py --save_path "$OUTPUT_DIR" --language en
|
| 57 |
+
|
scripts/eval/run_geneval.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
GPUS=8
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# generate images
|
| 10 |
+
torchrun \
|
| 11 |
+
--nnodes=1 \
|
| 12 |
+
--node_rank=0 \
|
| 13 |
+
--nproc_per_node=$GPUS \
|
| 14 |
+
--master_addr=127.0.0.1 \
|
| 15 |
+
--master_port=12345 \
|
| 16 |
+
./eval/gen/gen_images_mp.py \
|
| 17 |
+
--output_dir $output_path/images \
|
| 18 |
+
--metadata_file ./eval/gen/geneval/prompts/evaluation_metadata_long.jsonl \
|
| 19 |
+
--batch_size 1 \
|
| 20 |
+
--num_images 4 \
|
| 21 |
+
--resolution 1024 \
|
| 22 |
+
--max_latent_size 64 \
|
| 23 |
+
--model-path $model_path \
|
| 24 |
+
# --metadata_file ./eval/gen/geneval/prompts/evaluation_metadata.jsonl \
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# calculate score
|
| 28 |
+
torchrun \
|
| 29 |
+
--nnodes=1 \
|
| 30 |
+
--node_rank=0 \
|
| 31 |
+
--nproc_per_node=$GPUS \
|
| 32 |
+
--master_addr=127.0.0.1 \
|
| 33 |
+
--master_port=12345 \
|
| 34 |
+
./eval/gen/geneval/evaluation/evaluate_images_mp.py \
|
| 35 |
+
$output_path/images \
|
| 36 |
+
--outfile $output_path/results.jsonl \
|
| 37 |
+
--model-path ./eval/gen/geneval/model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# summarize score
|
| 41 |
+
python ./eval/gen/geneval/evaluation/summary_scores.py $output_path/results.jsonl
|
scripts/eval/run_imgedit.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
export OPENAI_API_KEY=$openai_api_key
|
| 7 |
+
|
| 8 |
+
GPUS=8
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# generate images
|
| 12 |
+
torchrun \
|
| 13 |
+
--nnodes=1 \
|
| 14 |
+
--node_rank=0 \
|
| 15 |
+
--nproc_per_node=$GPUS \
|
| 16 |
+
--master_addr=127.0.0.1 \
|
| 17 |
+
--master_port=12345 \
|
| 18 |
+
./eval/gen/gen_images_mp_imgedit.py \
|
| 19 |
+
--output_dir $output_path/bagel \
|
| 20 |
+
--metadata_file ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
|
| 21 |
+
--max_latent_size 64 \
|
| 22 |
+
--model-path $model_path
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# calculate score
|
| 26 |
+
python ./eval/gen/imgedit/basic_bench.py \
|
| 27 |
+
--result_img_folder $output_path/bagel \
|
| 28 |
+
--edit_json ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
|
| 29 |
+
--origin_img_root ./eval/gen/imgedit/Benchmark/singleturn \
|
| 30 |
+
--num_processes 4 \
|
| 31 |
+
--prompts_json ./eval/gen/imgedit/Benchmark/singleturn/judge_prompt.json
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# summarize score
|
| 35 |
+
python ./eval/gen/imgedit/step1_get_avgscore.py \
|
| 36 |
+
--result_json $output_path/bagel/result.json \
|
| 37 |
+
--average_score_json $output_path/bagel/average_score.json
|
| 38 |
+
|
| 39 |
+
python ./eval/gen/imgedit/step2_typescore.py \
|
| 40 |
+
--average_score_json $output_path/bagel/average_score.json \
|
| 41 |
+
--edit_json ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
|
| 42 |
+
--typescore_json $output_path/bagel/typescore.json
|
scripts/eval/run_kris.sh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
export OPENAI_API_KEY=$openai_api_key
|
| 7 |
+
|
| 8 |
+
GPUS=8
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# generate images
|
| 12 |
+
torchrun \
|
| 13 |
+
--nnodes=1 \
|
| 14 |
+
--node_rank=0 \
|
| 15 |
+
--nproc_per_node=$GPUS \
|
| 16 |
+
--master_addr=127.0.0.1 \
|
| 17 |
+
--master_port=12345 \
|
| 18 |
+
./eval/gen/gen_images_mp_kris.py \
|
| 19 |
+
--output_dir $output_path/bagel \
|
| 20 |
+
--metadata_file ./eval/gen/kris/final_data.json \
|
| 21 |
+
--max_latent_size 64 \
|
| 22 |
+
--model-path $model_path \
|
| 23 |
+
--think
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# calculate score
|
| 27 |
+
python ./eval/gen/kris/metrics_common.py \
|
| 28 |
+
--results_dir $output_path \
|
| 29 |
+
--max_workers 8
|
| 30 |
+
|
| 31 |
+
python ./eval/gen/kris/metrics_knowledge.py \
|
| 32 |
+
--results_dir $output_path \
|
| 33 |
+
--max_workers 8
|
| 34 |
+
|
| 35 |
+
python ./eval/gen/kris/metrics_multi_element.py \
|
| 36 |
+
--results_dir $output_path \
|
| 37 |
+
--max_workers 8
|
| 38 |
+
|
| 39 |
+
python ./eval/gen/kris/metrics_temporal_prediction.py \
|
| 40 |
+
--results_dir $output_path \
|
| 41 |
+
--max_workers 8
|
| 42 |
+
|
| 43 |
+
python ./eval/gen/kris/metrics_view_change.py \
|
| 44 |
+
--results_dir $output_path \
|
| 45 |
+
--max_workers 8
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# summarize score
|
| 49 |
+
python ./eval/gen/kris/summarize.py \
|
| 50 |
+
--results_dir $output_path/bagel \
|
scripts/eval/run_rise.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
export OPENAI_API_KEY=$openai_api_key
|
| 7 |
+
|
| 8 |
+
GPUS=8
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# generate images
|
| 12 |
+
torchrun \
|
| 13 |
+
--nnodes=1 \
|
| 14 |
+
--node_rank=0 \
|
| 15 |
+
--nproc_per_node=$GPUS \
|
| 16 |
+
--master_addr=127.0.0.1 \
|
| 17 |
+
--master_port=12345 \
|
| 18 |
+
./eval/gen/gen_images_mp_rise.py \
|
| 19 |
+
--output_dir $output_path/bagel \
|
| 20 |
+
--metadata_file ./eval/gen/rise/data/datav2_total_w_subtask.json \
|
| 21 |
+
--max_latent_size 64 \
|
| 22 |
+
--model-path $model_path \
|
| 23 |
+
--think
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# calculate score
|
| 27 |
+
python ./eval/gen/rise/gpt_eval.py \
|
| 28 |
+
--data ./eval/gen/rise/data/datav2_total_w_subtask.json \
|
| 29 |
+
--input ./eval/gen/rise/data \
|
| 30 |
+
--output $output_path/bagel
|
scripts/eval/run_wise.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
export OPENAI_API_KEY=$openai_api_key
|
| 7 |
+
|
| 8 |
+
GPUS=8
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# generate images
|
| 12 |
+
torchrun \
|
| 13 |
+
--nnodes=1 \
|
| 14 |
+
--node_rank=0 \
|
| 15 |
+
--nproc_per_node=$GPUS \
|
| 16 |
+
--master_addr=127.0.0.1 \
|
| 17 |
+
--master_port=12345 \
|
| 18 |
+
./eval/gen/gen_images_mp_wise.py \
|
| 19 |
+
--output_dir $output_path/images \
|
| 20 |
+
--metadata-file ./eval/gen/wise/final_data.json \
|
| 21 |
+
--resolution 1024 \
|
| 22 |
+
--max-latent_size 64 \
|
| 23 |
+
--model-path $model_path \
|
| 24 |
+
--think
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# calculate score
|
| 28 |
+
python3 eval/gen/wise/gpt_eval_mp.py \
|
| 29 |
+
--json_path eval/gen/wise/data/cultural_common_sense.json \
|
| 30 |
+
--image_dir $output_path/images \
|
| 31 |
+
--output_dir $output_path
|
| 32 |
+
|
| 33 |
+
python3 eval/gen/wise/gpt_eval_mp.py \
|
| 34 |
+
--json_path eval/gen/wise/data/spatio-temporal_reasoning.json \
|
| 35 |
+
--image_dir $output_path/images \
|
| 36 |
+
--output_dir $output_path
|
| 37 |
+
|
| 38 |
+
python3 eval/gen/wise/gpt_eval_mp.py \
|
| 39 |
+
--json_path eval/gen/wise/data/natural_science.json \
|
| 40 |
+
--image_dir $output_path/images \
|
| 41 |
+
--output_dir $output_path
|
| 42 |
+
|
| 43 |
+
python3 eval/gen/wise/cal_score.py \
|
| 44 |
+
--output_dir $output_path
|
scripts/train.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
export HF_HOME=/dev/shm/
|
| 5 |
+
NUM_NODES=1
|
| 6 |
+
NODE_RANK=0
|
| 7 |
+
MASTER_ADDR=localhost
|
| 8 |
+
MASTER_PORT=29500
|
| 9 |
+
NPROC_PER_NODE=8
|
| 10 |
+
MODEL_PATH=/dev/shm/models/BAGEL-7B-MoT
|
| 11 |
+
|
| 12 |
+
# replace the variables with your own
|
| 13 |
+
torchrun \
|
| 14 |
+
--nnodes=$NUM_NODES \
|
| 15 |
+
--node_rank=$NODE_RANK \
|
| 16 |
+
--nproc_per_node=$NPROC_PER_NODE \
|
| 17 |
+
--master_addr=$MASTER_ADDR \
|
| 18 |
+
--master_port=$MASTER_PORT \
|
| 19 |
+
train/pretrain_unified_navit.py \
|
| 20 |
+
--dataset_config_file ./data/configs/example.yaml \
|
| 21 |
+
--model_path $MODEL_PATH \
|
| 22 |
+
--layer_module Qwen2MoTDecoderLayer \
|
| 23 |
+
--max_latent_size 64 \
|
| 24 |
+
--resume-from $MODEL_PATH \
|
| 25 |
+
--finetune_from_hf True \
|
| 26 |
+
--auto_resume True \
|
| 27 |
+
--resume-model-only True \
|
| 28 |
+
--finetune-from-ema True \
|
| 29 |
+
--log_every 1 \
|
| 30 |
+
--lr 2e-5 \
|
| 31 |
+
--lr_scheduler cosine \
|
| 32 |
+
--min_lr 1e-6 \
|
| 33 |
+
--num_worker 1 \
|
| 34 |
+
--expected_num_tokens 60000 \
|
| 35 |
+
--max_num_tokens 60000 \
|
| 36 |
+
--max_num_tokens_per_sample 60000 \
|
| 37 |
+
--prefer_buffer_before 30000 \
|
| 38 |
+
--num_shard=$NPROC_PER_NODE \
|
| 39 |
+
--sharding_strategy="HYBRID_SHARD" \
|
| 40 |
+
--wandb_project "zebra-cot" \
|
| 41 |
+
--wandb_name "h200-zebra-cot-$(date +%Y%m%d_%H%M%S)" \
|
| 42 |
+
--save_every 50 \
|
| 43 |
+
--warmup_steps 50 \
|
| 44 |
+
--total_steps 5000 \
|
| 45 |
+
--results_dir results/ \
|
| 46 |
+
--checkpoint_dir results/checkpoints/ > run.out 2> run.err
|
| 47 |
+
|
| 48 |
+
# --cpu_offload True \
|
scripts/train_smm.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
# Change to the project directory
|
| 6 |
+
cd /scratch/by2593/Bagel-Zebra-CoT-origin
|
| 7 |
+
|
| 8 |
+
export HF_HOME=/dev/shm/
|
| 9 |
+
export PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin:$PYTHONPATH
|
| 10 |
+
export WANDB_MODE=offline
|
| 11 |
+
export WANDB_ANONYMOUS=must
|
| 12 |
+
NUM_NODES=1
|
| 13 |
+
NODE_RANK=0
|
| 14 |
+
MASTER_ADDR=localhost
|
| 15 |
+
MASTER_PORT=29500
|
| 16 |
+
NPROC_PER_NODE=8
|
| 17 |
+
MODEL_PATH=/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8
|
| 18 |
+
|
| 19 |
+
# replace the variables with your own
|
| 20 |
+
torchrun \
|
| 21 |
+
--nnodes=$NUM_NODES \
|
| 22 |
+
--node_rank=$NODE_RANK \
|
| 23 |
+
--nproc_per_node=$NPROC_PER_NODE \
|
| 24 |
+
--master_addr=$MASTER_ADDR \
|
| 25 |
+
--master_port=$MASTER_PORT \
|
| 26 |
+
train/pretrain_unified_navit.py \
|
| 27 |
+
--dataset_config_file ./data/configs/example_smm_semantic.yaml \
|
| 28 |
+
--model_path $MODEL_PATH \
|
| 29 |
+
--layer_module Qwen2MoTDecoderLayer \
|
| 30 |
+
--max_latent_size 64 \
|
| 31 |
+
--resume-from $MODEL_PATH \
|
| 32 |
+
--finetune_from_hf True \
|
| 33 |
+
--auto_resume True \
|
| 34 |
+
--resume-model-only True \
|
| 35 |
+
--finetune-from-ema False \
|
| 36 |
+
--log_every 1 \
|
| 37 |
+
--lr 2e-5 \
|
| 38 |
+
--lr_scheduler cosine \
|
| 39 |
+
--min_lr 1e-6 \
|
| 40 |
+
--num_worker 1 \
|
| 41 |
+
--expected_num_tokens 40000 \
|
| 42 |
+
--max_num_tokens 40000 \
|
| 43 |
+
--max_num_tokens_per_sample 40000 \
|
| 44 |
+
--prefer_buffer_before 10000 \
|
| 45 |
+
--num_shard=$NPROC_PER_NODE \
|
| 46 |
+
--sharding_strategy="HYBRID_SHARD" \
|
| 47 |
+
--wandb_project "zebra-cot" \
|
| 48 |
+
--wandb_name "h200-zebra-cot-$(date +%Y%m%d_%H%M%S)" \
|
| 49 |
+
--save_every 100 \
|
| 50 |
+
--warmup_steps 50 \
|
| 51 |
+
--total_steps 5000 \
|
| 52 |
+
--results_dir results/ \
|
| 53 |
+
--checkpoint_dir results/checkpoints_smm_semantic_part1_v1_origin/ > run.out 2> run.err \
|
| 54 |
+
--cpu_offload True \
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# bash scripts/train_smm.sh
|
scripts/train_smm_sbatch.sh
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=bagel-zebra-cot-smm
|
| 3 |
+
#SBATCH --partition=h200_tandon
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks-per-node=1
|
| 6 |
+
#SBATCH --cpus-per-task=32
|
| 7 |
+
#SBATCH --gres=gpu:h200:8
|
| 8 |
+
#SBATCH --mem=1600G
|
| 9 |
+
#SBATCH --time=48:00:00
|
| 10 |
+
#SBATCH --output=slurm_logs/train_smm_%j.out
|
| 11 |
+
#SBATCH --error=slurm_logs/train_smm_%j.err
|
| 12 |
+
|
| 13 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 14 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 15 |
+
|
| 16 |
+
# Load any necessary modules (adjust as needed for your cluster)
|
| 17 |
+
# module load cuda/12.1
|
| 18 |
+
# module load conda
|
| 19 |
+
|
| 20 |
+
# Activate conda environment
|
| 21 |
+
source /scratch/by2593/miniconda3/etc/profile.d/conda.sh
|
| 22 |
+
conda activate bagel
|
| 23 |
+
|
| 24 |
+
# Change to the project directory
|
| 25 |
+
cd /scratch/by2593/Bagel-Zebra-CoT-origin
|
| 26 |
+
|
| 27 |
+
# Set environment variables
|
| 28 |
+
export HF_HOME=/dev/shm/
|
| 29 |
+
export PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin:$PYTHONPATH
|
| 30 |
+
export WANDB_MODE=offline
|
| 31 |
+
export WANDB_ANONYMOUS=must
|
| 32 |
+
|
| 33 |
+
# SLURM variables
|
| 34 |
+
NUM_NODES=1
|
| 35 |
+
NODE_RANK=0
|
| 36 |
+
MASTER_ADDR=$(hostname)
|
| 37 |
+
MASTER_PORT=29500
|
| 38 |
+
NPROC_PER_NODE=8
|
| 39 |
+
MODEL_PATH=/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8
|
| 40 |
+
|
| 41 |
+
echo "Starting SMM training on node: $SLURM_JOB_NODELIST"
|
| 42 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 43 |
+
echo "Number of GPUs: $NPROC_PER_NODE"
|
| 44 |
+
|
| 45 |
+
# Run training
|
| 46 |
+
torchrun \
|
| 47 |
+
--nnodes=$NUM_NODES \
|
| 48 |
+
--node_rank=$NODE_RANK \
|
| 49 |
+
--nproc_per_node=$NPROC_PER_NODE \
|
| 50 |
+
--master_addr=$MASTER_ADDR \
|
| 51 |
+
--master_port=$MASTER_PORT \
|
| 52 |
+
train/pretrain_unified_navit.py \
|
| 53 |
+
--dataset_config_file ./data/configs/example_smm_random.yaml \
|
| 54 |
+
--model_path $MODEL_PATH \
|
| 55 |
+
--layer_module Qwen2MoTDecoderLayer \
|
| 56 |
+
--max_latent_size 64 \
|
| 57 |
+
--visual_und True \
|
| 58 |
+
--finetune_from_hf True \
|
| 59 |
+
--auto_resume True \
|
| 60 |
+
--resume-model-only False \
|
| 61 |
+
--finetune-from-ema False \
|
| 62 |
+
--log_every 1 \
|
| 63 |
+
--lr 2e-5 \
|
| 64 |
+
--lr_scheduler cosine \
|
| 65 |
+
--min_lr 1e-6 \
|
| 66 |
+
--num_worker 1 \
|
| 67 |
+
--expected_num_tokens 50000 \
|
| 68 |
+
--max_num_tokens 50000 \
|
| 69 |
+
--max_num_tokens_per_sample 50000 \
|
| 70 |
+
--prefer_buffer_before 10000 \
|
| 71 |
+
--num_shard=$NPROC_PER_NODE \
|
| 72 |
+
--sharding_strategy="HYBRID_SHARD" \
|
| 73 |
+
--wandb_project "smm" \
|
| 74 |
+
--wandb_name "h200-zebra-cot-smm-sbatch-$(date +%Y%m%d_%H%M%S)" \
|
| 75 |
+
--save_every 100 \
|
| 76 |
+
--warmup_steps 50 \
|
| 77 |
+
--total_steps 5000 \
|
| 78 |
+
--results_dir results/ \
|
| 79 |
+
--checkpoint_dir /scratch/by2593/Bagel-Zebra-CoT-origin/results/checkpoints_smm_random_20251026_033448/ \
|
| 80 |
+
--cpu_offload True \
|
| 81 |
+
--max_checkpoints 2
|
| 82 |
+
|
| 83 |
+
echo "SMM training completed on $(date)"
|
| 84 |
+
|
| 85 |
+
# sbatch scripts/train_smm_sbatch.sh
|
test_images/image.png
ADDED
|
Git LFS Details
|
test_images/meme.jpg
ADDED
|
test_images/octupusy.jpg
ADDED
|
test_images/women.jpg
ADDED
|