yinbq commited on
Commit
8834223
·
verified ·
1 Parent(s): 0341b51

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. assets/arch.png +3 -0
  3. assets/bagel-cot-example.png +3 -0
  4. assets/emerging_curves.png +3 -0
  5. assets/teaser.webp +3 -0
  6. assets/zebra_cot_datacard.png +3 -0
  7. data/__init__.py +2 -0
  8. data/configs/example.yaml +50 -0
  9. data/configs/example_smm_random.yaml +50 -0
  10. data/dataset_base.py +768 -0
  11. data/dataset_info.py +46 -0
  12. data/distributed_iterable_dataset.py +58 -0
  13. data/interleave_datasets/edit_dataset.py +72 -0
  14. data/interleave_datasets/interleave_t2i_dataset.py +218 -0
  15. data/interleave_datasets/think_trace_dataset.py +289 -0
  16. modeling/__init__.py +4 -0
  17. modeling/autoencoder.py +360 -0
  18. modeling/bagel/bagel.py +1068 -0
  19. modeling/bagel/modeling_utils.py +144 -0
  20. modeling/bagel/qwen2_navit.py +1157 -0
  21. modeling/bagel/siglip_navit.py +402 -0
  22. modeling/qwen2/__init__.py +68 -0
  23. modeling/qwen2/configuration_qwen2.py +179 -0
  24. modeling/qwen2/modeling_qwen2.py +929 -0
  25. modeling/qwen2/tokenization_qwen2.py +328 -0
  26. modeling/qwen2/tokenization_qwen2_fast.py +123 -0
  27. modeling/siglip/__init__.py +98 -0
  28. modeling/siglip/configuration_siglip.py +287 -0
  29. modeling/siglip/convert_siglip_to_hf.py +401 -0
  30. modeling/siglip/image_processing_siglip.py +230 -0
  31. modeling/siglip/modeling_siglip.py +1557 -0
  32. modeling/siglip/processing_siglip.py +131 -0
  33. modeling/siglip/tokenization_siglip.py +364 -0
  34. run.err +150 -0
  35. run.out +871 -0
  36. scripts/eval/eval_vlm.sh +27 -0
  37. scripts/eval/run_eval_vlm.sh +19 -0
  38. scripts/eval/run_gedit.sh +57 -0
  39. scripts/eval/run_geneval.sh +41 -0
  40. scripts/eval/run_imgedit.sh +42 -0
  41. scripts/eval/run_kris.sh +50 -0
  42. scripts/eval/run_rise.sh +30 -0
  43. scripts/eval/run_wise.sh +44 -0
  44. scripts/train.sh +48 -0
  45. scripts/train_smm.sh +57 -0
  46. scripts/train_smm_sbatch.sh +85 -0
  47. test_images/image.png +3 -0
  48. test_images/meme.jpg +0 -0
  49. test_images/octupusy.jpg +0 -0
  50. test_images/women.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/arch.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/bagel-cot-example.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/emerging_curves.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
40
+ assets/zebra_cot_datacard.png filter=lfs diff=lfs merge=lfs -text
41
+ test_images/image.png filter=lfs diff=lfs merge=lfs -text
assets/arch.png ADDED

Git LFS Details

  • SHA256: 28affbbfede911a75884bae4e8e1d5b897b8b450fa4c7d9b68818d05492b0967
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB
assets/bagel-cot-example.png ADDED

Git LFS Details

  • SHA256: e6852144610280fec76591276f090d163479cb54b7e1064e9d9ab77f9fa5e582
  • Pointer size: 132 Bytes
  • Size of remote file: 4.43 MB
assets/emerging_curves.png ADDED

Git LFS Details

  • SHA256: 0c1ddd355742cddb52045ee59098305cc5de8174cb09afa019bb9afefd868733
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
assets/teaser.webp ADDED

Git LFS Details

  • SHA256: d679e69a1fbdb7f9abceb59d9bc3d29ab65b7e871ba48b59aec0a7f35defa558
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
assets/zebra_cot_datacard.png ADDED

Git LFS Details

  • SHA256: 13a0df1dd68f77d535d41b2dfcb092c1f015289c4ca326d74322a9c7e98b5b17
  • Pointer size: 132 Bytes
  • Size of remote file: 3.86 MB
data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
data/configs/example.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ think_trace:
2
+ dataset_names:
3
+ - think_trace_dataset
4
+ jsonl_path_list: ["/dev/shm/data/Zebra-CoT/zebra_cot.jsonl"]
5
+ num_used_data: None
6
+ image_prefix_dir: "/dev/shm/data/Zebra-CoT"
7
+ image_transform_args:
8
+ image_stride: 16
9
+ max_image_size: 512
10
+ min_image_size: 512
11
+ vit_image_transform_args:
12
+ image_stride: 14
13
+ max_image_size: 512
14
+ min_image_size: 512
15
+ weight: 1.0
16
+ is_mandatory: true
17
+
18
+ # unified_edit:
19
+ # dataset_names:
20
+ # - seedxedit_multi
21
+ # image_transform_args:
22
+ # image_stride: 16
23
+ # max_image_size: 1024
24
+ # min_image_size: 512
25
+ # vit_image_transform_args:
26
+ # image_stride: 14
27
+ # max_image_size: 518
28
+ # min_image_size: 224
29
+ # is_mandatory: true
30
+ # num_used_data:
31
+ # - 10
32
+ # weight: 1
33
+
34
+ # vlm_sft:
35
+ # dataset_names:
36
+ # - llava_ov
37
+ # image_transform_args:
38
+ # image_stride: 14
39
+ # max_image_size: 980
40
+ # min_image_size: 378
41
+ # max_pixels: 2_007_040
42
+ # frame_sampler_args:
43
+ # max_num_frames: 12
44
+ # min_num_frames: 8
45
+ # is_mandatory: true
46
+ # shuffle_lines: True
47
+ # shuffle_seed: 0
48
+ # num_used_data:
49
+ # - 1000
50
+ # weight: 1
data/configs/example_smm_random.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ block_dataset_random:
2
+ dataset_names:
3
+ - block_dataset_random
4
+ jsonl_path_list: ["/scratch/by2593/project/SMM/SMM_data/random_block.jsonl"]
5
+ num_used_data: None
6
+ image_prefix_dir: "/scratch/by2593/project/SMM/random_pipeline/random_blocks"
7
+ image_transform_args:
8
+ image_stride: 16
9
+ max_image_size: 512 # VAE使用stride=16, 512/16=32 patches
10
+ min_image_size: 512
11
+ vit_image_transform_args:
12
+ image_stride: 14
13
+ max_image_size: 512 # ViT使用stride=14, 512/14=36 patches (匹配模型能力)
14
+ min_image_size: 512
15
+ weight: 1.0
16
+ is_mandatory: true
17
+
18
+ # unified_edit:
19
+ # dataset_names:
20
+ # - seedxedit_multi
21
+ # image_transform_args:
22
+ # image_stride: 16
23
+ # max_image_size: 1024
24
+ # min_image_size: 512
25
+ # vit_image_transform_args:
26
+ # image_stride: 14
27
+ # max_image_size: 518
28
+ # min_image_size: 224
29
+ # is_mandatory: true
30
+ # num_used_data:
31
+ # - 10
32
+ # weight: 1
33
+
34
+ # vlm_sft:
35
+ # dataset_names:
36
+ # - llava_ov
37
+ # image_transform_args:
38
+ # image_stride: 14
39
+ # max_image_size: 980
40
+ # min_image_size: 378
41
+ # max_pixels: 2_007_040
42
+ # frame_sampler_args:
43
+ # max_num_frames: 12
44
+ # min_num_frames: 8
45
+ # is_mandatory: true
46
+ # shuffle_lines: True
47
+ # shuffle_seed: 0
48
+ # num_used_data:
49
+ # - 1000
50
+ # weight: 1
data/dataset_base.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import random
6
+ import json
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .data_utils import (
12
+ get_flattened_position_ids_interpolate,
13
+ get_flattened_position_ids_extrapolate,
14
+ len2weight,
15
+ patchify,
16
+ prepare_attention_mask_per_sample,
17
+ )
18
+ from .dataset_info import DATASET_INFO, DATASET_REGISTRY
19
+ from .transforms import ImageTransform
20
+ from .video_utils import FrameSampler
21
+
22
+
23
+ class DataConfig:
24
+ def __init__(
25
+ self,
26
+ grouped_datasets,
27
+ text_cond_dropout_prob=0.1,
28
+ vit_cond_dropout_prob=0.4,
29
+ vae_cond_dropout_prob=0.1,
30
+ vae_image_downsample=16,
31
+ max_latent_size=32,
32
+ vit_patch_size=14,
33
+ max_num_patch_per_side=70,
34
+ ):
35
+ self.grouped_datasets = grouped_datasets
36
+ self.text_cond_dropout_prob = text_cond_dropout_prob
37
+ self.vit_cond_dropout_prob = vit_cond_dropout_prob
38
+ self.vit_patch_size = vit_patch_size
39
+ self.max_num_patch_per_side = max_num_patch_per_side
40
+ self.vae_cond_dropout_prob = vae_cond_dropout_prob
41
+ self.vae_image_downsample = vae_image_downsample
42
+ self.max_latent_size = max_latent_size
43
+
44
+
45
+ class PackedDataset(torch.utils.data.IterableDataset):
46
+ def __init__(
47
+ self,
48
+ data_config,
49
+ tokenizer,
50
+ special_tokens,
51
+ local_rank,
52
+ world_size,
53
+ num_workers,
54
+ expected_num_tokens=32768,
55
+ max_num_tokens_per_sample=16384,
56
+ max_num_tokens=36864,
57
+ prefer_buffer_before=16384,
58
+ max_buffer_size=50,
59
+ interpolate_pos=False,
60
+ use_flex=False,
61
+ data_status=None,
62
+ ):
63
+ super().__init__()
64
+ self.expected_num_tokens = expected_num_tokens
65
+ self.max_num_tokens_per_sample = max_num_tokens_per_sample
66
+ self.prefer_buffer_before = prefer_buffer_before
67
+ self.max_num_tokens = max_num_tokens
68
+ self.max_buffer_size = max_buffer_size
69
+ self.tokenizer = tokenizer
70
+ self.local_rank = local_rank
71
+ self.world_size = world_size
72
+ self.num_workers = num_workers
73
+ self.use_flex = use_flex
74
+ for k, v in special_tokens.items():
75
+ setattr(self, k, v)
76
+
77
+ grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
78
+ data_config.grouped_datasets, data_status
79
+ )
80
+ self.grouped_datasets = grouped_datasets
81
+ self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
82
+ self.is_mandatory = is_mandatory
83
+ self.grouped_weights = grouped_weights
84
+ self.data_config = data_config
85
+ self.interpolate_pos = interpolate_pos
86
+ if self.interpolate_pos:
87
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
88
+ else:
89
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
90
+
91
+ def build_datasets(self, datasets_metainfo, data_status):
92
+ datasets = []
93
+ is_mandatory = []
94
+ grouped_weights = []
95
+ for grouped_dataset_name, dataset_args in datasets_metainfo.items():
96
+ is_mandatory.append(dataset_args.pop('is_mandatory', False))
97
+ grouped_weights.append(dataset_args.pop('weight', 0.0))
98
+
99
+ if 'frame_sampler_args' in dataset_args.keys():
100
+ frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
101
+ dataset_args['frame_sampler'] = frame_sampler
102
+ if 'image_transform_args' in dataset_args.keys():
103
+ transform = ImageTransform(**dataset_args.pop('image_transform_args'))
104
+ dataset_args['transform'] = transform
105
+ if 'vit_image_transform_args' in dataset_args.keys():
106
+ vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
107
+ dataset_args['vit_transform'] = vit_transform
108
+
109
+ assert 'dataset_names' in dataset_args.keys()
110
+ dataset_names = dataset_args.pop('dataset_names')
111
+ dataset_args['data_dir_list'] = []
112
+ for item in dataset_names:
113
+ if self.local_rank == 0:
114
+ print(f'Preparing Dataset {grouped_dataset_name}/{item}')
115
+ meta_info = DATASET_INFO[grouped_dataset_name][item]
116
+ dataset_args['data_dir_list'].append(meta_info['data_dir'])
117
+
118
+ if "parquet_info_path" in meta_info.keys():
119
+ if 'parquet_info' not in dataset_args.keys():
120
+ dataset_args['parquet_info'] = {}
121
+ with open(meta_info['parquet_info_path'], 'r') as f:
122
+ parquet_info = json.load(f)
123
+ dataset_args['parquet_info'].update(parquet_info)
124
+
125
+ if 'json_dir' in meta_info.keys():
126
+ # parquet/tar with json
127
+ if 'json_dir_list' not in dataset_args.keys():
128
+ dataset_args['json_dir_list'] = [meta_info['json_dir']]
129
+ else:
130
+ dataset_args['json_dir_list'].append(meta_info['json_dir'])
131
+
132
+ if 'jsonl_path' in meta_info.keys():
133
+ # jsonl with jpeg
134
+ if 'jsonl_path_list' not in dataset_args.keys():
135
+ dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
136
+ else:
137
+ dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
138
+
139
+ if 'image_prefix_dir' in meta_info.keys():
140
+ dataset_args['image_prefix_dir'] = meta_info['image_prefix_dir']
141
+
142
+ resume_data_status = dataset_args.pop('resume_data_status', True)
143
+ if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
144
+ data_status_per_group = data_status[grouped_dataset_name]
145
+ else:
146
+ data_status_per_group = None
147
+ dataset = DATASET_REGISTRY[grouped_dataset_name](
148
+ dataset_name=grouped_dataset_name,
149
+ tokenizer=self.tokenizer,
150
+ local_rank=self.local_rank,
151
+ world_size=self.world_size,
152
+ num_workers=self.num_workers,
153
+ data_status=data_status_per_group,
154
+ **dataset_args
155
+ )
156
+ datasets.append(dataset)
157
+
158
+ return datasets, is_mandatory, grouped_weights
159
+
160
+ def set_epoch(self, seed):
161
+ for dataset in self.grouped_datasets:
162
+ dataset.set_epoch(seed)
163
+
164
+ def set_sequence_status(self):
165
+ sequence_status = dict(
166
+ curr = 0,
167
+ sample_lens = list(),
168
+ packed_position_ids = list(),
169
+ nested_attention_masks = list(),
170
+ split_lens = list(),
171
+ attn_modes = list(),
172
+ packed_text_ids = list(),
173
+ packed_text_indexes = list(),
174
+ packed_label_ids = list(),
175
+ ce_loss_indexes = list(),
176
+ ce_loss_weights = list(),
177
+ vae_image_tensors = list(),
178
+ packed_latent_position_ids = list(),
179
+ vae_latent_shapes = list(),
180
+ packed_vae_token_indexes = list(),
181
+ packed_timesteps = list(),
182
+ mse_loss_indexes = list(),
183
+ packed_vit_tokens = list(),
184
+ vit_token_seqlens = list(),
185
+ packed_vit_position_ids = list(),
186
+ packed_vit_token_indexes = list(),
187
+ )
188
+ return sequence_status
189
+
190
+ def to_tensor(self, sequence_status):
191
+ data = dict(
192
+ sequence_length=sum(sequence_status['sample_lens']),
193
+ sample_lens=sequence_status['sample_lens'],
194
+ packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
195
+ packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
196
+ packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
197
+ )
198
+ if not self.use_flex:
199
+ data['nested_attention_masks'] = sequence_status['nested_attention_masks']
200
+ else:
201
+ sequence_len = data['sequence_length']
202
+ pad_len = self.max_num_tokens - sequence_len
203
+ data['split_lens'] = sequence_status['split_lens'] + [pad_len]
204
+ data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
205
+ data['sample_lens'] += [pad_len]
206
+
207
+ # if the model has a convnet vae (e.g., as visual tokenizer)
208
+ if len(sequence_status['vae_image_tensors']) > 0:
209
+ image_tensors = sequence_status.pop('vae_image_tensors')
210
+ image_sizes = [item.shape for item in image_tensors]
211
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
212
+ padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
213
+ for i, image_tensor in enumerate(image_tensors):
214
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
215
+
216
+ data['padded_images'] = padded_images
217
+ data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
218
+ data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
219
+ data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
220
+
221
+ # if the model has a vit (e.g., as visual tokenizer)
222
+ if len(sequence_status['packed_vit_tokens']) > 0:
223
+ data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
224
+ data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
225
+ data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
226
+ data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
227
+
228
+ # if the model is required to perform visual generation
229
+ if len(sequence_status['packed_timesteps']) > 0:
230
+ data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
231
+ data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
232
+
233
+ # if the model is required to perform text generation
234
+ if len(sequence_status['packed_label_ids']) > 0:
235
+ data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
236
+ data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
237
+ data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
238
+
239
+ # Debug printing for rank 0
240
+ # if self.local_rank == 0:
241
+ # self.print_debug_info(data, sequence_status)
242
+
243
+ return data
244
+
245
+ def print_debug_info(self, data, sequence_status):
246
+ """Print detailed debug information in an intuitive table format"""
247
+ print("\n" + "="*120)
248
+ print("DEBUG: Complete Sequence Analysis")
249
+ print("="*120)
250
+
251
+ # Basic info
252
+ print(f"Sequence Length: {data['sequence_length']}")
253
+ print(f"Sample Lengths: {data['sample_lens']}")
254
+
255
+ # Get all data
256
+ packed_text_ids = data['packed_text_ids'].tolist()
257
+ packed_text_indexes = data['packed_text_indexes'].tolist()
258
+
259
+ # Build loss mappings
260
+ ce_loss_indexes = set(data.get('ce_loss_indexes', []).tolist())
261
+ mse_loss_indexes = set(data.get('mse_loss_indexes', []).tolist())
262
+ vit_token_indexes = set(data.get('packed_vit_token_indexes', []).tolist())
263
+ vae_token_indexes = set(data.get('packed_vae_token_indexes', []).tolist())
264
+
265
+ # Build label mapping
266
+ label_mapping = {}
267
+ if 'ce_loss_indexes' in data:
268
+ ce_indexes = data['ce_loss_indexes'].tolist()
269
+ ce_labels = data['packed_label_ids'].tolist()
270
+ for i, pos in enumerate(ce_indexes):
271
+ label_mapping[pos] = ce_labels[i]
272
+
273
+ # Print raw token sequence
274
+ print(f"\n1. Raw Token IDs: {packed_text_ids}")
275
+
276
+ # Print decoded token sequence
277
+ try:
278
+ decoded_text_tokens = []
279
+ for token_id in packed_text_ids:
280
+ decoded = self.tokenizer.decode([token_id])
281
+ decoded_text_tokens.append(decoded)
282
+ print(f"2. Decoded Tokens: {decoded_text_tokens}")
283
+ except Exception as e:
284
+ print(f"2. Error decoding tokens: {e}")
285
+ decoded_text_tokens = ["<ERROR>"] * len(packed_text_ids)
286
+
287
+ # Create comprehensive sequence table
288
+ print(f"\n3. Complete Sequence Table:")
289
+ print("-" * 120)
290
+ print(f"{'Order':<6} | {'Token Type':<12} | {'Token/Content':<30} | {'Loss Type':<10} | {'Label':<30} | {'Notes':<20}")
291
+ print("-" * 120)
292
+
293
+ # Track text token index
294
+ text_token_idx = 0
295
+
296
+ for pos in range(data['sequence_length']):
297
+ # Determine token type and content
298
+ if pos in packed_text_indexes:
299
+ # This is a text token position
300
+ token_id = packed_text_ids[text_token_idx]
301
+ try:
302
+ decoded_token = self.tokenizer.decode([token_id])
303
+ token_content = f"ID:{token_id} '{decoded_token}'"
304
+ except:
305
+ token_content = f"ID:{token_id} '<ERROR>'"
306
+ token_type = "TEXT"
307
+ text_token_idx += 1
308
+
309
+ elif pos in vit_token_indexes:
310
+ token_type = "VIT_IMAGE"
311
+ token_content = "[VIT Image Patch]"
312
+
313
+ elif pos in vae_token_indexes:
314
+ token_type = "VAE_IMAGE"
315
+ token_content = "[VAE Image Latent]"
316
+
317
+ else:
318
+ token_type = "UNKNOWN"
319
+ token_content = "[Unknown Position]"
320
+
321
+ # Determine loss type
322
+ if pos in ce_loss_indexes:
323
+ loss_type = "CE"
324
+ elif pos in mse_loss_indexes:
325
+ loss_type = "MSE"
326
+ else:
327
+ loss_type = "None"
328
+
329
+ # Determine label
330
+ if pos in label_mapping:
331
+ label_id = label_mapping[pos]
332
+ try:
333
+ decoded_label = self.tokenizer.decode([label_id])
334
+ label_content = f"ID:{label_id} '{decoded_label}'"
335
+ except:
336
+ label_content = f"ID:{label_id} '<ERROR>'"
337
+ elif pos in mse_loss_indexes:
338
+ label_content = "[Image Generation Target]"
339
+ else:
340
+ label_content = "N/A"
341
+
342
+ # Additional notes
343
+ notes = ""
344
+ if pos in mse_loss_indexes and 'packed_timesteps' in data:
345
+ timestep_idx = list(mse_loss_indexes).index(pos) if pos in mse_loss_indexes else -1
346
+ if timestep_idx >= 0 and timestep_idx < len(data['packed_timesteps']):
347
+ timestep = data['packed_timesteps'][timestep_idx].item()
348
+ if timestep == float('-inf'):
349
+ notes = "No noise"
350
+ else:
351
+ notes = f"t={timestep:.3f}"
352
+
353
+ print(f"{pos:<6} | {token_type:<12} | {token_content:<30} | {loss_type:<10} | {label_content:<30} | {notes:<20}")
354
+
355
+ print("-" * 120)
356
+
357
+ # Summary statistics
358
+ total_positions = data['sequence_length']
359
+ ce_positions = len(ce_loss_indexes)
360
+ mse_positions = len(mse_loss_indexes)
361
+ vit_positions = len(vit_token_indexes)
362
+ vae_positions = len(vae_token_indexes)
363
+ text_positions = len(packed_text_indexes)
364
+ no_loss_positions = total_positions - ce_positions - mse_positions
365
+
366
+ print(f"\nSummary Statistics:")
367
+ print(f" Total positions: {total_positions}")
368
+ print(f" Text tokens: {text_positions} ({text_positions/total_positions*100:.1f}%)")
369
+ print(f" VIT image tokens: {vit_positions} ({vit_positions/total_positions*100:.1f}%)")
370
+ print(f" VAE image tokens: {vae_positions} ({vae_positions/total_positions*100:.1f}%)")
371
+ print(f" Positions with CE loss: {ce_positions} ({ce_positions/total_positions*100:.1f}%)")
372
+ print(f" Positions with MSE loss: {mse_positions} ({mse_positions/total_positions*100:.1f}%)")
373
+ print(f" Positions with no loss: {no_loss_positions} ({no_loss_positions/total_positions*100:.1f}%)")
374
+
375
+ print("="*120 + "\n")
376
+
377
+ def __iter__(self):
378
+ total_weights = sum(self.grouped_weights)
379
+ assert total_weights > 0.0
380
+ group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
381
+ for i in range(len(self.grouped_weights))]
382
+ sequence_status = self.set_sequence_status()
383
+ batch_data_indexes = []
384
+
385
+ buffer = []
386
+ while True:
387
+ # Ensure at least one sample from each group
388
+ if sequence_status['curr'] == 0:
389
+ for group_index, group_iter in enumerate(self.dataset_iters):
390
+ if self.is_mandatory[group_index]:
391
+ while True:
392
+ sample = next(group_iter)
393
+ # if a sample is too long, skip it
394
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
395
+ if num_tokens < self.max_num_tokens_per_sample:
396
+ sequence_status = self.pack_sequence(sample, sequence_status)
397
+ batch_data_indexes.append(sample['data_indexes'])
398
+ break
399
+ else:
400
+ print(f"skip a sample with length {num_tokens}")
401
+ continue
402
+
403
+ if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
404
+ sample = buffer.pop(0)
405
+ sample_from_buffer = True
406
+ else:
407
+ # sample normally across all groups
408
+ n = random.random()
409
+ group_index = 0
410
+ for i, cumprob in enumerate(group_cumprobs):
411
+ if n < cumprob:
412
+ group_index = i
413
+ break
414
+ sample = next(self.dataset_iters[group_index])
415
+ sample_from_buffer = False
416
+
417
+ # if a sample is too long, skip it
418
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
419
+ if num_tokens > self.max_num_tokens_per_sample:
420
+ print(f"skip a sample with length {num_tokens}")
421
+ continue
422
+
423
+ if sequence_status['curr'] + num_tokens > self.max_num_tokens:
424
+ if len(buffer) < self.max_buffer_size and not sample_from_buffer:
425
+ buffer.append(sample)
426
+ else:
427
+ print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
428
+ data = self.to_tensor(sequence_status)
429
+ data['batch_data_indexes'] = batch_data_indexes
430
+ yield data
431
+ sequence_status = self.set_sequence_status()
432
+ batch_data_indexes = []
433
+ continue
434
+
435
+ sequence_status = self.pack_sequence(sample, sequence_status)
436
+ batch_data_indexes.append(sample['data_indexes'])
437
+
438
+ if sequence_status['curr'] >= self.expected_num_tokens:
439
+ data = self.to_tensor(sequence_status)
440
+ data['batch_data_indexes'] = batch_data_indexes
441
+ yield data
442
+ sequence_status = self.set_sequence_status()
443
+ batch_data_indexes = []
444
+
445
+ def pack_sequence(self, sample, sequence_status):
446
+ image_tensor_list = sample['image_tensor_list']
447
+ text_ids_list = sample['text_ids_list']
448
+ sequence_plan = sample['sequence_plan']
449
+
450
+ split_lens, attn_modes = list(), list()
451
+ curr = sequence_status['curr']
452
+ curr_rope_id = 0
453
+ sample_lens = 0
454
+
455
+ for item in sequence_plan:
456
+ split_start = item.get('split_start', True)
457
+ if split_start:
458
+ curr_split_len = 0
459
+
460
+ if item['type'] == 'text':
461
+ text_ids = text_ids_list.pop(0)
462
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
463
+ continue
464
+
465
+ shifted_text_ids = [self.bos_token_id] + text_ids
466
+ sequence_status['packed_text_ids'].extend(shifted_text_ids)
467
+ sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
468
+ if item['loss'] == 1:
469
+ sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
470
+ sequence_status['ce_loss_weights'].extend(
471
+ [len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
472
+ )
473
+ sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
474
+ curr += len(shifted_text_ids)
475
+ curr_split_len += len(shifted_text_ids)
476
+
477
+ # add a <|im_end|> token
478
+ sequence_status['packed_text_ids'].append(self.eos_token_id)
479
+ sequence_status['packed_text_indexes'].append(curr)
480
+ if item['special_token_loss'] == 1: # <|im_end|> may have loss
481
+ sequence_status['ce_loss_indexes'].append(curr)
482
+ sequence_status['ce_loss_weights'].append(1.0)
483
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
484
+ curr += 1
485
+ curr_split_len += 1
486
+
487
+ # update sequence status
488
+ attn_modes.append("causal")
489
+ sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
490
+ curr_rope_id += curr_split_len
491
+
492
+ elif item['type'] == 'vit_image':
493
+ image_tensor = image_tensor_list.pop(0)
494
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
495
+ curr_rope_id += 1
496
+ continue
497
+
498
+ # add a <|startofimage|> token
499
+ sequence_status['packed_text_ids'].append(self.start_of_image)
500
+ sequence_status['packed_text_indexes'].append(curr)
501
+ curr += 1
502
+ curr_split_len += 1
503
+
504
+ # preprocess image
505
+ vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
506
+ num_img_tokens = vit_tokens.shape[0]
507
+ sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
508
+ curr += num_img_tokens
509
+ curr_split_len += num_img_tokens
510
+
511
+ sequence_status['packed_vit_tokens'].append(vit_tokens)
512
+ sequence_status['vit_token_seqlens'].append(num_img_tokens)
513
+ sequence_status['packed_vit_position_ids'].append(
514
+ self.get_flattened_position_ids(
515
+ image_tensor.size(1), image_tensor.size(2),
516
+ self.data_config.vit_patch_size,
517
+ max_num_patches_per_side=self.data_config.max_num_patch_per_side
518
+ )
519
+ )
520
+
521
+ # add a <|endofimage|> token
522
+ sequence_status['packed_text_ids'].append(self.end_of_image)
523
+ sequence_status['packed_text_indexes'].append(curr)
524
+ if item['special_token_loss'] == 1: # <|endofimage|> may have loss
525
+ sequence_status['ce_loss_indexes'].append(curr)
526
+ sequence_status['ce_loss_weights'].append(1.0)
527
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
528
+ curr += 1
529
+ curr_split_len += 1
530
+
531
+ # update sequence status
532
+ attn_modes.append("full")
533
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
534
+ curr_rope_id += 1
535
+
536
+ elif item['type'] == 'vae_image':
537
+ image_tensor = image_tensor_list.pop(0)
538
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
539
+ # FIXME fix vae dropout in video2video setting.
540
+ curr_rope_id += 1
541
+ continue
542
+
543
+
544
+
545
+
546
+ # add a <|startofimage|> token
547
+ sequence_status['packed_text_ids'].append(self.start_of_image)
548
+ sequence_status['packed_text_indexes'].append(curr)
549
+
550
+ if item['special_token_loss'] == 1:
551
+ sequence_status['ce_loss_indexes'].append(curr)
552
+ sequence_status['ce_loss_weights'].append(1.0)
553
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
554
+
555
+ curr += 1
556
+ curr_split_len += 1
557
+
558
+ # preprocess image
559
+ sequence_status['vae_image_tensors'].append(image_tensor)
560
+ sequence_status['packed_latent_position_ids'].append(
561
+ self.get_flattened_position_ids(
562
+ image_tensor.size(1), image_tensor.size(2),
563
+ self.data_config.vae_image_downsample,
564
+ max_num_patches_per_side=self.data_config.max_latent_size
565
+ )
566
+ )
567
+ H, W = image_tensor.shape[1:]
568
+ h = H // self.data_config.vae_image_downsample
569
+ w = W // self.data_config.vae_image_downsample
570
+ sequence_status['vae_latent_shapes'].append((h, w))
571
+
572
+ num_img_tokens = w * h
573
+ sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
574
+ if item['loss'] == 1:
575
+ sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
576
+ if split_start:
577
+ timestep = np.random.randn()
578
+ else:
579
+ timestep = float('-inf')
580
+
581
+ sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
582
+ curr += num_img_tokens
583
+ curr_split_len += num_img_tokens
584
+
585
+ # add a <|endofimage|> token
586
+ sequence_status['packed_text_ids'].append(self.end_of_image)
587
+ sequence_status['packed_text_indexes'].append(curr)
588
+ # <|endofimage|> may have loss
589
+ if item['special_token_loss'] == 1:
590
+ sequence_status['ce_loss_indexes'].append(curr)
591
+ sequence_status['ce_loss_weights'].append(1.0)
592
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
593
+ curr += 1
594
+ curr_split_len += 1
595
+
596
+ # update sequence status
597
+ if split_start:
598
+ if item['loss'] == 1 and 'frame_delta' not in item.keys():
599
+ attn_modes.append("noise")
600
+ else:
601
+ attn_modes.append("full")
602
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
603
+ if 'frame_delta' in item.keys():
604
+ curr_rope_id += item['frame_delta']
605
+ elif item['loss'] == 0:
606
+ curr_rope_id += 1
607
+
608
+ if item.get('split_end', True):
609
+ split_lens.append(curr_split_len)
610
+ sample_lens += curr_split_len
611
+
612
+ sequence_status['curr'] = curr
613
+ sequence_status['sample_lens'].append(sample_lens)
614
+ # prepare attention mask
615
+ if not self.use_flex:
616
+ sequence_status['nested_attention_masks'].append(
617
+ prepare_attention_mask_per_sample(split_lens, attn_modes)
618
+ )
619
+ else:
620
+ sequence_status['split_lens'].extend(split_lens)
621
+ sequence_status['attn_modes'].extend(attn_modes)
622
+
623
+ return sequence_status
624
+
625
+
626
+ class SimpleCustomBatch:
627
+ def __init__(self, batch):
628
+ data = batch[0]
629
+ self.batch_data_indexes = data['batch_data_indexes']
630
+ self.sequence_length = data["sequence_length"]
631
+ self.sample_lens = data["sample_lens"]
632
+ self.packed_text_ids = data["packed_text_ids"]
633
+ self.packed_text_indexes = data["packed_text_indexes"]
634
+ self.packed_position_ids = data["packed_position_ids"]
635
+
636
+ self.use_flex = "nested_attention_masks" not in data.keys()
637
+
638
+ if self.use_flex:
639
+ self.split_lens = data["split_lens"]
640
+ self.attn_modes = data["attn_modes"]
641
+ else:
642
+ self.nested_attention_masks = data["nested_attention_masks"]
643
+
644
+ if "padded_images" in data.keys():
645
+ self.padded_images = data["padded_images"]
646
+ self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
647
+ self.packed_latent_position_ids = data["packed_latent_position_ids"]
648
+ self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
649
+
650
+ if "packed_vit_tokens" in data.keys():
651
+ self.packed_vit_tokens = data["packed_vit_tokens"]
652
+ self.packed_vit_position_ids = data["packed_vit_position_ids"]
653
+ self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
654
+ self.vit_token_seqlens = data["vit_token_seqlens"]
655
+
656
+ if "packed_timesteps" in data.keys():
657
+ self.packed_timesteps = data["packed_timesteps"]
658
+ self.mse_loss_indexes = data["mse_loss_indexes"]
659
+
660
+ if "packed_label_ids" in data.keys():
661
+ self.packed_label_ids = data["packed_label_ids"]
662
+ self.ce_loss_indexes = data["ce_loss_indexes"]
663
+ self.ce_loss_weights = data["ce_loss_weights"]
664
+
665
+ def pin_memory(self):
666
+ self.packed_text_ids = self.packed_text_ids.pin_memory()
667
+ self.packed_text_indexes = self.packed_text_indexes.pin_memory()
668
+ self.packed_position_ids = self.packed_position_ids.pin_memory()
669
+
670
+ if not self.use_flex:
671
+ self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
672
+
673
+ if hasattr(self, 'padded_images'):
674
+ self.padded_images = self.padded_images.pin_memory()
675
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
676
+ self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
677
+
678
+ if hasattr(self, 'packed_timesteps'):
679
+ self.packed_timesteps = self.packed_timesteps.pin_memory()
680
+ self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
681
+
682
+ if hasattr(self, 'packed_vit_tokens'):
683
+ self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
684
+ self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
685
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
686
+ self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
687
+
688
+ if hasattr(self, 'packed_label_ids'):
689
+ self.packed_label_ids = self.packed_label_ids.pin_memory()
690
+ self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
691
+ self.ce_loss_weights = self.ce_loss_weights.pin_memory()
692
+
693
+ return self
694
+
695
+ def cuda(self, device):
696
+ self.packed_text_ids = self.packed_text_ids.to(device)
697
+ self.packed_text_indexes = self.packed_text_indexes.to(device)
698
+ self.packed_position_ids = self.packed_position_ids.to(device)
699
+
700
+ if not self.use_flex:
701
+ self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
702
+
703
+ if hasattr(self, 'padded_images'):
704
+ self.padded_images = self.padded_images.to(device)
705
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
706
+ self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
707
+
708
+ if hasattr(self, 'packed_timesteps'):
709
+ self.packed_timesteps = self.packed_timesteps.to(device)
710
+ self.mse_loss_indexes = self.mse_loss_indexes.to(device)
711
+
712
+ if hasattr(self, 'packed_vit_tokens'):
713
+ self.packed_vit_tokens = self.packed_vit_tokens.to(device)
714
+ self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
715
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
716
+ self.vit_token_seqlens = self.vit_token_seqlens.to(device)
717
+
718
+ if hasattr(self, 'packed_label_ids'):
719
+ self.packed_label_ids = self.packed_label_ids.to(device)
720
+ self.ce_loss_indexes = self.ce_loss_indexes.to(device)
721
+ self.ce_loss_weights = self.ce_loss_weights.to(device)
722
+
723
+ return self
724
+
725
+ def to_dict(self):
726
+ data = dict(
727
+ sequence_length = self.sequence_length,
728
+ sample_lens = self.sample_lens,
729
+ packed_text_ids = self.packed_text_ids,
730
+ packed_text_indexes = self.packed_text_indexes,
731
+ packed_position_ids = self.packed_position_ids,
732
+ batch_data_indexes = self.batch_data_indexes,
733
+ )
734
+
735
+ if not self.use_flex:
736
+ data['nested_attention_masks'] = self.nested_attention_masks
737
+ else:
738
+ data['split_lens'] = self.split_lens
739
+ data['attn_modes'] = self.attn_modes
740
+
741
+ if hasattr(self, 'padded_images'):
742
+ data['padded_images'] = self.padded_images
743
+ data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
744
+ data['packed_latent_position_ids'] = self.packed_latent_position_ids
745
+ data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
746
+
747
+ if hasattr(self, 'packed_vit_tokens'):
748
+ data['packed_vit_tokens'] = self.packed_vit_tokens
749
+ data['packed_vit_position_ids'] = self.packed_vit_position_ids
750
+ data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
751
+ data['vit_token_seqlens'] = self.vit_token_seqlens
752
+
753
+ if hasattr(self, 'packed_timesteps'):
754
+ data['packed_timesteps'] = self.packed_timesteps
755
+ data['mse_loss_indexes'] = self.mse_loss_indexes
756
+
757
+ if hasattr(self, 'packed_label_ids'):
758
+ data['packed_label_ids'] = self.packed_label_ids
759
+ data['ce_loss_indexes'] = self.ce_loss_indexes
760
+ data['ce_loss_weights'] = self.ce_loss_weights
761
+
762
+ return data
763
+
764
+
765
+ def collate_wrapper():
766
+ def collate_fn(batch):
767
+ return SimpleCustomBatch(batch)
768
+ return collate_fn
data/dataset_info.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .interleave_datasets import UnifiedEditIterableDataset
5
+ from .t2i_dataset import T2IIterableDataset
6
+ from .vlm_dataset import SftJSONLIterableDataset
7
+ from .interleave_datasets.think_trace_dataset import ThinkTraceJSONLIterableDataset
8
+
9
+
10
+ DATASET_REGISTRY = {
11
+ 't2i_pretrain': T2IIterableDataset,
12
+ 'vlm_sft': SftJSONLIterableDataset,
13
+ 'unified_edit': UnifiedEditIterableDataset,
14
+ 'think_trace': ThinkTraceJSONLIterableDataset,
15
+ 'block_dataset': ThinkTraceJSONLIterableDataset,
16
+ 'block_dataset_random': ThinkTraceJSONLIterableDataset,
17
+ }
18
+
19
+
20
+ DATASET_INFO = {
21
+ 'think_trace': {
22
+ 'think_trace_dataset': {
23
+ 'data_dir': '/scratch/by2593/project/SpaCU/interleaved-co3dv2/data',
24
+ 'jsonl_path': '/scratch/by2593/project/SpaCU/interleaved-co3dv2/data/merged_train.jsonl',
25
+ 'image_prefix_dir': '/scratch/by2593/project/SpaCU/restored_data2', # Base path for relative image paths
26
+ # 'num_total_samples': 100,
27
+ },
28
+ },
29
+ 'block_dataset': {
30
+ 'block_dataset': {
31
+ 'data_dir': "/scratch/by2593/project/SMM/semantic_blocks_part1",
32
+ # 'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1_v2_reordered.jsonl',
33
+ 'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl',
34
+ 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', # Base path for relative image paths
35
+ # 'num_total_samples': 100,
36
+ },
37
+ },
38
+ 'block_dataset_random': {
39
+ 'block_dataset_random': {
40
+ 'data_dir': "/scratch/by2593/project/SMM/random_pipeline/random_blocks",
41
+ 'jsonl_path': '/scratch/by2593/project/SMM/SMM_data/random_block.jsonl',
42
+ 'image_prefix_dir': '/scratch/by2593/project/SMM/random_pipeline/random_blocks', # Base path for relative image paths
43
+ # 'num_total_samples': 100,
44
+ },
45
+ },
46
+ }
data/distributed_iterable_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import random
5
+ import torch
6
+
7
+
8
+ class DistributedIterableDataset(torch.utils.data.IterableDataset):
9
+ def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
10
+ self.dataset_name = dataset_name
11
+ self.local_rank = local_rank
12
+ self.world_size = world_size
13
+ self.num_workers = num_workers
14
+ self.rng = random.Random()
15
+ self.data_paths = None
16
+
17
+ def get_data_paths(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+ def set_epoch(self, seed=42):
21
+ if self.data_paths is None:
22
+ return
23
+
24
+ if isinstance(self.data_paths[0], tuple):
25
+ data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
26
+ elif isinstance(self.data_paths[0], str):
27
+ data_paths = sorted(self.data_paths)
28
+ else:
29
+ raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
30
+
31
+ self.rng.seed(seed)
32
+ self.rng.shuffle(data_paths)
33
+
34
+ num_files_per_rank = len(data_paths) // self.world_size
35
+ local_start = self.local_rank * num_files_per_rank
36
+ local_end = (self.local_rank + 1) * num_files_per_rank
37
+ self.num_files_per_rank = num_files_per_rank
38
+ self.data_paths_per_rank = data_paths[local_start:local_end]
39
+
40
+ def get_data_paths_per_worker(self):
41
+ if self.data_paths is None:
42
+ return None
43
+
44
+ info = torch.utils.data.get_worker_info()
45
+ if info is None:
46
+ # Single worker: Use all files assigned to the rank
47
+ return self.data_paths_per_rank, 0
48
+
49
+ worker_id = info.id
50
+ num_files_per_worker = self.num_files_per_rank // info.num_workers
51
+ start = num_files_per_worker * worker_id
52
+ end = num_files_per_worker * (worker_id + 1)
53
+ data_paths_per_worker = self.data_paths_per_rank[start:end]
54
+
55
+ return data_paths_per_worker[::-1], worker_id
56
+
57
+ def __iter__(self):
58
+ raise NotImplementedError
data/interleave_datasets/edit_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import io
5
+ import random
6
+ from PIL import Image, ImageFile, PngImagePlugin
7
+
8
+ from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
9
+ from ..data_utils import pil_img2rgb
10
+
11
+
12
+ Image.MAX_IMAGE_PIXELS = 200000000
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ MaximumDecompressedSize = 1024
15
+ MegaByte = 2 ** 20
16
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
17
+
18
+
19
+ class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
20
+
21
+ def parse_row(self, row):
22
+ image_num = len(row["image_list"])
23
+ # randomly choose start and end, return [0, 1] when only two images
24
+ start_idx = random.choice(range(image_num - 1))
25
+ max_end = min(start_idx + 3, image_num)
26
+ end_idx = random.choice(range(start_idx + 1, max_end))
27
+
28
+ data = self._init_data()
29
+ data = self._add_image(
30
+ data,
31
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
32
+ need_loss=False,
33
+ need_vae=True,
34
+ need_vit=True,
35
+ )
36
+
37
+ if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
38
+ if end_idx == image_num - 1:
39
+ end_idx -= 1
40
+
41
+ instruction = ""
42
+ for idx in range(start_idx + 1, end_idx + 1):
43
+ instruction += random.choice(row["instruction_list"][idx-1]) + ". "
44
+ data = self._add_text(data, instruction.rstrip(), need_loss=False)
45
+ data = self._add_image(
46
+ data,
47
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
48
+ need_loss=True,
49
+ need_vae=False,
50
+ need_vit=False,
51
+ )
52
+ else:
53
+ for idx in range(start_idx + 1, end_idx + 1):
54
+ instruction = random.choice(row["instruction_list"][idx-1])
55
+ data = self._add_text(data, instruction, need_loss=False)
56
+ if idx != end_idx:
57
+ data = self._add_image(
58
+ data,
59
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
60
+ need_loss=True,
61
+ need_vae=True,
62
+ need_vit=True,
63
+ )
64
+ else:
65
+ data = self._add_image(
66
+ data,
67
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
68
+ need_loss=True,
69
+ need_vae=False,
70
+ need_vit=False,
71
+ )
72
+ return data
data/interleave_datasets/interleave_t2i_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pyarrow.parquet as pq
5
+
6
+ from ..distributed_iterable_dataset import DistributedIterableDataset
7
+ from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
8
+
9
+
10
+ class InterleavedBaseIterableDataset(DistributedIterableDataset):
11
+
12
+ def _init_data(self):
13
+ data = {
14
+ 'sequence_plan': [],
15
+ 'text_ids_list': [],
16
+ 'image_tensor_list': [],
17
+ 'num_tokens': 0,
18
+ }
19
+ return data
20
+
21
+ def _add_text(self, data, text, need_loss, enable_cfg=True, next_token_label=None):
22
+ text_ids = self.tokenizer.encode(text)
23
+ data['num_tokens'] += len(text_ids)
24
+ data['text_ids_list'].append(text_ids)
25
+
26
+ # If next_token_label is provided, the im_end token should predict it
27
+ special_token_loss = 1 if next_token_label is not None else 0
28
+
29
+ data['sequence_plan'].append(
30
+ {
31
+ 'type': 'text',
32
+ 'enable_cfg': int(enable_cfg),
33
+ 'loss': int(need_loss),
34
+ 'special_token_loss': special_token_loss,
35
+ 'special_token_label': next_token_label,
36
+ }
37
+ )
38
+ return data
39
+
40
+ def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True, special_token_label=None):
41
+ assert need_loss or need_vae or need_vit
42
+
43
+ if need_loss:
44
+ # For loss images, don't add special_token_loss on the start token
45
+ # The previous text token should predict the vision_start token
46
+ data['sequence_plan'].append(
47
+ {
48
+ 'type': 'vae_image',
49
+ 'enable_cfg': 0,
50
+ 'loss': 1,
51
+ 'special_token_loss': 0, # No loss on start token itself
52
+ 'special_token_label': None,
53
+ }
54
+ )
55
+
56
+ image_tensor = self.transform(image)
57
+ height, width = image_tensor.shape[1:]
58
+ data['num_tokens'] += width * height // self.transform.stride ** 2
59
+ data['image_tensor_list'].append(image_tensor)
60
+
61
+ if need_vae:
62
+ data['sequence_plan'].append(
63
+ {
64
+ 'type': 'vae_image',
65
+ 'enable_cfg': int(enable_cfg),
66
+ 'loss': 0,
67
+ 'special_token_loss': 0,
68
+ 'special_token_label': None,
69
+ }
70
+ )
71
+
72
+ image_tensor = self.transform(image)
73
+ height, width = image_tensor.shape[1:]
74
+ data['num_tokens'] += width * height // self.transform.stride ** 2
75
+ data['image_tensor_list'].append(image_tensor.clone())
76
+
77
+ if need_vit:
78
+ data['sequence_plan'].append(
79
+ {
80
+ 'type': 'vit_image',
81
+ 'enable_cfg': int(enable_cfg),
82
+ 'loss': 0,
83
+ 'special_token_loss': 0,
84
+ 'special_token_label': None,
85
+ },
86
+ )
87
+ vit_image_tensor = self.vit_transform(image)
88
+ height, width = vit_image_tensor.shape[1:]
89
+ data['num_tokens'] += width * height // self.vit_transform.stride ** 2
90
+ data['image_tensor_list'].append(vit_image_tensor)
91
+
92
+ return data
93
+
94
+ def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
95
+ assert int(need_loss) + int(need_vae) == 1
96
+
97
+ if need_loss:
98
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
99
+ current_sequence_plan = {
100
+ 'type': 'vae_image',
101
+ 'enable_cfg': 0,
102
+ 'loss': 1,
103
+ 'special_token_loss': 0,
104
+ 'special_token_label': None,
105
+ 'split_start': idx == 0,
106
+ 'split_end': idx == len(frames) - 1,
107
+ }
108
+ if idx < len(frame_indexes) - 1:
109
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
110
+ data['sequence_plan'].append(current_sequence_plan)
111
+ image_tensor = self.transform(image)
112
+ height, width = image_tensor.shape[1:]
113
+ data['image_tensor_list'].append(image_tensor)
114
+ data['num_tokens'] += width * height // self.transform.stride ** 2
115
+
116
+ elif need_vae:
117
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
118
+ current_sequence_plan = {
119
+ 'type': 'vae_image',
120
+ 'enable_cfg': int(enable_cfg),
121
+ 'loss': 0,
122
+ 'special_token_loss': 0,
123
+ 'special_token_label': None,
124
+ 'split_start': idx == 0,
125
+ 'split_end': idx == len(frames) - 1,
126
+ }
127
+ if idx < len(frame_indexes) - 1:
128
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
129
+ data['sequence_plan'].append(current_sequence_plan)
130
+ image_tensor = self.transform(image)
131
+ height, width = image_tensor.shape[1:]
132
+ data['image_tensor_list'].append(image_tensor)
133
+ data['num_tokens'] += width * height // self.transform.stride ** 2
134
+
135
+ return data
136
+
137
+
138
+ class ParquetStandardIterableDataset(DistributedIterableDataset):
139
+
140
+ def __init__(
141
+ self, dataset_name, transform, tokenizer, vit_transform,
142
+ data_dir_list, num_used_data, parquet_info,
143
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
144
+ ):
145
+ """
146
+ data_dir_list: list of data directories contains parquet files
147
+ num_used_data: list of number of sampled data paths for each data directory
148
+ vit_transform: input transform for vit model.
149
+ """
150
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
151
+ self.transform = transform
152
+ self.vit_transform = vit_transform
153
+ self.tokenizer = tokenizer
154
+ self.data_status = data_status
155
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
156
+ self.set_epoch()
157
+
158
+ def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
159
+ row_groups = []
160
+ for data_dir, num_data_path in zip(data_dir_list, num_used_data):
161
+ data_paths = get_parquet_data_paths([data_dir], [num_data_path])
162
+ for data_path in data_paths:
163
+ if data_path in parquet_info.keys():
164
+ num_row_groups = parquet_info[data_path]['num_row_groups']
165
+ for rg_idx in range(num_row_groups):
166
+ row_groups.append((data_path, rg_idx))
167
+ return row_groups
168
+
169
+ def parse_row(self, row):
170
+ raise NotImplementedError
171
+
172
+ def __iter__(self):
173
+ file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
174
+ if self.data_status is not None:
175
+ global_row_group_start_id = self.data_status[worker_id][0]
176
+ row_start_id = self.data_status[worker_id][1] + 1
177
+ else:
178
+ global_row_group_start_id = 0
179
+ row_start_id = 0
180
+
181
+ print(
182
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
183
+ f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
184
+ )
185
+
186
+ while True:
187
+ file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
188
+ for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
189
+ file_paths_per_worker_, start=global_row_group_start_id
190
+ ):
191
+ fs = init_arrow_pf_fs(parquet_file_path)
192
+ with fs.open_input_file(parquet_file_path) as f:
193
+ try:
194
+ fr = pq.ParquetFile(f)
195
+ df = fr.read_row_group(row_group_id).to_pandas()
196
+ df = df.iloc[row_start_id:]
197
+ except Exception as e:
198
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
199
+ continue
200
+
201
+ for row_idx, row in df.iterrows():
202
+ try:
203
+ data = self.parse_row(row)
204
+ if len(data) == 0:
205
+ continue
206
+ data['data_indexes'] = {
207
+ "data_indexes": [global_row_group_idx, row_idx],
208
+ "worker_id": worker_id,
209
+ "dataset_name": self.dataset_name,
210
+ }
211
+ except Exception as e:
212
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
213
+ continue
214
+ yield data
215
+
216
+ row_start_id = 0
217
+ global_row_group_start_id = 0
218
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
data/interleave_datasets/think_trace_dataset.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import traceback
5
+ from PIL import Image, ImageFile, PngImagePlugin
6
+
7
+ from .interleave_t2i_dataset import InterleavedBaseIterableDataset
8
+ from ..data_utils import pil_img2rgb
9
+ from ..distributed_iterable_dataset import DistributedIterableDataset
10
+
11
+
12
+ Image.MAX_IMAGE_PIXELS = 200000000
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ MaximumDecompressedSize = 1024
15
+ MegaByte = 2 ** 20
16
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
17
+
18
+
19
+ class ThinkTraceJSONLIterableDataset(InterleavedBaseIterableDataset, DistributedIterableDataset):
20
+ def __init__(
21
+ self,
22
+ dataset_name,
23
+ transform,
24
+ tokenizer,
25
+ vit_transform,
26
+ jsonl_path_list,
27
+ data_dir_list,
28
+ num_used_data,
29
+ local_rank=0,
30
+ world_size=1,
31
+ num_workers=8,
32
+ data_status=None,
33
+ shuffle_lines=True,
34
+ shuffle_seed=0,
35
+ image_prefix_dir=None,
36
+ ):
37
+ """
38
+ Dataset for think-trace style JSONL files with interleaved text and images.
39
+
40
+ Args:
41
+ dataset_name: Name of the dataset
42
+ transform: Transform for VAE images
43
+ tokenizer: Text tokenizer
44
+ vit_transform: Transform for VIT images
45
+ jsonl_path_list: List of JSONL file paths
46
+ data_dir_list: List of base directories (should match jsonl_path_list)
47
+ num_used_data: List of number of samples to use from each JSONL. If a value is None or non-positive, all data from that JSONL will be used.
48
+ image_prefix_dir: Absolute path to prepend to relative image paths
49
+ Other args: Standard distributed dataset args
50
+ """
51
+ DistributedIterableDataset.__init__(self, dataset_name, local_rank, world_size, num_workers)
52
+ self.transform = transform
53
+ self.vit_transform = vit_transform
54
+ self.tokenizer = tokenizer
55
+ self.data_status = data_status
56
+ self.image_prefix_dir = image_prefix_dir or ""
57
+
58
+ self.start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
59
+ self.end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
60
+ self.im_start = tokenizer.convert_tokens_to_ids('<|im_start|>')
61
+
62
+ self.data_paths = self.get_data_paths(
63
+ jsonl_path_list,
64
+ num_used_data,
65
+ shuffle_lines,
66
+ shuffle_seed,
67
+ )
68
+ self.set_epoch()
69
+
70
+ def get_data_paths(self, jsonl_path_list, num_used_data, shuffle_lines, shuffle_seed):
71
+ data_paths = []
72
+ if not isinstance(num_used_data, list):
73
+ num_used_data = [num_used_data] * len(jsonl_path_list)
74
+
75
+ for jsonl_path, num_data_point in zip(jsonl_path_list, num_used_data):
76
+ with open(jsonl_path, 'r') as f:
77
+ raw_data = f.readlines()
78
+ if shuffle_lines:
79
+ self.rng.seed(shuffle_seed)
80
+ self.rng.shuffle(raw_data)
81
+
82
+ # Convert 'None' string to None type
83
+ if num_data_point == 'None':
84
+ num_data_point = None
85
+
86
+ if num_data_point is not None and int(num_data_point) > 0:
87
+ raw_data = raw_data[:int(num_data_point)]
88
+
89
+ data_paths.extend(raw_data)
90
+ return data_paths
91
+
92
+ def extract_image_references(self, text):
93
+ """Extract image references from text like <image_start>[problem_image_1]<image_end>"""
94
+ pattern = r'<image_start>\[([^\]]+)\]<image_end>'
95
+ matches = re.findall(pattern, text)
96
+ return matches
97
+
98
+ def replace_image_references(self, text):
99
+ """Replace image references with placeholder tokens for processing"""
100
+ pattern = r'<image_start>\[([^\]]+)\]<image_end>'
101
+ # Replace with a special placeholder that we'll process later
102
+ return re.sub(pattern, '<IMAGE_PLACEHOLDER>', text)
103
+
104
+ def remove_thought_patterns(self, text):
105
+ """Remove THOUGHT x: patterns from text"""
106
+ # Remove patterns like "THOUGHT 1:", "THOUGHT 2:", etc.
107
+ pattern = r'THOUGHT\s*\d+:\s*'
108
+ return re.sub(pattern, '', text)
109
+
110
+ def load_image_safely(self, data_item, image_key):
111
+ """Load image with null checking and path resolution"""
112
+ if image_key not in data_item or data_item[image_key] is None:
113
+ return None
114
+
115
+ image_path = data_item[image_key]
116
+ full_path = os.path.join(self.image_prefix_dir, image_path)
117
+
118
+ try:
119
+ return pil_img2rgb(Image.open(full_path))
120
+ except Exception as e:
121
+ print(f"Failed to load image {full_path}: {e}")
122
+ return None
123
+
124
+ def parse_row(self, json_line):
125
+ """Parse a single JSON line into the required format"""
126
+ try:
127
+ data_item = json.loads(json_line.strip())
128
+ except:
129
+ traceback.print_exc()
130
+ return {}
131
+
132
+ # Extract the main fields
133
+ prompt = "You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and generate visual aids to enhance your problem-solving. You should first think about the reasoning and planning process in the mind before generating visual aids. Wrap your text reasoning with <think></think> tokens, and wrap your final conclusion with <answer></answer> tokens. Provide your final conclusion clearly in the format of '<answer>Final Answer: <answer here></answer>'"
134
+ question = data_item.get('Question', '')
135
+ question = f'Question: {question}'
136
+ reasoning_trace = data_item.get('Text Reasoning Trace', '')
137
+ reasoning_trace = f'{reasoning_trace}'
138
+ final_answer = data_item.get('Final Answer', '')
139
+ final_answer = f'<answer>Final Answer: {final_answer}</answer>'
140
+
141
+ if not question or not reasoning_trace or not final_answer:
142
+ return {}
143
+
144
+ # Build the sequence
145
+ data = self._init_data()
146
+
147
+ # 0. Add prompt
148
+ data = self._add_text(data, prompt, need_loss=False, enable_cfg=True)
149
+
150
+ # 1. Add question (with image parsing)
151
+ question_image_refs = self.extract_image_references(question)
152
+ if question_image_refs:
153
+ clean_question = self.replace_image_references(question)
154
+ question_text_parts = clean_question.split('<IMAGE_PLACEHOLDER>')
155
+
156
+ if len(question_text_parts) != len(question_image_refs) + 1:
157
+ print(f"Mismatch in question: text parts {len(question_text_parts)}, images {len(question_image_refs)}")
158
+ return {}
159
+
160
+ question_images = []
161
+ for image_ref in question_image_refs:
162
+ image = self.load_image_safely(data_item, image_ref)
163
+ if image is None:
164
+ print(f"Skipping sample due to missing image in question: {image_ref}")
165
+ return {}
166
+ question_images.append(image)
167
+
168
+
169
+ for i, text_part in enumerate(question_text_parts):
170
+ if text_part.strip():
171
+ # Question text has no loss, so no need for vision start prediction
172
+ data = self._add_text(data, text_part.strip(), need_loss=False, enable_cfg=True)
173
+ if i < len(question_images):
174
+ data = self._add_image(
175
+ data, question_images[i],
176
+ need_loss=False, # No loss for question images
177
+ need_vae=False, # VAE conditioning
178
+ need_vit=True, # VIT understanding
179
+ enable_cfg=True,
180
+ )
181
+ else:
182
+ # Original behavior if no images in question
183
+ data = self._add_text(data, question, need_loss=False, enable_cfg=True)
184
+
185
+ # 2. Interleave text parts and images from reasoning trace
186
+ image_refs = self.extract_image_references(reasoning_trace)
187
+
188
+ loaded_images = []
189
+ for image_ref in image_refs:
190
+ image = self.load_image_safely(data_item, image_ref)
191
+ if image is not None:
192
+ loaded_images.append(image)
193
+ else:
194
+ # If image fails to load, skip this sample
195
+ print(f"Skipping sample due to missing image: {image_ref}")
196
+ return {}
197
+
198
+ # Clean reasoning trace by removing image references for text processing
199
+ clean_reasoning_trace = self.replace_image_references(reasoning_trace)
200
+
201
+ # Remove THOUGHT patterns from the reasoning trace
202
+ clean_reasoning_trace = self.remove_thought_patterns(clean_reasoning_trace)
203
+
204
+ # Append final answer to the reasoning trace
205
+ # clean_reasoning_trace += f"\n\nFinal Answer: {final_answer}"
206
+
207
+ # Split reasoning trace by image placeholders to interleave text and images
208
+ text_parts = clean_reasoning_trace.split('<IMAGE_PLACEHOLDER>')
209
+
210
+ if len(text_parts) != len(loaded_images) + 1:
211
+ print(f"Mismatch between text parts ({len(text_parts)}) and images ({len(loaded_images)})")
212
+ return {}
213
+
214
+ # 4. Interleave text parts and images from reasoning trace
215
+ for i, text_part in enumerate(text_parts):
216
+ # Add text part if not empty
217
+ if text_part.strip():
218
+ # Wrap reasoning text with <think></think> tokens
219
+ wrapped_text = f"<think>{text_part.strip()}</think>"
220
+
221
+ # Determine what the im_end token should predict
222
+ if i < len(loaded_images):
223
+ # If this text part is followed by an image, predict vision_start
224
+ next_token_label = self.start_of_image
225
+ elif i == len(text_parts) - 1:
226
+ # If this is the last text part, predict im_start for final answer
227
+ next_token_label = self.im_start
228
+ else:
229
+ next_token_label = None
230
+
231
+ data = self._add_text(data, wrapped_text, need_loss=True, enable_cfg=True, next_token_label=next_token_label)
232
+
233
+ # Add image if available
234
+ if i < len(loaded_images):
235
+ # Add image with both VAE and VIT processing for full capability
236
+ data = self._add_image(
237
+ data,
238
+ loaded_images[i],
239
+ need_loss=True, # VAE generation loss
240
+ need_vae=True, # VAE conditioning
241
+ need_vit=True, # VIT understanding
242
+ enable_cfg=True,
243
+ )
244
+
245
+ # 5. Add final answer
246
+ data = self._add_text(data, final_answer, need_loss=True, enable_cfg=True)# ybq1025 need_loss=False
247
+
248
+ return data
249
+
250
+ def __iter__(self):
251
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
252
+ if self.data_status is not None:
253
+ row_start_id = self.data_status[worker_id] + 1
254
+ else:
255
+ row_start_id = 0
256
+
257
+ print(
258
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
259
+ f"resuming data at row#{row_start_id}"
260
+ )
261
+
262
+ while True:
263
+ data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
264
+ for row_idx, json_line in enumerate(data_paths_per_worker_, start=row_start_id):
265
+ try:
266
+ data = self.parse_row(json_line)
267
+ if len(data) == 0:
268
+ continue
269
+
270
+ # Check if we have any loss
271
+ has_loss = any(item['loss'] for item in data['sequence_plan'])
272
+ if not has_loss:
273
+ print('No loss defined, skipped.')
274
+ continue
275
+
276
+ data['data_indexes'] = {
277
+ "data_indexes": row_idx,
278
+ "worker_id": worker_id,
279
+ "dataset_name": self.dataset_name,
280
+ }
281
+ yield data
282
+
283
+ except Exception as e:
284
+ print(f"Error processing row {row_idx}: {e}")
285
+ traceback.print_exc()
286
+ continue
287
+
288
+ row_start_id = 0
289
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
modeling/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from . import bagel, qwen2, siglip, autoencoder
modeling/autoencoder.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Black Forest Labs.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from einops import rearrange
16
+ from torch import Tensor, nn
17
+ from safetensors.torch import load_file as load_sft
18
+
19
+
20
+ @dataclass
21
+ class AutoEncoderParams:
22
+ resolution: int
23
+ in_channels: int
24
+ downsample: int
25
+ ch: int
26
+ out_ch: int
27
+ ch_mult: list[int]
28
+ num_res_blocks: int
29
+ z_channels: int
30
+ scale_factor: float
31
+ shift_factor: float
32
+
33
+
34
+ def swish(x: Tensor) -> Tensor:
35
+ return x * torch.sigmoid(x)
36
+
37
+
38
+ class AttnBlock(nn.Module):
39
+ def __init__(self, in_channels: int):
40
+ super().__init__()
41
+ self.in_channels = in_channels
42
+
43
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
44
+
45
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
46
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+
50
+ def attention(self, h_: Tensor) -> Tensor:
51
+ h_ = self.norm(h_)
52
+ q = self.q(h_)
53
+ k = self.k(h_)
54
+ v = self.v(h_)
55
+
56
+ b, c, h, w = q.shape
57
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
58
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
59
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
60
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
61
+
62
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
63
+
64
+ def forward(self, x: Tensor) -> Tensor:
65
+ return x + self.proj_out(self.attention(x))
66
+
67
+
68
+ class ResnetBlock(nn.Module):
69
+ def __init__(self, in_channels: int, out_channels: int):
70
+ super().__init__()
71
+ self.in_channels = in_channels
72
+ out_channels = in_channels if out_channels is None else out_channels
73
+ self.out_channels = out_channels
74
+
75
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
76
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
77
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
78
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ if self.in_channels != self.out_channels:
80
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
81
+
82
+ def forward(self, x):
83
+ h = x
84
+ h = self.norm1(h)
85
+ h = swish(h)
86
+ h = self.conv1(h)
87
+
88
+ h = self.norm2(h)
89
+ h = swish(h)
90
+ h = self.conv2(h)
91
+
92
+ if self.in_channels != self.out_channels:
93
+ x = self.nin_shortcut(x)
94
+
95
+ return x + h
96
+
97
+
98
+ class Downsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ # no asymmetric padding in torch conv, must do it ourselves
102
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
103
+
104
+ def forward(self, x: Tensor):
105
+ pad = (0, 1, 0, 1)
106
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
107
+ x = self.conv(x)
108
+ return x
109
+
110
+
111
+ class Upsample(nn.Module):
112
+ def __init__(self, in_channels: int):
113
+ super().__init__()
114
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
115
+
116
+ def forward(self, x: Tensor):
117
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
118
+ x = self.conv(x)
119
+ return x
120
+
121
+
122
+ class Encoder(nn.Module):
123
+ def __init__(
124
+ self,
125
+ resolution: int,
126
+ in_channels: int,
127
+ ch: int,
128
+ ch_mult: list[int],
129
+ num_res_blocks: int,
130
+ z_channels: int,
131
+ ):
132
+ super().__init__()
133
+ self.ch = ch
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ # downsampling
139
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
140
+
141
+ curr_res = resolution
142
+ in_ch_mult = (1,) + tuple(ch_mult)
143
+ self.in_ch_mult = in_ch_mult
144
+ self.down = nn.ModuleList()
145
+ block_in = self.ch
146
+ for i_level in range(self.num_resolutions):
147
+ block = nn.ModuleList()
148
+ attn = nn.ModuleList()
149
+ block_in = ch * in_ch_mult[i_level]
150
+ block_out = ch * ch_mult[i_level]
151
+ for _ in range(self.num_res_blocks):
152
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
153
+ block_in = block_out
154
+ down = nn.Module()
155
+ down.block = block
156
+ down.attn = attn
157
+ if i_level != self.num_resolutions - 1:
158
+ down.downsample = Downsample(block_in)
159
+ curr_res = curr_res // 2
160
+ self.down.append(down)
161
+
162
+ # middle
163
+ self.mid = nn.Module()
164
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
165
+ self.mid.attn_1 = AttnBlock(block_in)
166
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167
+
168
+ # end
169
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
170
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
171
+
172
+ def forward(self, x: Tensor) -> Tensor:
173
+ # downsampling
174
+ hs = [self.conv_in(x)]
175
+ for i_level in range(self.num_resolutions):
176
+ for i_block in range(self.num_res_blocks):
177
+ h = self.down[i_level].block[i_block](hs[-1])
178
+ if len(self.down[i_level].attn) > 0:
179
+ h = self.down[i_level].attn[i_block](h)
180
+ hs.append(h)
181
+ if i_level != self.num_resolutions - 1:
182
+ hs.append(self.down[i_level].downsample(hs[-1]))
183
+
184
+ # middle
185
+ h = hs[-1]
186
+ h = self.mid.block_1(h)
187
+ h = self.mid.attn_1(h)
188
+ h = self.mid.block_2(h)
189
+ # end
190
+ h = self.norm_out(h)
191
+ h = swish(h)
192
+ h = self.conv_out(h)
193
+ return h
194
+
195
+
196
+ class Decoder(nn.Module):
197
+ def __init__(
198
+ self,
199
+ ch: int,
200
+ out_ch: int,
201
+ ch_mult: list[int],
202
+ num_res_blocks: int,
203
+ in_channels: int,
204
+ resolution: int,
205
+ z_channels: int,
206
+ ):
207
+ super().__init__()
208
+ self.ch = ch
209
+ self.num_resolutions = len(ch_mult)
210
+ self.num_res_blocks = num_res_blocks
211
+ self.resolution = resolution
212
+ self.in_channels = in_channels
213
+ self.ffactor = 2 ** (self.num_resolutions - 1)
214
+
215
+ # compute in_ch_mult, block_in and curr_res at lowest res
216
+ block_in = ch * ch_mult[self.num_resolutions - 1]
217
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
218
+ self.z_shape = (1, z_channels, curr_res, curr_res)
219
+
220
+ # z to block_in
221
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
222
+
223
+ # middle
224
+ self.mid = nn.Module()
225
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
226
+ self.mid.attn_1 = AttnBlock(block_in)
227
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228
+
229
+ # upsampling
230
+ self.up = nn.ModuleList()
231
+ for i_level in reversed(range(self.num_resolutions)):
232
+ block = nn.ModuleList()
233
+ attn = nn.ModuleList()
234
+ block_out = ch * ch_mult[i_level]
235
+ for _ in range(self.num_res_blocks + 1):
236
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
237
+ block_in = block_out
238
+ up = nn.Module()
239
+ up.block = block
240
+ up.attn = attn
241
+ if i_level != 0:
242
+ up.upsample = Upsample(block_in)
243
+ curr_res = curr_res * 2
244
+ self.up.insert(0, up) # prepend to get consistent order
245
+
246
+ # end
247
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
248
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
249
+
250
+ def forward(self, z: Tensor) -> Tensor:
251
+ # z to block_in
252
+ h = self.conv_in(z)
253
+
254
+ # middle
255
+ h = self.mid.block_1(h)
256
+ h = self.mid.attn_1(h)
257
+ h = self.mid.block_2(h)
258
+
259
+ # upsampling
260
+ for i_level in reversed(range(self.num_resolutions)):
261
+ for i_block in range(self.num_res_blocks + 1):
262
+ h = self.up[i_level].block[i_block](h)
263
+ if len(self.up[i_level].attn) > 0:
264
+ h = self.up[i_level].attn[i_block](h)
265
+ if i_level != 0:
266
+ h = self.up[i_level].upsample(h)
267
+
268
+ # end
269
+ h = self.norm_out(h)
270
+ h = swish(h)
271
+ h = self.conv_out(h)
272
+ return h
273
+
274
+
275
+ class DiagonalGaussian(nn.Module):
276
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
277
+ super().__init__()
278
+ self.sample = sample
279
+ self.chunk_dim = chunk_dim
280
+
281
+ def forward(self, z: Tensor) -> Tensor:
282
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
283
+ if self.sample:
284
+ std = torch.exp(0.5 * logvar)
285
+ return mean + std * torch.randn_like(mean)
286
+ else:
287
+ return mean
288
+
289
+
290
+ class AutoEncoder(nn.Module):
291
+ def __init__(self, params: AutoEncoderParams):
292
+ super().__init__()
293
+ self.encoder = Encoder(
294
+ resolution=params.resolution,
295
+ in_channels=params.in_channels,
296
+ ch=params.ch,
297
+ ch_mult=params.ch_mult,
298
+ num_res_blocks=params.num_res_blocks,
299
+ z_channels=params.z_channels,
300
+ )
301
+ self.decoder = Decoder(
302
+ resolution=params.resolution,
303
+ in_channels=params.in_channels,
304
+ ch=params.ch,
305
+ out_ch=params.out_ch,
306
+ ch_mult=params.ch_mult,
307
+ num_res_blocks=params.num_res_blocks,
308
+ z_channels=params.z_channels,
309
+ )
310
+ self.reg = DiagonalGaussian()
311
+
312
+ self.scale_factor = params.scale_factor
313
+ self.shift_factor = params.shift_factor
314
+
315
+ def encode(self, x: Tensor) -> Tensor:
316
+ z = self.reg(self.encoder(x))
317
+ z = self.scale_factor * (z - self.shift_factor)
318
+ return z
319
+
320
+ def decode(self, z: Tensor) -> Tensor:
321
+ z = z / self.scale_factor + self.shift_factor
322
+ return self.decoder(z)
323
+
324
+ def forward(self, x: Tensor) -> Tensor:
325
+ return self.decode(self.encode(x))
326
+
327
+
328
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
329
+ if len(missing) > 0 and len(unexpected) > 0:
330
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
331
+ print("\n" + "-" * 79 + "\n")
332
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
333
+ elif len(missing) > 0:
334
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
335
+ elif len(unexpected) > 0:
336
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
337
+
338
+
339
+ def load_ae(local_path: str) -> AutoEncoder:
340
+ ae_params = AutoEncoderParams(
341
+ resolution=256,
342
+ in_channels=3,
343
+ downsample=8,
344
+ ch=128,
345
+ out_ch=3,
346
+ ch_mult=[1, 2, 4, 4],
347
+ num_res_blocks=2,
348
+ z_channels=16,
349
+ scale_factor=0.3611,
350
+ shift_factor=0.1159,
351
+ )
352
+
353
+ # Loading the autoencoder
354
+ ae = AutoEncoder(ae_params)
355
+
356
+ if local_path is not None:
357
+ sd = load_sft(local_path)
358
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
359
+ print_load_warning(missing, unexpected)
360
+ return ae, ae_params
modeling/bagel/bagel.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import copy
5
+ from typing import List, Tuple, Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from torch.nn.attention.flex_attention import create_block_mask
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_utils import PreTrainedModel
13
+
14
+ from data.data_utils import (
15
+ create_sparse_mask,
16
+ get_flattened_position_ids_extrapolate,
17
+ get_flattened_position_ids_interpolate,
18
+ patchify,
19
+ )
20
+ from .qwen2_navit import NaiveCache
21
+ from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
22
+
23
+ from tqdm import tqdm
24
+
25
+
26
+ class BagelConfig(PretrainedConfig):
27
+ def __init__(
28
+ self,
29
+ visual_gen=True,
30
+ visual_und=True,
31
+ llm_config=None,
32
+ vit_config=None,
33
+ vae_config=None,
34
+ latent_patch_size=2,
35
+ max_latent_size=32,
36
+ vit_max_num_patch_per_side=70,
37
+ connector_act="gelu_pytorch_tanh",
38
+ interpolate_pos=False,
39
+ timestep_shift=1.0,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.visual_gen = visual_gen
44
+ self.visual_und = visual_und
45
+ self.llm_config = llm_config
46
+ self.vit_config = vit_config
47
+ self.vae_config = vae_config
48
+ self.latent_patch_size = latent_patch_size
49
+ self.max_latent_size = max_latent_size
50
+ self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
+ self.connector_act = connector_act
52
+ self.interpolate_pos = interpolate_pos
53
+ self.timestep_shift = timestep_shift
54
+
55
+
56
+ class Bagel(PreTrainedModel):
57
+ config_class = BagelConfig
58
+ base_model_prefix = 'bagel'
59
+
60
+ def __init__(self, language_model, vit_model, config: BagelConfig):
61
+ super().__init__(config)
62
+ self.language_model = language_model
63
+ self.hidden_size = config.llm_config.hidden_size
64
+ self.use_moe = "Mo" in config.llm_config.layer_module
65
+ self.num_heads = config.llm_config.num_attention_heads
66
+
67
+ if config.visual_gen:
68
+ self.latent_patch_size = config.latent_patch_size
69
+ self.timestep_shift = config.timestep_shift
70
+ self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
+ self.max_latent_size = config.max_latent_size
72
+ self.latent_channel = config.vae_config.z_channels
73
+ self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
+ self.time_embedder = TimestepEmbedder(self.hidden_size)
75
+ self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
+ self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
+ self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
+
79
+ if config.visual_und:
80
+ self.vit_model = vit_model
81
+ self.vit_patch_size = config.vit_config.patch_size
82
+ self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
+ self.vit_hidden_size = config.vit_config.hidden_size
84
+ self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
+ self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
+
87
+ if config.interpolate_pos:
88
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
+ else:
90
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
+
92
+ self.config = config
93
+ self._init_weights()
94
+
95
+ def _init_weights(self):
96
+ if self.config.visual_gen:
97
+ nn.init.constant_(self.llm2vae.weight, 0)
98
+ nn.init.constant_(self.llm2vae.bias, 0)
99
+
100
+ def forward(
101
+ self,
102
+ sequence_length: int,
103
+ packed_text_ids: torch.LongTensor,
104
+ packed_text_indexes: torch.LongTensor,
105
+ sample_lens: List[int],
106
+ packed_position_ids: torch.LongTensor,
107
+ nested_attention_masks: List[torch.Tensor] = None,
108
+ split_lens: List[int] = None,
109
+ attn_modes: List[str] = None,
110
+ # for visual understanding
111
+ ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
+ packed_label_ids: Optional[torch.LongTensor] = None,
113
+ packed_vit_tokens: Optional[torch.Tensor] = None,
114
+ packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
+ packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
+ vit_token_seqlens: Optional[torch.IntTensor] = None,
117
+ # for visual generation
118
+ padded_latent: Optional[torch.Tensor] = None,
119
+ patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
+ packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
+ packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
+ packed_timesteps: Optional[torch.LongTensor] = None,
123
+ mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ sequence_length: length of sequence.
128
+ packed_text_ids: 1-D int tensor, packed text token ids.
129
+ packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
+ sample_lens: A list of N ints, length of each sample in packed_sequence.
131
+ nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
+ -inf means ignore.
133
+ packed_position_ids: packed 1-D positions, an image has only one global position shared
134
+ by all latent tokens.
135
+
136
+ packed_vit_tokens: packed patchified image tokens for vit model.
137
+ packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
+ packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
+ vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
+ packed_label_ids: 1-D int tensor, packed label token ids.
141
+ ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
+
143
+ padded_latent: padded latent from VAE encoder.
144
+ patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
+ packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
+ packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
+ packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
+ mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
+ """
150
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
+ packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
+ packed_sequence[packed_text_indexes] = packed_text_embedding
153
+
154
+ if nested_attention_masks is None:
155
+ sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
+ seqlen = sum(sample_lens)
157
+ block_mask = create_block_mask(
158
+ sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
+ device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
+ )
161
+ attention_mask = block_mask
162
+ else:
163
+ attention_mask = nested_attention_masks
164
+
165
+ # if self.config.visual_und and vit_token_seqlens is not None:
166
+ if self.config.visual_und:
167
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
168
+ cu_seqlens = cu_seqlens.to(torch.int32)
169
+ max_seqlen = torch.max(vit_token_seqlens).item()
170
+ packed_vit_token_embed = self.vit_model(
171
+ packed_pixel_values=packed_vit_tokens,
172
+ packed_flattened_position_ids=packed_vit_position_ids,
173
+ cu_seqlens=cu_seqlens,
174
+ max_seqlen=max_seqlen,
175
+ )
176
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
177
+ vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
178
+ packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
179
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
180
+
181
+ if self.config.visual_gen:
182
+ p = self.latent_patch_size
183
+ packed_latent = []
184
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
185
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
186
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
187
+ packed_latent.append(latent)
188
+ packed_latent_clean = torch.cat(packed_latent, dim=0)
189
+
190
+ noise = torch.randn_like(packed_latent_clean)
191
+ packed_timesteps = torch.sigmoid(packed_timesteps)
192
+ packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
193
+ packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
194
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
195
+ latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
196
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
197
+ packed_sequence[packed_vae_token_indexes] = packed_latent
198
+
199
+ extra_inputs = {}
200
+ if self.use_moe:
201
+ packed_und_token_indexes = packed_text_indexes
202
+ if packed_vit_token_indexes is not None:
203
+ packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
204
+ extra_inputs.update(
205
+ packed_und_token_indexes=packed_und_token_indexes,
206
+ packed_gen_token_indexes=packed_vae_token_indexes,
207
+ )
208
+
209
+ last_hidden_state = self.language_model(
210
+ packed_sequence=packed_sequence,
211
+ sample_lens=sample_lens,
212
+ attention_mask=attention_mask,
213
+ packed_position_ids=packed_position_ids,
214
+ **extra_inputs,
215
+ )
216
+
217
+ mse = None
218
+ if self.config.visual_gen:
219
+ packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
220
+ target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
221
+ has_mse = packed_timesteps > 0
222
+ mse = (packed_mse_preds - target[has_mse]) ** 2
223
+
224
+ ce = None
225
+ if ce_loss_indexes is not None:
226
+ packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
227
+ ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
228
+
229
+ return dict(mse=mse, ce=ce)
230
+
231
+
232
+ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
233
+ packed_text_ids = list()
234
+ packed_text_position_ids = list()
235
+ text_token_lens = list()
236
+ packed_text_indexes = list()
237
+ packed_key_value_indexes = list()
238
+
239
+ curr = 0
240
+ newlens, new_rope = list(), list()
241
+ for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
242
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
243
+ curr += curr_kvlen
244
+
245
+ text_ids = tokenizer.encode(prompt)
246
+ text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
247
+ text_token_lens.append(len(text_ids))
248
+ packed_text_ids.extend(text_ids)
249
+ packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
250
+ packed_text_indexes.extend(range(curr, curr + len(text_ids)))
251
+ newlens.append(curr_kvlen + len(text_ids))
252
+ new_rope.append(curr_position_id + len(text_ids))
253
+ curr += len(text_ids)
254
+
255
+ generation_input = {
256
+ "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
257
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
258
+ "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
259
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
260
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
261
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
262
+ }
263
+
264
+ return generation_input, newlens, new_rope
265
+
266
+ @torch.no_grad
267
+ def forward_cache_update_text(
268
+ self,
269
+ past_key_values: NaiveCache,
270
+ packed_text_ids: torch.IntTensor,
271
+ packed_text_position_ids: torch.LongTensor,
272
+ text_token_lens: torch.LongTensor,
273
+ packed_text_indexes: torch.LongTensor,
274
+ packed_key_value_indexes: torch.LongTensor,
275
+ key_values_lens: torch.IntTensor,
276
+ ):
277
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
278
+
279
+ extra_inputs = {}
280
+ if self.use_moe:
281
+ extra_inputs = {"mode": "und"}
282
+
283
+ output = self.language_model.forward_inference(
284
+ packed_query_sequence=packed_text_embedding,
285
+ query_lens=text_token_lens,
286
+ packed_query_position_ids=packed_text_position_ids,
287
+ packed_query_indexes=packed_text_indexes,
288
+ past_key_values=past_key_values,
289
+ packed_key_value_indexes=packed_key_value_indexes,
290
+ key_values_lens=key_values_lens,
291
+ update_past_key_values=True,
292
+ is_causal=True,
293
+ **extra_inputs,
294
+ )
295
+ past_key_values = output.past_key_values
296
+
297
+ return past_key_values
298
+
299
+ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
300
+ packed_vit_token_indexes = list()
301
+ vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
302
+ packed_text_ids, packed_text_indexes = list(), list()
303
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
304
+ packed_key_value_indexes = list()
305
+
306
+ _curr = curr = 0
307
+ newlens, new_rope = list(), list()
308
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
309
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
310
+ curr += curr_kvlen
311
+
312
+ packed_text_ids.append(new_token_ids['start_of_image'])
313
+ packed_text_indexes.append(_curr)
314
+ packed_indexes.append(curr)
315
+ curr += 1
316
+ _curr += 1
317
+
318
+ image_tensor = transforms(image)
319
+ vit_position_ids = self.get_flattened_position_ids(
320
+ image_tensor.size(1), image_tensor.size(2),
321
+ self.vit_patch_size,
322
+ max_num_patches_per_side=self.vit_max_num_patch_per_side
323
+ )
324
+ vit_tokens = patchify(image_tensor, self.vit_patch_size)
325
+ packed_vit_tokens.append(vit_tokens)
326
+ num_img_tokens = vit_tokens.shape[0]
327
+ packed_vit_position_ids.append(vit_position_ids)
328
+ vit_token_seqlens.append(num_img_tokens)
329
+ packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
330
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
331
+ curr += num_img_tokens
332
+ _curr += num_img_tokens
333
+
334
+ packed_text_ids.append(new_token_ids['end_of_image'])
335
+ packed_text_indexes.append(_curr)
336
+ packed_indexes.append(curr)
337
+ curr += 1
338
+ _curr += 1
339
+
340
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
341
+ packed_seqlens.append(num_img_tokens + 2)
342
+ newlens.append(curr_kvlen + num_img_tokens + 2)
343
+ new_rope.append(curr_position_id + 1)
344
+
345
+ generation_input = {
346
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
347
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
348
+ "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
349
+ "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
350
+ "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
351
+ "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
352
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
353
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
354
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
355
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
356
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
357
+ }
358
+
359
+ return generation_input, newlens, new_rope
360
+
361
+ @torch.no_grad
362
+ def forward_cache_update_vit(
363
+ self,
364
+ past_key_values: NaiveCache,
365
+ packed_text_ids: torch.LongTensor,
366
+ packed_text_indexes: torch.LongTensor,
367
+ packed_vit_tokens: torch.Tensor,
368
+ packed_vit_token_indexes: torch.LongTensor,
369
+ packed_vit_position_ids: torch.LongTensor,
370
+ vit_token_seqlens: torch.IntTensor,
371
+ packed_position_ids: torch.LongTensor,
372
+ packed_seqlens: torch.IntTensor,
373
+ packed_indexes: torch.LongTensor,
374
+ packed_key_value_indexes: torch.LongTensor,
375
+ key_values_lens: torch.IntTensor,
376
+ ):
377
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
378
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
379
+ packed_sequence[packed_text_indexes] = packed_text_embedding
380
+
381
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
382
+ cu_seqlens = cu_seqlens.to(torch.int32)
383
+ max_seqlen = torch.max(vit_token_seqlens).item()
384
+ packed_vit_token_embed = self.vit_model(
385
+ packed_pixel_values=packed_vit_tokens,
386
+ packed_flattened_position_ids=packed_vit_position_ids,
387
+ cu_seqlens=cu_seqlens,
388
+ max_seqlen=max_seqlen,
389
+ )
390
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
391
+ pos_emb = self.vit_pos_embed(packed_vit_position_ids)
392
+ packed_vit_token_embed = packed_vit_token_embed + pos_emb
393
+ if packed_vit_token_embed.dtype != packed_sequence.dtype:
394
+ packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
395
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
396
+
397
+ extra_inputs = {}
398
+ if self.use_moe:
399
+ extra_inputs = {"mode": "und"}
400
+
401
+ output = self.language_model.forward_inference(
402
+ packed_query_sequence=packed_sequence,
403
+ query_lens=packed_seqlens,
404
+ packed_query_position_ids=packed_position_ids,
405
+ packed_query_indexes=packed_indexes,
406
+ past_key_values=past_key_values,
407
+ packed_key_value_indexes=packed_key_value_indexes,
408
+ key_values_lens=key_values_lens,
409
+ update_past_key_values=True,
410
+ is_causal=False,
411
+ **extra_inputs,
412
+ )
413
+ past_key_values = output.past_key_values
414
+
415
+ return past_key_values
416
+
417
+ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
418
+ patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
419
+ packed_vae_token_indexes = list()
420
+ packed_text_ids, packed_text_indexes = list(), list()
421
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
422
+ packed_key_value_indexes = list()
423
+
424
+ _curr = curr = 0
425
+ vae_image_tensors = list()
426
+ newlens, new_rope = list(), list()
427
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
428
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
429
+ curr += curr_kvlen
430
+
431
+ packed_text_ids.append(new_token_ids['start_of_image'])
432
+ packed_text_indexes.append(_curr)
433
+ packed_indexes.append(curr)
434
+ curr += 1
435
+ _curr += 1
436
+
437
+ image_tensor = transforms(image)
438
+ vae_image_tensors.append(image_tensor)
439
+ vae_posiiton_ids = self.get_flattened_position_ids(
440
+ image_tensor.size(1), image_tensor.size(2),
441
+ self.latent_downsample,
442
+ max_num_patches_per_side=self.max_latent_size
443
+ )
444
+ packed_vae_position_ids.append(vae_posiiton_ids)
445
+ H, W = image_tensor.shape[1:]
446
+ h = H // self.latent_downsample
447
+ w = W // self.latent_downsample
448
+ patchified_vae_latent_shapes.append((h, w))
449
+
450
+ num_img_tokens = w * h
451
+ packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
452
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
453
+ curr += num_img_tokens
454
+ _curr += num_img_tokens
455
+
456
+ packed_text_ids.append(new_token_ids['end_of_image'])
457
+ packed_text_indexes.append(_curr)
458
+ packed_indexes.append(curr)
459
+ curr += 1
460
+ _curr += 1
461
+
462
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
463
+ packed_seqlens.append(num_img_tokens + 2)
464
+ newlens.append(curr_kvlen + num_img_tokens + 2)
465
+ new_rope.append(curr_position_id + 1)
466
+
467
+ image_sizes = [item.shape for item in vae_image_tensors]
468
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
469
+ padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
470
+ for i, image_tensor in enumerate(vae_image_tensors):
471
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
472
+
473
+ generation_input = {
474
+ "padded_images": padded_images,
475
+ "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
476
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
477
+ "packed_timesteps": torch.tensor([timestep]),
478
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
479
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
480
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
481
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
482
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
483
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
484
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
485
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
486
+ }
487
+
488
+ return generation_input, newlens, new_rope
489
+
490
+ @torch.no_grad
491
+ def forward_cache_update_vae(
492
+ self,
493
+ vae_model,
494
+ past_key_values: NaiveCache,
495
+ padded_images: torch.Tensor,
496
+ patchified_vae_latent_shapes: List,
497
+ packed_vae_position_ids: torch.LongTensor,
498
+ packed_timesteps: torch.Tensor,
499
+ packed_vae_token_indexes: torch.LongTensor,
500
+ packed_text_ids: torch.LongTensor,
501
+ packed_text_indexes: torch.LongTensor,
502
+ packed_position_ids: torch.LongTensor,
503
+ packed_seqlens: torch.IntTensor,
504
+ packed_indexes: torch.LongTensor,
505
+ key_values_lens: torch.IntTensor,
506
+ packed_key_value_indexes: torch.Tensor,
507
+ ):
508
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
509
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
510
+ packed_sequence[packed_text_indexes] = packed_text_embedding
511
+
512
+ padded_latent = vae_model.encode(padded_images)
513
+
514
+ p = self.latent_patch_size
515
+ packed_latent = list()
516
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
517
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
518
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
519
+ packed_latent.append(latent)
520
+ packed_latent = torch.cat(packed_latent, dim=0)
521
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
522
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
523
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
524
+ if packed_latent.dtype != packed_sequence.dtype:
525
+ packed_latent = packed_latent.to(packed_sequence.dtype)
526
+ packed_sequence[packed_vae_token_indexes] = packed_latent
527
+
528
+ extra_inputs = {}
529
+ if self.use_moe:
530
+ extra_inputs = {
531
+ "mode": "gen",
532
+ "packed_vae_token_indexes": packed_vae_token_indexes,
533
+ "packed_text_indexes": packed_text_indexes
534
+ }
535
+
536
+ output = self.language_model.forward_inference(
537
+ packed_query_sequence=packed_sequence,
538
+ query_lens=packed_seqlens,
539
+ packed_query_position_ids=packed_position_ids,
540
+ packed_query_indexes=packed_indexes,
541
+ past_key_values=past_key_values,
542
+ key_values_lens=key_values_lens,
543
+ packed_key_value_indexes=packed_key_value_indexes,
544
+ update_past_key_values=True,
545
+ is_causal=False,
546
+ **extra_inputs,
547
+ )
548
+ past_key_values = output.past_key_values
549
+
550
+ return past_key_values
551
+
552
+ def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
553
+ packed_text_ids, packed_text_indexes = list(), list()
554
+ packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
555
+ packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
556
+ packed_key_value_indexes = list()
557
+
558
+ query_curr = curr = 0
559
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
560
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
561
+ curr += curr_kvlen
562
+
563
+ packed_text_ids.append(new_token_ids['start_of_image'])
564
+ packed_text_indexes.append(query_curr)
565
+ packed_indexes.append(curr)
566
+ curr += 1
567
+ query_curr += 1
568
+
569
+ vae_posiiton_ids = self.get_flattened_position_ids(
570
+ H, W,
571
+ self.latent_downsample,
572
+ max_num_patches_per_side=self.max_latent_size
573
+ )
574
+ packed_vae_position_ids.append(vae_posiiton_ids)
575
+
576
+ h, w = H // self.latent_downsample, W // self.latent_downsample
577
+ num_image_tokens = h * w
578
+ packed_init_noises.append(
579
+ torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
580
+ )
581
+ packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
582
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
583
+ curr += num_image_tokens
584
+ query_curr += num_image_tokens
585
+
586
+ packed_text_ids.append(new_token_ids['end_of_image'])
587
+ packed_text_indexes.append(query_curr)
588
+ packed_indexes.append(curr)
589
+ curr += 1
590
+ query_curr += 1
591
+
592
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
593
+ packed_seqlens.append(num_image_tokens + 2)
594
+
595
+ generation_input = {
596
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
597
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
598
+ "packed_init_noises": torch.cat(packed_init_noises, dim=0),
599
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
600
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
601
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
602
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
603
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
604
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
605
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
606
+ }
607
+
608
+ return generation_input
609
+
610
+ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
611
+ packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
612
+
613
+ query_curr = curr = 0
614
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
615
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
616
+ curr += curr_kvlen
617
+
618
+ packed_indexes.append(curr)
619
+ curr += 1
620
+ query_curr += 1
621
+
622
+ h, w = H // self.latent_downsample, W // self.latent_downsample
623
+ num_image_tokens = h * w
624
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
625
+ curr += num_image_tokens
626
+ query_curr += num_image_tokens
627
+
628
+ packed_indexes.append(curr)
629
+ curr += 1
630
+ query_curr += 1
631
+
632
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
633
+
634
+ generation_input = {
635
+ "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
636
+ "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
637
+ "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
638
+ "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
639
+ }
640
+
641
+ return generation_input
642
+
643
+ @torch.no_grad
644
+ def generate_image(
645
+ self,
646
+ packed_text_ids: torch.LongTensor,
647
+ packed_text_indexes: torch.LongTensor,
648
+ packed_init_noises: torch.Tensor,
649
+ packed_vae_position_ids: torch.LongTensor,
650
+ packed_vae_token_indexes: torch.LongTensor,
651
+ packed_seqlens: torch.IntTensor,
652
+ packed_position_ids: torch.LongTensor,
653
+ packed_indexes: torch.LongTensor,
654
+ past_key_values: NaiveCache,
655
+ key_values_lens: torch.IntTensor,
656
+ packed_key_value_indexes: torch.LongTensor,
657
+ num_timesteps: int = 24,
658
+ timestep_shift: float = 1.0,
659
+ cfg_renorm_min: float = 0.0,
660
+ cfg_renorm_type: str = "global",
661
+ cfg_interval: Optional[Tuple[float, float]] = [0, 1],
662
+ # cfg_text
663
+ cfg_text_scale: float = 1.0,
664
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
665
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
666
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
667
+ cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
668
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
669
+ # cfg_img
670
+ cfg_img_scale: float = 1.0,
671
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
672
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
673
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
674
+ cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
675
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
676
+ cfg_type: str = "parallel",
677
+ ):
678
+ x_t = packed_init_noises
679
+
680
+ timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
681
+ timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
682
+ dts = timesteps[:-1] - timesteps[1:]
683
+ timesteps = timesteps[:-1]
684
+
685
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
686
+
687
+ timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
688
+ if t > cfg_interval[0] and t <= cfg_interval[1]:
689
+ cfg_text_scale_ = cfg_text_scale
690
+ cfg_img_scale_ = cfg_img_scale
691
+ else:
692
+ cfg_text_scale_ = 1.0
693
+ cfg_img_scale_ = 1.0
694
+ v_t = self._forward_flow(
695
+ x_t=x_t,
696
+ timestep=timestep,
697
+ packed_vae_token_indexes=packed_vae_token_indexes,
698
+ packed_vae_position_ids=packed_vae_position_ids,
699
+ packed_text_ids=packed_text_ids,
700
+ packed_text_indexes=packed_text_indexes,
701
+ packed_position_ids=packed_position_ids,
702
+ packed_indexes=packed_indexes,
703
+ packed_seqlens=packed_seqlens,
704
+ key_values_lens=key_values_lens,
705
+ past_key_values=past_key_values,
706
+ packed_key_value_indexes=packed_key_value_indexes,
707
+ cfg_renorm_min=cfg_renorm_min,
708
+ cfg_renorm_type=cfg_renorm_type,
709
+ # cfg_text
710
+ cfg_text_scale=cfg_text_scale_,
711
+ cfg_text_packed_position_ids=cfg_text_packed_position_ids,
712
+ cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
713
+ cfg_text_key_values_lens=cfg_text_key_values_lens,
714
+ cfg_text_past_key_values=cfg_text_past_key_values,
715
+ cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
716
+ # cfg_img
717
+ cfg_img_scale=cfg_img_scale_,
718
+ cfg_img_packed_position_ids=cfg_img_packed_position_ids,
719
+ cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
720
+ cfg_img_key_values_lens=cfg_img_key_values_lens,
721
+ cfg_img_past_key_values=cfg_img_past_key_values,
722
+ cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
723
+ cfg_type=cfg_type,
724
+ )
725
+
726
+ x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
727
+
728
+ unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
729
+ return unpacked_latent
730
+
731
+ @torch.no_grad
732
+ def _forward_flow(
733
+ self,
734
+ x_t: torch.Tensor,
735
+ timestep: torch.LongTensor,
736
+ packed_vae_token_indexes: torch.LongTensor,
737
+ packed_vae_position_ids: torch.LongTensor,
738
+ packed_text_ids: torch.LongTensor,
739
+ packed_text_indexes: torch.LongTensor,
740
+ packed_indexes: torch.LongTensor,
741
+ packed_position_ids: torch.LongTensor,
742
+ packed_seqlens: torch.IntTensor,
743
+ key_values_lens: torch.IntTensor,
744
+ past_key_values: NaiveCache,
745
+ packed_key_value_indexes: torch.LongTensor,
746
+ cfg_renorm_min: float = 0.0,
747
+ cfg_renorm_type: str = "global",
748
+ # cfg_text
749
+ cfg_text_scale: float = 1.0,
750
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
751
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
752
+ cfg_text_key_values_lens: Optional[torch.Tensor] = None,
753
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
754
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
755
+ # cfg_img
756
+ cfg_img_scale: float = 1.0,
757
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
758
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
759
+ cfg_img_key_values_lens: Optional[torch.Tensor] = None,
760
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
761
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
762
+ cfg_type: str = "parallel",
763
+ ):
764
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
765
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
766
+ packed_sequence[packed_text_indexes] = packed_text_embedding
767
+
768
+ assert timestep.unique().shape[0] == 1
769
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
770
+ packed_timestep_embeds = self.time_embedder(timestep)
771
+ x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
772
+ if x_t.dtype != packed_sequence.dtype:
773
+ x_t = x_t.to(packed_sequence.dtype)
774
+ packed_sequence[packed_vae_token_indexes] = x_t
775
+
776
+ extra_inputs = {}
777
+ if self.use_moe:
778
+ extra_inputs = {
779
+ "mode": "gen",
780
+ "packed_vae_token_indexes": packed_vae_token_indexes,
781
+ "packed_text_indexes": packed_text_indexes
782
+ }
783
+
784
+ output = self.language_model.forward_inference(
785
+ packed_query_sequence=packed_sequence,
786
+ query_lens=packed_seqlens,
787
+ packed_query_position_ids=packed_position_ids,
788
+ packed_query_indexes=packed_indexes,
789
+ past_key_values=past_key_values,
790
+ key_values_lens=key_values_lens,
791
+ packed_key_value_indexes=packed_key_value_indexes,
792
+ update_past_key_values=False,
793
+ is_causal=False,
794
+ **extra_inputs,
795
+ )
796
+ v_t = self.llm2vae(output.packed_query_sequence)
797
+ v_t = v_t[packed_vae_token_indexes]
798
+
799
+ if cfg_text_scale > 1.0:
800
+ cfg_text_output = self.language_model.forward_inference(
801
+ packed_query_sequence=packed_sequence,
802
+ query_lens=packed_seqlens,
803
+ packed_query_position_ids=cfg_text_packed_position_ids,
804
+ packed_query_indexes=cfg_text_packed_query_indexes,
805
+ past_key_values=cfg_text_past_key_values,
806
+ key_values_lens=cfg_text_key_values_lens,
807
+ packed_key_value_indexes=cfg_text_packed_key_value_indexes,
808
+ update_past_key_values=False,
809
+ is_causal=False,
810
+ **extra_inputs,
811
+ )
812
+ cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
813
+ cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
814
+
815
+ if cfg_img_scale > 1.0:
816
+ cfg_img_output = self.language_model.forward_inference(
817
+ packed_query_sequence=packed_sequence,
818
+ query_lens=packed_seqlens,
819
+ packed_query_position_ids=cfg_img_packed_position_ids,
820
+ packed_query_indexes=cfg_img_packed_query_indexes,
821
+ past_key_values=cfg_img_past_key_values,
822
+ key_values_lens=cfg_img_key_values_lens,
823
+ packed_key_value_indexes=cfg_img_packed_key_value_indexes,
824
+ update_past_key_values=False,
825
+ is_causal=False,
826
+ **extra_inputs,
827
+ )
828
+ cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
829
+ cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
830
+
831
+ if cfg_text_scale > 1.0:
832
+ if cfg_renorm_type == "text_channel":
833
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
834
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
835
+ norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
836
+ scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
837
+ v_t_text = v_t_text_ * scale
838
+ if cfg_img_scale > 1.0:
839
+ v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
840
+ else:
841
+ v_t = v_t_text
842
+ else:
843
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
844
+
845
+ if cfg_img_scale > 1.0:
846
+ v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
847
+ else:
848
+ v_t_ = v_t_text_
849
+
850
+ # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
851
+ if cfg_renorm_type == "global":
852
+ norm_v_t = torch.norm(v_t)
853
+ norm_v_t_ = torch.norm(v_t_)
854
+ elif cfg_renorm_type == "channel":
855
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
856
+ norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
857
+ else:
858
+ raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
859
+ scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
860
+ v_t = v_t_ * scale
861
+ else:
862
+ # No CFG
863
+ pass
864
+
865
+ return v_t
866
+
867
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
868
+ packed_start_tokens, packed_key_value_indexes = list(), list()
869
+ packed_query_position_ids = list()
870
+
871
+ curr = 0
872
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
873
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
874
+ packed_start_tokens.append(new_token_ids['bos_token_id'])
875
+ packed_query_position_ids.append(curr_position_id)
876
+ curr += curr_kvlen
877
+
878
+ generation_input = {
879
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
880
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
881
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
882
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
883
+ }
884
+
885
+ return generation_input
886
+
887
+ @torch.no_grad
888
+ def generate_text(
889
+ self,
890
+ past_key_values: NaiveCache,
891
+ packed_key_value_indexes: torch.LongTensor,
892
+ key_values_lens: torch.IntTensor,
893
+ packed_start_tokens: torch.LongTensor,
894
+ packed_query_position_ids: torch.LongTensor,
895
+ max_length: int,
896
+ do_sample: bool = False,
897
+ temperature: float = 1.0,
898
+ end_token_id: int = None,
899
+ ):
900
+ step = 0
901
+ generated_sequence = []
902
+ curr_tokens = packed_start_tokens
903
+ while step < max_length:
904
+ generated_sequence.append(curr_tokens)
905
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
906
+ query_lens = torch.ones_like(curr_tokens)
907
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
908
+ 0, len(key_values_lens),
909
+ device=key_values_lens.device,
910
+ dtype=key_values_lens.dtype
911
+ )
912
+
913
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
914
+ for i in range(len(uppacked)):
915
+ uppacked[i] += i
916
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
917
+
918
+ extra_inputs = {}
919
+ if self.use_moe:
920
+ extra_inputs = {"mode": "und"}
921
+
922
+ output = self.language_model.forward_inference(
923
+ packed_query_sequence=packed_text_embedding,
924
+ query_lens=query_lens,
925
+ packed_query_position_ids=packed_query_position_ids,
926
+ packed_query_indexes=packed_query_indexes,
927
+ past_key_values=past_key_values,
928
+ key_values_lens=key_values_lens,
929
+ packed_key_value_indexes=packed_key_value_indexes,
930
+ update_past_key_values=True,
931
+ is_causal=True,
932
+ **extra_inputs,
933
+ )
934
+ past_key_values = output.past_key_values
935
+ packed_query_sequence = output.packed_query_sequence
936
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
937
+
938
+ if do_sample:
939
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
940
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
941
+ else:
942
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
943
+
944
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
945
+ for i in range(len(uppacked)):
946
+ uppacked[i] = torch.cat(
947
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
948
+ )
949
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
950
+ key_values_lens = key_values_lens + 1
951
+ packed_query_position_ids = packed_query_position_ids + 1
952
+ step += 1
953
+
954
+ if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
955
+ # Check if next token would be vision_start (151652)
956
+ generated_sequence.append(curr_tokens) # Add the end token
957
+
958
+ # Generate one more token to check if it's vision_start
959
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
960
+
961
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
962
+ for i in range(len(uppacked)):
963
+ uppacked[i] += i
964
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
965
+
966
+ output = self.language_model.forward_inference(
967
+ packed_query_sequence=packed_text_embedding,
968
+ query_lens=query_lens,
969
+ packed_query_position_ids=packed_query_position_ids,
970
+ packed_query_indexes=packed_query_indexes,
971
+ past_key_values=past_key_values,
972
+ key_values_lens=key_values_lens,
973
+ packed_key_value_indexes=packed_key_value_indexes,
974
+ update_past_key_values=False,
975
+ is_causal=True,
976
+ **extra_inputs,
977
+ )
978
+
979
+ pred_logits = self.language_model.lm_head(output.packed_query_sequence)
980
+ if do_sample:
981
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
982
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
983
+ else:
984
+ next_token = torch.argmax(pred_logits, dim=-1)
985
+
986
+ # If next token is vision_start (151652), include it
987
+ if next_token[0] == 151652:
988
+ generated_sequence.append(next_token)
989
+
990
+ break
991
+
992
+ output_device = generated_sequence[0].device
993
+ return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
994
+
995
+ # for evaluation
996
+ @torch.no_grad()
997
+ def chat(
998
+ self,
999
+ tokenizer,
1000
+ new_token_ids,
1001
+ image_transform,
1002
+ images,
1003
+ prompt,
1004
+ max_length: int,
1005
+ do_sample: bool = False,
1006
+ temperature: float = 1.0,
1007
+ ):
1008
+ device = next(self.parameters()).device
1009
+
1010
+ if isinstance(new_token_ids, dict):
1011
+ for k, v in new_token_ids.items():
1012
+ if torch.is_tensor(v):
1013
+ new_token_ids[k] = v.to(device)
1014
+ elif torch.is_tensor(new_token_ids):
1015
+ new_token_ids = new_token_ids.to(device)
1016
+
1017
+ # prefill
1018
+ past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
1019
+ newlens = [0]
1020
+ new_rope = [0]
1021
+
1022
+ # add images
1023
+ for image in images:
1024
+ generation_input, newlens, new_rope = self.prepare_vit_images(
1025
+ curr_kvlens=newlens,
1026
+ curr_rope=new_rope,
1027
+ images=[image],
1028
+ transforms=image_transform,
1029
+ new_token_ids=new_token_ids,
1030
+ )
1031
+ for k, v in generation_input.items():
1032
+ if torch.is_tensor(v):
1033
+ generation_input[k] = v.to(device)
1034
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1035
+ past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
1036
+
1037
+ # add text
1038
+ generation_input, newlens, new_rope = self.prepare_prompts(
1039
+ curr_kvlens=newlens,
1040
+ curr_rope=new_rope,
1041
+ prompts=[prompt],
1042
+ tokenizer=tokenizer,
1043
+ new_token_ids=new_token_ids,
1044
+ )
1045
+ for k, v in generation_input.items():
1046
+ if torch.is_tensor(v):
1047
+ generation_input[k] = v.to(device)
1048
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1049
+ past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1050
+
1051
+ # decode
1052
+ generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1053
+ for k, v in generation_input.items():
1054
+ if torch.is_tensor(v):
1055
+ generation_input[k] = v.to(device)
1056
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1057
+ unpacked_latent = self.generate_text(
1058
+ past_key_values=past_key_values,
1059
+ max_length=max_length,
1060
+ do_sample=do_sample,
1061
+ temperature=temperature,
1062
+ end_token_id=new_token_ids['eos_token_id'],
1063
+ **generation_input,
1064
+ )
1065
+ output = tokenizer.decode(unpacked_latent[:,0])
1066
+ output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
1067
+
1068
+ return output
modeling/bagel/modeling_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Facebook, Inc. and its affiliates.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: CC BY-NC 4.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under CC BY-NC 4.0, with the full license text
8
+ # available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import math
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+ from transformers.activations import ACT2FN
18
+
19
+ # --------------------------------------------------------
20
+ # 2D sine-cosine position embedding
21
+ # References:
22
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
23
+ # --------------------------------------------------------
24
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
25
+ grid_h = np.arange(grid_size, dtype=np.float32)
26
+ grid_w = np.arange(grid_size, dtype=np.float32)
27
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
28
+ grid = np.stack(grid, axis=0)
29
+
30
+ grid = grid.reshape([2, 1, grid_size, grid_size])
31
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
+ if cls_token and extra_tokens > 0:
33
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
34
+ return pos_embed
35
+
36
+
37
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+
40
+ # use half of dimensions to encode grid_h
41
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43
+
44
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45
+ return emb
46
+
47
+
48
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49
+ """
50
+ embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (M,)
52
+ out: (M, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
56
+ omega /= embed_dim / 2.
57
+ omega = 1. / 10000**omega # (D/2,)
58
+
59
+ pos = pos.reshape(-1) # (M,)
60
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61
+
62
+ emb_sin = np.sin(out) # (M, D/2)
63
+ emb_cos = np.cos(out) # (M, D/2)
64
+
65
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66
+ return emb
67
+
68
+
69
+ # --------------------------------------------------------
70
+ # TimestepEmbedder
71
+ # Reference:
72
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
73
+ # --------------------------------------------------------
74
+ class TimestepEmbedder(nn.Module):
75
+ """
76
+ Embeds scalar timesteps into vector representations.
77
+ """
78
+ def __init__(self, hidden_size, frequency_embedding_size=256):
79
+ super().__init__()
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
82
+ nn.SiLU(),
83
+ nn.Linear(hidden_size, hidden_size, bias=True),
84
+ )
85
+ self.frequency_embedding_size = frequency_embedding_size
86
+
87
+ @staticmethod
88
+ def timestep_embedding(t, dim, max_period=10000):
89
+ """
90
+ Create sinusoidal timestep embeddings.
91
+ :param t: a 1-D Tensor of N indices, one per batch element.
92
+ These may be fractional.
93
+ :param dim: the dimension of the output.
94
+ :param max_period: controls the minimum frequency of the embeddings.
95
+ :return: an (N, D) Tensor of positional embeddings.
96
+ """
97
+ half = dim // 2
98
+ freqs = torch.exp(
99
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
100
+ ).to(device=t.device)
101
+ args = t[:, None].float() * freqs[None]
102
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
103
+ if dim % 2:
104
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
105
+ return embedding
106
+
107
+ def forward(self, t):
108
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
109
+ t_emb = self.mlp(t_freq)
110
+ return t_emb
111
+
112
+
113
+ class MLPconnector(nn.Module):
114
+ def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
115
+ super().__init__()
116
+ self.activation_fn = ACT2FN[hidden_act]
117
+ self.fc1 = nn.Linear(in_dim, out_dim)
118
+ self.fc2 = nn.Linear(out_dim, out_dim)
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
+ hidden_states = self.fc1(hidden_states)
122
+ hidden_states = self.activation_fn(hidden_states)
123
+ hidden_states = self.fc2(hidden_states)
124
+ return hidden_states
125
+
126
+
127
+ class PositionEmbedding(nn.Module):
128
+ def __init__(self, max_num_patch_per_side, hidden_size):
129
+ super().__init__()
130
+ self.max_num_patch_per_side = max_num_patch_per_side
131
+ self.hidden_size = hidden_size
132
+ self.pos_embed = nn.Parameter(
133
+ torch.zeros(max_num_patch_per_side ** 2, hidden_size),
134
+ requires_grad=False
135
+ )
136
+ self._init_weights()
137
+
138
+ def _init_weights(self):
139
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
140
+ pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
141
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
142
+
143
+ def forward(self, position_ids):
144
+ return self.pos_embed[position_ids]
modeling/bagel/qwen2_navit.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from typing import List, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn.attention import SDPBackend, sdpa_kernel
20
+ from torch.nn.attention.flex_attention import flex_attention
21
+ from torch.nn.functional import scaled_dot_product_attention
22
+ from transformers.utils import ModelOutput
23
+
24
+ from flash_attn import flash_attn_varlen_func
25
+ from modeling.qwen2.modeling_qwen2 import (
26
+ Qwen2Attention,
27
+ Qwen2MLP,
28
+ Qwen2PreTrainedModel,
29
+ Qwen2RMSNorm,
30
+ Qwen2RotaryEmbedding,
31
+ apply_rotary_pos_emb,
32
+ )
33
+
34
+ from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
35
+
36
+
37
+ torch._dynamo.config.cache_size_limit = 512
38
+ torch._dynamo.config.accumulated_cache_size_limit = 4096
39
+ # flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
40
+ flex_attention = torch.compile(flex_attention)
41
+
42
+
43
+ class Qwen2Config(_Qwen2Config):
44
+ r"""
45
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
46
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
47
+ with the defaults will yield a similar configuration to that of
48
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
49
+
50
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
51
+ documentation from [`PretrainedConfig`] for more information.
52
+
53
+ Args:
54
+ vocab_size (`int`, *optional*, defaults to 151936):
55
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
56
+ `inputs_ids` passed when calling [`Qwen2Model`]
57
+ hidden_size (`int`, *optional*, defaults to 4096):
58
+ Dimension of the hidden representations.
59
+ intermediate_size (`int`, *optional*, defaults to 22016):
60
+ Dimension of the MLP representations.
61
+ num_hidden_layers (`int`, *optional*, defaults to 32):
62
+ Number of hidden layers in the Transformer encoder.
63
+ num_attention_heads (`int`, *optional*, defaults to 32):
64
+ Number of attention heads for each attention layer in the Transformer encoder.
65
+ num_key_value_heads (`int`, *optional*, defaults to 32):
66
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
67
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
68
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
69
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
70
+ by meanpooling all the original heads within that group. For more details checkout [this
71
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
73
+ The non-linear activation function (function or string) in the decoder.
74
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
75
+ The maximum sequence length that this model might ever be used with.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
79
+ The epsilon used by the rms normalization layers.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
82
+ relevant if `config.is_decoder=True`.
83
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
84
+ Whether the model's input and output word embeddings should be tied.
85
+ rope_theta (`float`, *optional*, defaults to 10000.0):
86
+ The base period of the RoPE embeddings.
87
+ rope_scaling (`Dict`, *optional*):
88
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
89
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
90
+ accordingly.
91
+ Expected contents:
92
+ `rope_type` (`str`):
93
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
94
+ 'llama3'], with 'default' being the original RoPE implementation.
95
+ `factor` (`float`, *optional*):
96
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
97
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
98
+ original maximum pre-trained length.
99
+ `original_max_position_embeddings` (`int`, *optional*):
100
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
101
+ pretraining.
102
+ `attention_factor` (`float`, *optional*):
103
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
104
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
105
+ `factor` field to infer the suggested value.
106
+ `beta_fast` (`float`, *optional*):
107
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
108
+ ramp function. If unspecified, it defaults to 32.
109
+ `beta_slow` (`float`, *optional*):
110
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
111
+ ramp function. If unspecified, it defaults to 1.
112
+ `short_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `long_factor` (`List[float]`, *optional*):
117
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
118
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
119
+ size divided by the number of attention heads divided by 2
120
+ `low_freq_factor` (`float`, *optional*):
121
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
122
+ `high_freq_factor` (`float`, *optional*):
123
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
124
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
125
+ Whether to use sliding window attention.
126
+ sliding_window (`int`, *optional*, defaults to 4096):
127
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
128
+ max_window_layers (`int`, *optional*, defaults to 28):
129
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
130
+ attention_dropout (`float`, *optional*, defaults to 0.0):
131
+ The dropout ratio for the attention probabilities.
132
+
133
+ ```python
134
+ >>> from transformers import Qwen2Model, Qwen2Config
135
+
136
+ >>> # Initializing a Qwen2 style configuration
137
+ >>> configuration = Qwen2Config()
138
+
139
+ >>> # Initializing a model from the Qwen2-7B style configuration
140
+ >>> model = Qwen2Model(configuration)
141
+
142
+ >>> # Accessing the model configuration
143
+ >>> configuration = model.config
144
+ ```"""
145
+
146
+ model_type = "qwen2"
147
+ keys_to_ignore_at_inference = ["past_key_values"]
148
+
149
+ def __init__(
150
+ self,
151
+ vocab_size=151936,
152
+ hidden_size=4096,
153
+ intermediate_size=22016,
154
+ num_hidden_layers=32,
155
+ num_attention_heads=32,
156
+ num_key_value_heads=32,
157
+ hidden_act="silu",
158
+ max_position_embeddings=32768,
159
+ initializer_range=0.02,
160
+ rms_norm_eps=1e-6,
161
+ use_cache=True,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ use_sliding_window=False,
166
+ sliding_window=4096,
167
+ max_window_layers=28,
168
+ attention_dropout=0.0,
169
+ is_causal=True,
170
+ _attn_implementation="flash_attention_2",
171
+ qk_norm=True,
172
+ layer_module="Qwen2DecoderLayer",
173
+ freeze_und=False,
174
+ **kwargs,
175
+ ):
176
+ super().__init__(
177
+ vocab_size=vocab_size,
178
+ hidden_size=hidden_size,
179
+ intermediate_size=intermediate_size,
180
+ num_hidden_layers=num_hidden_layers,
181
+ num_attention_heads=num_attention_heads,
182
+ num_key_value_heads=num_key_value_heads,
183
+ hidden_act=hidden_act,
184
+ max_position_embeddings=max_position_embeddings,
185
+ initializer_range=initializer_range,
186
+ rms_norm_eps=rms_norm_eps,
187
+ use_cache=use_cache,
188
+ tie_word_embeddings=tie_word_embeddings,
189
+ rope_theta=rope_theta,
190
+ rope_scaling=rope_scaling,
191
+ use_sliding_window=use_sliding_window,
192
+ sliding_window=sliding_window,
193
+ max_window_layers=max_window_layers,
194
+ attention_dropout=attention_dropout,
195
+ is_causal=is_causal,
196
+ _attn_implementation=_attn_implementation,
197
+ **kwargs,
198
+ )
199
+ self.qk_norm = qk_norm
200
+ self.layer_module = layer_module
201
+ self.freeze_und = freeze_und
202
+
203
+
204
+ class NaiveCache:
205
+ def __init__(self, num_layers):
206
+ self.key_cache = {k: None for k in range(num_layers)}
207
+ self.value_cache = {k: None for k in range(num_layers)}
208
+
209
+ @property
210
+ def num_layers(self):
211
+ return len(self.key_cache)
212
+
213
+ @property
214
+ def seq_lens(self):
215
+ if self.key_cache[0] is not None:
216
+ return self.key_cache[0].shape[0]
217
+ else:
218
+ return 0
219
+
220
+
221
+ @dataclass
222
+ class BaseNavitOutputWithPast(ModelOutput):
223
+ packed_query_sequence: torch.FloatTensor = None
224
+ past_key_values: Optional[NaiveCache] = None
225
+
226
+
227
+ def pad_sequence(tensor, pad_size):
228
+ H, L, D = tensor.shape
229
+ pad_tensor = tensor.new_zeros((H, pad_size, D))
230
+ return torch.cat([tensor, pad_tensor], dim=1)
231
+
232
+
233
+ class PackedAttention(Qwen2Attention):
234
+ def __init__(self, config, layer_idx: Optional[int] = None):
235
+ super().__init__(config, layer_idx)
236
+ if self.config.qk_norm:
237
+ self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
238
+ self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
239
+ else:
240
+ self.q_norm = nn.Identity()
241
+ self.k_norm = nn.Identity()
242
+
243
+ def forward(self, *args, **kwargs):
244
+ if self.training:
245
+ return self.forward_train(*args, **kwargs)
246
+ else:
247
+ return self.forward_inference(*args, **kwargs)
248
+
249
+ def forward_train(
250
+ self,
251
+ packed_sequence: torch.Tensor,
252
+ sample_lens: List[int],
253
+ attention_mask: List[torch.Tensor],
254
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
255
+ ):
256
+ packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
257
+ packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
258
+ packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
259
+
260
+ packed_query_states = self.q_norm(packed_query_states)
261
+ packed_key_states = self.k_norm(packed_key_states)
262
+
263
+ packed_cos, packed_sin = packed_position_embeddings
264
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
265
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
266
+ )
267
+
268
+ if isinstance(attention_mask, List):
269
+ packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
270
+ packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
271
+ packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
272
+ packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
273
+
274
+ unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
275
+ unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
276
+ unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
277
+ upacked_attn_output = []
278
+ for query_states, key_states, value_states, attention_mask_per_sample in zip(
279
+ unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
280
+ ):
281
+ with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
282
+ attn_output = scaled_dot_product_attention(
283
+ query_states.to(torch.bfloat16).unsqueeze(0),
284
+ key_states.to(torch.bfloat16).unsqueeze(0),
285
+ value_states.to(torch.bfloat16).unsqueeze(0),
286
+ attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
287
+ )
288
+ upacked_attn_output.append(attn_output.squeeze(0))
289
+ packed_attn_output = torch.cat(upacked_attn_output, dim=1)
290
+ else:
291
+ pad_size = sum(sample_lens) - packed_query_states.shape[0]
292
+ packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
293
+ packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
294
+ packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
295
+ packed_attn_output = flex_attention(
296
+ packed_query_states.unsqueeze(0),
297
+ packed_key_states.unsqueeze(0),
298
+ packed_value_states.unsqueeze(0),
299
+ enable_gqa=True,
300
+ block_mask=attention_mask,
301
+ )
302
+ end_index = packed_attn_output.shape[2] - pad_size
303
+ packed_attn_output = packed_attn_output[0, :, :end_index, :]
304
+
305
+ packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
306
+ packed_attn_output = self.o_proj(packed_attn_output)
307
+
308
+ return packed_attn_output
309
+
310
+ def forward_inference(
311
+ self,
312
+ packed_query_sequence: torch.Tensor,
313
+ query_lens: torch.Tensor,
314
+ packed_query_position_embeddings: torch.Tensor,
315
+ packed_query_indexes: torch.Tensor,
316
+ past_key_values: Optional[NaiveCache] = None,
317
+ key_values_lens: Optional[torch.Tensor] = None,
318
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
319
+ update_past_key_values=True,
320
+ is_causal=True,
321
+ ):
322
+ packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
323
+ packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
324
+ packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
325
+
326
+ packed_query_states = self.q_norm(packed_query_states)
327
+ packed_key_states = self.k_norm(packed_key_states)
328
+
329
+ packed_cos, packed_sin = packed_query_position_embeddings
330
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
331
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
332
+ )
333
+
334
+ packed_query_states = packed_query_states.to(torch.bfloat16)
335
+ packed_key_states = packed_key_states.to(torch.bfloat16)
336
+ packed_value_states = packed_value_states.to(torch.bfloat16)
337
+
338
+ if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
339
+ past_key_states = past_key_values.key_cache[self.layer_idx]
340
+ past_value_states = past_key_values.value_cache[self.layer_idx]
341
+
342
+ seqlens = sum(query_lens) + sum(key_values_lens)
343
+ merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
344
+ merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
345
+ merged_key_states[packed_query_indexes] = packed_key_states
346
+ merged_key_states[packed_key_value_indexes] = past_key_states
347
+ merged_value_states[packed_query_indexes] = packed_value_states
348
+ merged_value_states[packed_key_value_indexes] = past_value_states
349
+ key_values_lens = key_values_lens + query_lens
350
+ else:
351
+ merged_key_states = packed_key_states
352
+ merged_value_states = packed_value_states
353
+ key_values_lens = query_lens
354
+
355
+ cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
356
+ cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
357
+
358
+ packed_attn_output = flash_attn_varlen_func(
359
+ q=packed_query_states,
360
+ k=merged_key_states,
361
+ v=merged_value_states,
362
+ cu_seqlens_q=cu_seqlens_q.to(torch.int32),
363
+ cu_seqlens_k=cu_seqlens_k.to(torch.int32),
364
+ max_seqlen_q=max(query_lens).item(),
365
+ max_seqlen_k=max(key_values_lens).item(),
366
+ causal=is_causal,
367
+ )
368
+ packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
369
+ packed_attn_output = self.o_proj(packed_attn_output)
370
+
371
+ if update_past_key_values:
372
+ past_key_values.key_cache[self.layer_idx] = merged_key_states
373
+ past_key_values.value_cache[self.layer_idx] = merged_value_states
374
+
375
+ return packed_attn_output, past_key_values
376
+
377
+
378
+ class PackedAttentionMoT(Qwen2Attention):
379
+ def __init__(self, config, layer_idx: Optional[int] = None):
380
+ super().__init__(config, layer_idx)
381
+ if self.config.qk_norm:
382
+ self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
383
+ self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
384
+ self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
385
+ self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
386
+ else:
387
+ self.q_norm = nn.Identity()
388
+ self.k_norm = nn.Identity()
389
+ self.q_norm_moe_gen = nn.Identity()
390
+ self.k_norm_moe_gen = nn.Identity()
391
+
392
+ self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
393
+ self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
394
+ self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
395
+ self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
396
+
397
+ def forward(self, *args, **kwargs):
398
+ if self.training:
399
+ return self.forward_train(*args, **kwargs)
400
+ else:
401
+ return self.forward_inference(*args, **kwargs)
402
+
403
+ def forward_train(
404
+ self,
405
+ packed_sequence: torch.Tensor,
406
+ sample_lens: List[int],
407
+ attention_mask,
408
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
409
+ packed_und_token_indexes: torch.LongTensor,
410
+ packed_gen_token_indexes: torch.LongTensor,
411
+ ):
412
+ packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
413
+ packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
414
+ packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
415
+
416
+ packed_sequence_und = packed_sequence[packed_und_token_indexes]
417
+ packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
418
+
419
+ packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
420
+ packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)
421
+
422
+ packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
423
+ packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)
424
+
425
+ packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
426
+ packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)
427
+
428
+ packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
429
+ packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
430
+ packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
431
+ if self.config.freeze_und:
432
+ packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()
433
+
434
+ packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
435
+ packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
436
+
437
+ packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
438
+ if self.config.freeze_und:
439
+ packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
440
+ packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])
441
+
442
+ packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
443
+ if self.config.freeze_und:
444
+ packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
445
+ packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])
446
+
447
+ packed_cos, packed_sin = packed_position_embeddings
448
+ packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
449
+ packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
450
+ )
451
+
452
+ if isinstance(attention_mask, List):
453
+ packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
454
+ packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
455
+ packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
456
+ packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
457
+
458
+ unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
459
+ unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
460
+ unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
461
+ upacked_attn_output = []
462
+ for query_states, key_states, value_states, attention_mask_per_sample in zip(
463
+ unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
464
+ ):
465
+ with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
466
+ attn_output = scaled_dot_product_attention(
467
+ query_states.to(torch.bfloat16).unsqueeze(0),
468
+ key_states.to(torch.bfloat16).unsqueeze(0),
469
+ value_states.to(torch.bfloat16).unsqueeze(0),
470
+ attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
471
+ )
472
+ upacked_attn_output.append(attn_output.squeeze(0))
473
+ packed_attn_output = torch.cat(upacked_attn_output, dim=1)
474
+ else:
475
+ pad_size = sum(sample_lens) - packed_query_states.shape[0]
476
+ packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
477
+ packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
478
+ packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
479
+ packed_attn_output = flex_attention(
480
+ packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
481
+ packed_key_states_.unsqueeze(0),
482
+ packed_value_states.unsqueeze(0),
483
+ enable_gqa=True,
484
+ block_mask=attention_mask,
485
+ )
486
+ end_index = packed_attn_output.shape[2] - pad_size
487
+ packed_attn_output = packed_attn_output[0, :, :end_index, :]
488
+
489
+ packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
490
+ packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
491
+ packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
492
+ packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])
493
+
494
+ return packed_attn_output_
495
+
496
+ def forward_inference(
497
+ self,
498
+ packed_query_sequence: torch.Tensor,
499
+ query_lens: torch.Tensor,
500
+ packed_query_position_embeddings: torch.Tensor,
501
+ packed_query_indexes: torch.Tensor,
502
+ past_key_values: Optional[NaiveCache] = None,
503
+ key_values_lens: Optional[torch.Tensor] = None,
504
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
505
+ update_past_key_values=True,
506
+ is_causal=True,
507
+ mode="und",
508
+ packed_vae_token_indexes=None,
509
+ packed_text_indexes=None,
510
+ ):
511
+ if mode == 'und':
512
+ packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
513
+ packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
514
+ packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
515
+ packed_query_states = self.q_norm(packed_query_states)
516
+ packed_key_states = self.k_norm(packed_key_states)
517
+ elif mode == 'gen':
518
+ packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
519
+ packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
520
+ packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
521
+ packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
522
+
523
+ packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
524
+ packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
525
+
526
+ packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
527
+ packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)
528
+
529
+ packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
530
+ packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)
531
+
532
+ packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
533
+ packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)
534
+
535
+ packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
536
+ packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
537
+ packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
538
+
539
+ packed_query_states = packed_query_states.to(torch.float32)
540
+ packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
541
+ packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])
542
+
543
+ packed_key_states = packed_key_states.to(torch.float32)
544
+ packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
545
+ packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])
546
+
547
+ packed_cos, packed_sin = packed_query_position_embeddings
548
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
549
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
550
+ )
551
+
552
+ packed_query_states = packed_query_states.to(torch.bfloat16)
553
+ packed_key_states = packed_key_states.to(torch.bfloat16)
554
+ packed_value_states = packed_value_states.to(torch.bfloat16)
555
+
556
+ if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
557
+ past_key_states = past_key_values.key_cache[self.layer_idx]
558
+ past_value_states = past_key_values.value_cache[self.layer_idx]
559
+
560
+ seqlens = sum(query_lens) + sum(key_values_lens)
561
+ merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
562
+ merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
563
+ merged_key_states[packed_query_indexes] = packed_key_states
564
+ merged_key_states[packed_key_value_indexes] = past_key_states
565
+ merged_value_states[packed_query_indexes] = packed_value_states
566
+ merged_value_states[packed_key_value_indexes] = past_value_states
567
+ key_values_lens = key_values_lens + query_lens
568
+ else:
569
+ merged_key_states = packed_key_states
570
+ merged_value_states = packed_value_states
571
+ key_values_lens = query_lens
572
+
573
+ cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
574
+ cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
575
+
576
+ packed_attn_output = flash_attn_varlen_func(
577
+ q=packed_query_states,
578
+ k=merged_key_states,
579
+ v=merged_value_states,
580
+ cu_seqlens_q=cu_seqlens_q.to(torch.int32),
581
+ cu_seqlens_k=cu_seqlens_k.to(torch.int32),
582
+ max_seqlen_q=max(query_lens).item(),
583
+ max_seqlen_k=max(key_values_lens).item(),
584
+ causal=is_causal,
585
+ )
586
+ packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
587
+ if mode == 'und':
588
+ packed_attn_output = self.o_proj(packed_attn_output)
589
+ elif mode == 'gen':
590
+ packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
591
+ packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])
592
+
593
+ if update_past_key_values:
594
+ past_key_values.key_cache[self.layer_idx] = merged_key_states
595
+ past_key_values.value_cache[self.layer_idx] = merged_value_states
596
+
597
+ return packed_attn_output, past_key_values
598
+
599
+
600
+ class Qwen2DecoderLayer(nn.Module):
601
+ def __init__(self, config, layer_idx: Optional[int] = None):
602
+ super().__init__()
603
+ self.hidden_size = config.hidden_size
604
+
605
+ self.self_attn = PackedAttention(config, layer_idx)
606
+
607
+ self.mlp = Qwen2MLP(config)
608
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
609
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
610
+
611
+ def forward(self, *args, **kwargs):
612
+ if self.training:
613
+ return self.forward_train(*args, **kwargs)
614
+ else:
615
+ return self.forward_inference(*args, **kwargs)
616
+
617
+ def forward_train(
618
+ self,
619
+ packed_sequence: torch.Tensor,
620
+ sample_lens: List[int],
621
+ attention_mask,
622
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
623
+ ) -> torch.Tensor:
624
+
625
+ residual = packed_sequence
626
+ packed_sequence = self.input_layernorm(packed_sequence)
627
+
628
+ # Self Attention
629
+ packed_sequence = self.self_attn(
630
+ packed_sequence=packed_sequence,
631
+ sample_lens=sample_lens,
632
+ attention_mask=attention_mask,
633
+ packed_position_embeddings=packed_position_embeddings,
634
+ )
635
+ packed_sequence = residual + packed_sequence
636
+
637
+ # Fully Connected
638
+ residual = packed_sequence
639
+ packed_sequence = self.post_attention_layernorm(packed_sequence)
640
+ packed_sequence = self.mlp(packed_sequence)
641
+ packed_sequence = residual + packed_sequence
642
+
643
+ return packed_sequence
644
+
645
+ def forward_inference(
646
+ self,
647
+ packed_query_sequence: torch.Tensor,
648
+ query_lens: torch.Tensor,
649
+ packed_query_position_embeddings: torch.Tensor,
650
+ packed_query_indexes: torch.Tensor,
651
+ past_key_values: Optional[NaiveCache] = None,
652
+ key_values_lens: Optional[torch.Tensor] = None,
653
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
654
+ update_past_key_values=True,
655
+ is_causal=True,
656
+ ) -> BaseNavitOutputWithPast:
657
+
658
+ residual = packed_query_sequence
659
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
660
+
661
+ # Self Attention
662
+ packed_query_sequence, past_key_values = self.self_attn(
663
+ packed_query_sequence=packed_query_sequence,
664
+ query_lens=query_lens,
665
+ packed_query_position_embeddings=packed_query_position_embeddings,
666
+ packed_query_indexes=packed_query_indexes,
667
+ past_key_values=past_key_values,
668
+ key_values_lens=key_values_lens,
669
+ packed_key_value_indexes=packed_key_value_indexes,
670
+ update_past_key_values=update_past_key_values,
671
+ is_causal=is_causal,
672
+ )
673
+ packed_query_sequence = residual + packed_query_sequence
674
+
675
+ # Fully Connected
676
+ residual = packed_query_sequence
677
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
678
+ packed_query_sequence = self.mlp(packed_query_sequence)
679
+ packed_query_sequence = residual + packed_query_sequence
680
+
681
+ return packed_query_sequence, past_key_values
682
+
683
+
684
+ class Qwen2MoTDecoderLayer(nn.Module):
685
+ def __init__(
686
+ self,
687
+ config,
688
+ layer_idx: Optional[int] = None,
689
+ attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
690
+ ):
691
+ super().__init__()
692
+ self.hidden_size = config.hidden_size
693
+ self.freeze_und = config.freeze_und
694
+
695
+ self.self_attn = attn_module(config, layer_idx)
696
+
697
+ self.mlp = Qwen2MLP(config)
698
+ self.mlp_moe_gen = Qwen2MLP(config)
699
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
700
+ self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
701
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
702
+ self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
703
+
704
+ def forward(self, *args, **kwargs):
705
+ if self.training:
706
+ return self.forward_train(*args, **kwargs)
707
+ else:
708
+ return self.forward_inference(*args, **kwargs)
709
+
710
+ def forward_train(
711
+ self,
712
+ packed_sequence: torch.Tensor,
713
+ sample_lens: List[int],
714
+ attention_mask,
715
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
716
+ packed_und_token_indexes: torch.LongTensor,
717
+ packed_gen_token_indexes: torch.LongTensor,
718
+ ) -> torch.Tensor:
719
+
720
+ residual = packed_sequence
721
+ packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
722
+ packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
723
+ packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
724
+
725
+ # Self Attention
726
+ packed_sequence_ = self.self_attn(
727
+ packed_sequence=packed_sequence_,
728
+ sample_lens=sample_lens,
729
+ attention_mask=attention_mask,
730
+ packed_position_embeddings=packed_position_embeddings,
731
+ packed_und_token_indexes=packed_und_token_indexes,
732
+ packed_gen_token_indexes=packed_gen_token_indexes,
733
+ )
734
+ if self.freeze_und:
735
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
736
+ packed_sequence = residual + packed_sequence_
737
+
738
+ # Fully Connected
739
+ residual = packed_sequence
740
+ packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
741
+ packed_sequence_[packed_und_token_indexes] = self.mlp(
742
+ self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
743
+ )
744
+ if self.freeze_und:
745
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
746
+
747
+ packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
748
+ self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
749
+ )
750
+ packed_sequence = residual + packed_sequence_
751
+
752
+ return packed_sequence
753
+
754
+ def forward_inference(
755
+ self,
756
+ packed_query_sequence: torch.Tensor,
757
+ query_lens: torch.Tensor,
758
+ packed_query_position_embeddings: torch.Tensor,
759
+ packed_query_indexes: torch.Tensor,
760
+ past_key_values: Optional[NaiveCache] = None,
761
+ key_values_lens: Optional[torch.Tensor] = None,
762
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
763
+ update_past_key_values=True,
764
+ is_causal=True,
765
+ mode="und",
766
+ packed_vae_token_indexes=None,
767
+ packed_text_indexes=None,
768
+ ) -> BaseNavitOutputWithPast:
769
+
770
+ residual = packed_query_sequence
771
+ if mode == "und":
772
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
773
+ elif mode == "gen":
774
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
775
+ packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
776
+ packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
777
+ packed_query_sequence = packed_query_sequence_
778
+
779
+ # Self Attention
780
+ packed_query_sequence, past_key_values = self.self_attn(
781
+ packed_query_sequence=packed_query_sequence,
782
+ query_lens=query_lens,
783
+ packed_query_position_embeddings=packed_query_position_embeddings,
784
+ packed_query_indexes=packed_query_indexes,
785
+ past_key_values=past_key_values,
786
+ key_values_lens=key_values_lens,
787
+ packed_key_value_indexes=packed_key_value_indexes,
788
+ update_past_key_values=update_past_key_values,
789
+ is_causal=is_causal,
790
+ mode=mode,
791
+ packed_vae_token_indexes=packed_vae_token_indexes,
792
+ packed_text_indexes=packed_text_indexes,
793
+ )
794
+ packed_query_sequence = residual + packed_query_sequence
795
+
796
+ # Fully Connected
797
+ residual = packed_query_sequence
798
+ if mode == "und":
799
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
800
+ packed_query_sequence = self.mlp(packed_query_sequence)
801
+ elif mode == "gen":
802
+ packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
803
+ packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
804
+ packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
805
+ packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)
806
+
807
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
808
+ packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence)
809
+ packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence)
810
+ packed_query_sequence = packed_query_sequence_
811
+
812
+ packed_query_sequence = residual + packed_query_sequence
813
+ return packed_query_sequence, past_key_values
814
+
815
+
816
+ class Qwen2MoEDecoderLayer(nn.Module):
817
+ def __init__(self, config, layer_idx: Optional[int] = None):
818
+ super().__init__()
819
+ self.hidden_size = config.hidden_size
820
+
821
+ self.self_attn = PackedAttention(config, layer_idx)
822
+
823
+ self.mlp = Qwen2MLP(config)
824
+ self.mlp_moe_gen = Qwen2MLP(config)
825
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
826
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
827
+
828
+ def forward(self, *args, **kwargs):
829
+ if self.training:
830
+ return self.forward_train(*args, **kwargs)
831
+ else:
832
+ return self.forward_inference(*args, **kwargs)
833
+
834
+ def forward_train(
835
+ self,
836
+ packed_sequence: torch.Tensor,
837
+ sample_lens: List[int],
838
+ attention_mask,
839
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
840
+ packed_und_token_indexes: torch.LongTensor,
841
+ packed_gen_token_indexes: torch.LongTensor,
842
+ ) -> torch.Tensor:
843
+
844
+ residual = packed_sequence
845
+ packed_sequence = self.input_layernorm(packed_sequence)
846
+
847
+ # Self Attention
848
+ packed_sequence = self.self_attn(
849
+ packed_sequence=packed_sequence,
850
+ sample_lens=sample_lens,
851
+ attention_mask=attention_mask,
852
+ packed_position_embeddings=packed_position_embeddings,
853
+ )
854
+ packed_sequence = residual + packed_sequence
855
+
856
+ # Fully Connected
857
+ residual = packed_sequence
858
+ packed_sequence = self.post_attention_layernorm(packed_sequence)
859
+
860
+ packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
861
+ packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
862
+ packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
863
+ packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
864
+ packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
865
+
866
+ packed_sequence = residual + packed_sequence_new
867
+
868
+ return packed_sequence
869
+
870
+ def forward_inference(
871
+ self,
872
+ packed_query_sequence: torch.Tensor,
873
+ query_lens: torch.Tensor,
874
+ packed_query_position_embeddings: torch.Tensor,
875
+ packed_query_indexes: torch.Tensor,
876
+ past_key_values: Optional[NaiveCache] = None,
877
+ key_values_lens: Optional[torch.Tensor] = None,
878
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
879
+ update_past_key_values=True,
880
+ is_causal=True,
881
+ mode="und",
882
+ packed_vae_token_indexes=None,
883
+ packed_text_indexes=None,
884
+ ) -> BaseNavitOutputWithPast:
885
+
886
+ residual = packed_query_sequence
887
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
888
+
889
+ # Self Attention
890
+ packed_query_sequence, past_key_values = self.self_attn(
891
+ packed_query_sequence=packed_query_sequence,
892
+ query_lens=query_lens,
893
+ packed_query_position_embeddings=packed_query_position_embeddings,
894
+ packed_query_indexes=packed_query_indexes,
895
+ past_key_values=past_key_values,
896
+ key_values_lens=key_values_lens,
897
+ packed_key_value_indexes=packed_key_value_indexes,
898
+ update_past_key_values=update_past_key_values,
899
+ is_causal=is_causal,
900
+ )
901
+ packed_query_sequence = residual + packed_query_sequence
902
+
903
+ # Fully Connected
904
+ residual = packed_query_sequence
905
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
906
+ if mode == "und":
907
+ packed_query_sequence = self.mlp(packed_query_sequence)
908
+ elif mode == "gen":
909
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
910
+ packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
911
+ packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
912
+ packed_query_sequence = packed_query_sequence_
913
+ packed_query_sequence = residual + packed_query_sequence
914
+
915
+ return packed_query_sequence, past_key_values
916
+
917
+
918
+ Decoder_layer_dict = {
919
+ "Qwen2DecoderLayer": Qwen2DecoderLayer,
920
+ "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
921
+ "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
922
+ }
923
+
924
+
925
+ class Qwen2Model(Qwen2PreTrainedModel):
926
+ def __init__(self, config):
927
+ super().__init__(config)
928
+ self.padding_idx = config.pad_token_id
929
+ self.vocab_size = config.vocab_size
930
+ self.use_moe = 'Mo' in config.layer_module
931
+
932
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
933
+ layer_module = Decoder_layer_dict[config.layer_module]
934
+ self.layers = nn.ModuleList(
935
+ [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
936
+ )
937
+
938
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
939
+ if self.use_moe:
940
+ self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
941
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
942
+
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ def forward(self, *args, **kwargs):
947
+ if self.training:
948
+ return self.forward_train(*args, **kwargs)
949
+ else:
950
+ return self.forward_inference(*args, **kwargs)
951
+
952
+ def forward_train(
953
+ self,
954
+ packed_sequence: torch.Tensor,
955
+ sample_lens: List[int],
956
+ attention_mask,
957
+ packed_position_ids: torch.Tensor,
958
+ packed_und_token_indexes: Optional[torch.LongTensor] = None,
959
+ packed_gen_token_indexes: Optional[torch.LongTensor] = None,
960
+ ) -> torch.Tensor:
961
+
962
+ if self.config.freeze_und:
963
+ packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()
964
+
965
+ # create position embeddings to be shared across the decoder layers
966
+ cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
967
+ cos = cos.squeeze(0)
968
+ sin = sin.squeeze(0)
969
+ packed_position_embeddings = (cos, sin)
970
+
971
+ extra_inputs = {}
972
+ if self.use_moe:
973
+ assert packed_und_token_indexes is not None
974
+ if packed_gen_token_indexes is None:
975
+ packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
976
+ extra_inputs.update(
977
+ packed_und_token_indexes=packed_und_token_indexes,
978
+ packed_gen_token_indexes=packed_gen_token_indexes,
979
+ )
980
+
981
+ for decoder_layer in self.layers:
982
+ packed_sequence = decoder_layer(
983
+ packed_sequence=packed_sequence,
984
+ sample_lens=sample_lens,
985
+ attention_mask=attention_mask,
986
+ packed_position_embeddings=packed_position_embeddings,
987
+ **extra_inputs
988
+ )
989
+
990
+ if self.use_moe:
991
+ packed_sequence_ = torch.zeros_like(packed_sequence)
992
+ packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
993
+ if self.config.freeze_und:
994
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
995
+ packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
996
+ return packed_sequence_
997
+ else:
998
+ return self.norm(packed_sequence)
999
+
1000
+ def forward_inference(
1001
+ self,
1002
+ packed_query_sequence: torch.Tensor,
1003
+ query_lens: torch.Tensor,
1004
+ packed_query_position_ids: torch.Tensor,
1005
+ packed_query_indexes: torch.Tensor,
1006
+ past_key_values: Optional[NaiveCache] = None,
1007
+ key_values_lens: Optional[torch.Tensor] = None,
1008
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
1009
+ update_past_key_values=True,
1010
+ is_causal=True,
1011
+ mode="und",
1012
+ packed_vae_token_indexes=None,
1013
+ packed_text_indexes=None,
1014
+ ) -> BaseNavitOutputWithPast:
1015
+
1016
+ # create position embeddings to be shared across the decoder layers
1017
+ cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))
1018
+ cos = cos.squeeze(0)
1019
+ sin = sin.squeeze(0)
1020
+ packed_query_position_embeddings = (cos, sin)
1021
+
1022
+ extra_inputs = {}
1023
+ if self.use_moe:
1024
+ extra_inputs.update(mode=mode)
1025
+ if mode == 'gen':
1026
+ assert packed_vae_token_indexes is not None
1027
+ assert packed_text_indexes is not None
1028
+ extra_inputs.update(
1029
+ packed_vae_token_indexes=packed_vae_token_indexes,
1030
+ packed_text_indexes=packed_text_indexes,
1031
+ )
1032
+
1033
+ for decoder_layer in self.layers:
1034
+ packed_query_sequence, past_key_values = decoder_layer(
1035
+ packed_query_sequence=packed_query_sequence,
1036
+ query_lens=query_lens,
1037
+ packed_query_position_embeddings=packed_query_position_embeddings,
1038
+ packed_query_indexes=packed_query_indexes,
1039
+ past_key_values=past_key_values,
1040
+ key_values_lens=key_values_lens,
1041
+ packed_key_value_indexes=packed_key_value_indexes,
1042
+ update_past_key_values=update_past_key_values,
1043
+ is_causal=is_causal,
1044
+ **extra_inputs,
1045
+ )
1046
+
1047
+ if self.use_moe:
1048
+ if mode == "und":
1049
+ packed_query_sequence = self.norm(packed_query_sequence)
1050
+ elif mode == "gen":
1051
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
1052
+ packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
1053
+ packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
1054
+ packed_query_sequence = packed_query_sequence_
1055
+ else:
1056
+ packed_query_sequence = self.norm(packed_query_sequence)
1057
+
1058
+ return BaseNavitOutputWithPast(
1059
+ packed_query_sequence=packed_query_sequence,
1060
+ past_key_values=past_key_values,
1061
+ )
1062
+
1063
+
1064
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1065
+ _tied_weights_keys = ["lm_head.weight"]
1066
+
1067
+ def __init__(self, config):
1068
+ super().__init__(config)
1069
+ self.model = Qwen2Model(config)
1070
+ self.vocab_size = config.vocab_size
1071
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1072
+
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def init_moe(self):
1077
+ for name, param in self.named_parameters():
1078
+ if "moe_gen" in name:
1079
+ original_name = name.replace("_moe_gen", "")
1080
+ param.data.copy_(self.state_dict()[original_name].data)
1081
+
1082
+ def get_input_embeddings(self):
1083
+ return self.model.embed_tokens
1084
+
1085
+ def set_input_embeddings(self, value):
1086
+ self.model.embed_tokens = value
1087
+
1088
+ def get_output_embeddings(self):
1089
+ return self.lm_head
1090
+
1091
+ def set_output_embeddings(self, new_embeddings):
1092
+ self.lm_head = new_embeddings
1093
+
1094
+ def set_decoder(self, decoder):
1095
+ self.model = decoder
1096
+
1097
+ def get_decoder(self):
1098
+ return self.model
1099
+
1100
+ def forward(self, *args, **kwargs):
1101
+ if self.training:
1102
+ return self.forward_train(*args, **kwargs)
1103
+ else:
1104
+ return self.forward_inference(*args, **kwargs)
1105
+
1106
+ def forward_train(
1107
+ self,
1108
+ packed_sequence: torch.Tensor,
1109
+ sample_lens: List[int],
1110
+ attention_mask,
1111
+ packed_position_ids: torch.Tensor,
1112
+ packed_und_token_indexes: Optional[torch.LongTensor] = None,
1113
+ packed_gen_token_indexes: Optional[torch.LongTensor] = None,
1114
+ ) -> torch.Tensor:
1115
+
1116
+ outputs = self.model(
1117
+ packed_sequence=packed_sequence,
1118
+ sample_lens=sample_lens,
1119
+ packed_position_ids=packed_position_ids,
1120
+ attention_mask=attention_mask,
1121
+ packed_und_token_indexes=packed_und_token_indexes,
1122
+ packed_gen_token_indexes=packed_gen_token_indexes,
1123
+ )
1124
+ return outputs
1125
+
1126
+ def forward_inference(
1127
+ self,
1128
+ packed_query_sequence: torch.Tensor,
1129
+ query_lens: torch.Tensor,
1130
+ packed_query_position_ids: torch.Tensor,
1131
+ packed_query_indexes: torch.Tensor,
1132
+ past_key_values: Optional[NaiveCache] = None,
1133
+ key_values_lens: Optional[torch.Tensor] = None,
1134
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
1135
+ update_past_key_values=True,
1136
+ is_causal=True,
1137
+ mode="und",
1138
+ packed_vae_token_indexes=None,
1139
+ packed_text_indexes=None,
1140
+ ) -> BaseNavitOutputWithPast:
1141
+
1142
+ outputs = self.model(
1143
+ packed_query_sequence=packed_query_sequence,
1144
+ query_lens=query_lens,
1145
+ packed_query_position_ids=packed_query_position_ids,
1146
+ packed_query_indexes=packed_query_indexes,
1147
+ past_key_values=past_key_values,
1148
+ key_values_lens=key_values_lens,
1149
+ packed_key_value_indexes=packed_key_value_indexes,
1150
+ update_past_key_values=update_past_key_values,
1151
+ is_causal=is_causal,
1152
+ mode=mode,
1153
+ packed_vae_token_indexes=packed_vae_token_indexes,
1154
+ packed_text_indexes=packed_text_indexes,
1155
+ )
1156
+
1157
+ return outputs
modeling/bagel/siglip_navit.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 The HuggingFace Inc. team.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
17
+ from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
18
+ from flash_attn import flash_attn_varlen_func
19
+
20
+
21
+ class SiglipVisionConfig(_SiglipVisionConfig):
22
+ r"""
23
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
24
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
25
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
26
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ hidden_size (`int`, *optional*, defaults to 768):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ intermediate_size (`int`, *optional*, defaults to 3072):
35
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
36
+ num_hidden_layers (`int`, *optional*, defaults to 12):
37
+ Number of hidden layers in the Transformer encoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 12):
39
+ Number of attention heads for each attention layer in the Transformer encoder.
40
+ num_channels (`int`, *optional*, defaults to 3):
41
+ Number of channels in the input images.
42
+ image_size (`int`, *optional*, defaults to 224):
43
+ The size (resolution) of each image.
44
+ patch_size (`int`, *optional*, defaults to 16):
45
+ The size (resolution) of each patch.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
49
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
50
+ The epsilon used by the layer normalization layers.
51
+ attention_dropout (`float`, *optional*, defaults to 0.0):
52
+ The dropout ratio for the attention probabilities.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
58
+
59
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
60
+ >>> configuration = SiglipVisionConfig()
61
+
62
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
63
+ >>> model = SiglipVisionModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "siglip_vision_model"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=768,
74
+ intermediate_size=3072,
75
+ num_hidden_layers=12,
76
+ num_attention_heads=12,
77
+ num_channels=3,
78
+ image_size=224,
79
+ patch_size=16,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ rope=True,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(
87
+ hidden_size=hidden_size,
88
+ intermediate_size=intermediate_size,
89
+ num_hidden_layers=num_hidden_layers,
90
+ num_attention_heads=num_attention_heads,
91
+ num_channels=num_channels,
92
+ image_size=image_size,
93
+ patch_size=patch_size,
94
+ hidden_act=hidden_act,
95
+ layer_norm_eps=layer_norm_eps,
96
+ attention_dropout=attention_dropout,
97
+ **kwargs)
98
+
99
+ self.rope = rope
100
+
101
+
102
+ class RotaryEmbedding2D(torch.nn.Module):
103
+ def __init__(self, dim, max_h, max_w, base=10000):
104
+ super().__init__()
105
+ freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
106
+ inv_freq = 1.0 / (base ** freq)
107
+
108
+ grid_h = torch.arange(0, max_h)
109
+ grid_h = grid_h.to(inv_freq.dtype)
110
+ grid_h = grid_h[:, None].repeat(1, max_w)
111
+
112
+ grid_w = torch.arange(0, max_w)
113
+ grid_w = grid_w.to(inv_freq.dtype)
114
+ grid_w = grid_w[None, :].repeat(max_h, 1)
115
+
116
+ cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
117
+ cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
118
+
119
+ self.register_buffer("cos_h", cos_h)
120
+ self.register_buffer("sin_h", sin_h)
121
+ self.register_buffer("cos_w", cos_w)
122
+ self.register_buffer("sin_w", sin_w)
123
+
124
+ def _forward_one_side(self, grid, inv_freq):
125
+ freqs = grid[..., None] * inv_freq[None, None, :]
126
+ emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
127
+ return emb.cos(), emb.sin()
128
+
129
+
130
+ def rotate_half(x):
131
+ x1 = x[..., : x.shape[-1] // 2]
132
+ x2 = x[..., x.shape[-1] // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+
136
+ def apply_rotary_pos_emb(q, k, cos, sin):
137
+ # unsqueeze due to the head dimension
138
+ cos = cos.unsqueeze(1)
139
+ sin = sin.unsqueeze(1)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ class SiglipVisionEmbeddings(nn.Module):
146
+ def __init__(self, config: SiglipVisionConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.embed_dim = config.hidden_size
150
+ self.image_size = config.image_size
151
+ self.patch_size = config.patch_size
152
+
153
+ self.patch_embedding = nn.Conv2d(
154
+ in_channels=config.num_channels,
155
+ out_channels=self.embed_dim,
156
+ kernel_size=self.patch_size,
157
+ stride=self.patch_size,
158
+ padding="valid",
159
+ )
160
+
161
+ self.num_patches_per_side = self.image_size // self.patch_size
162
+ self.num_patches = self.num_patches_per_side**2
163
+ self.num_positions = self.num_patches
164
+ if not config.rope:
165
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
166
+
167
+ def convert_conv2d_to_linear(self, config, meta=False):
168
+ if meta:
169
+ linear_patch_embedding = nn.Linear(
170
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
171
+ )
172
+ else:
173
+ linear_patch_embedding = nn.Linear(
174
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
175
+ )
176
+ W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
177
+ self.embed_dim, config.num_channels * self.patch_size ** 2
178
+ )
179
+ linear_patch_embedding.weight.data = W
180
+ linear_patch_embedding.bias.data = self.patch_embedding.bias.data
181
+ del self.patch_embedding
182
+ self.patch_embedding = linear_patch_embedding
183
+
184
+ def forward(
185
+ self,
186
+ packed_pixel_values: torch.FloatTensor,
187
+ packed_flattened_position_ids: torch.LongTensor
188
+ ) -> torch.Tensor:
189
+
190
+ patch_embeds = self.patch_embedding(packed_pixel_values)
191
+ if not self.config.rope:
192
+ embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
193
+ else:
194
+ embeddings = patch_embeds
195
+ return embeddings
196
+
197
+
198
+ class SiglipFlashAttention2(SiglipAttention):
199
+ def __init__(self, *args, **kwargs):
200
+ super().__init__(*args, **kwargs)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ cu_seqlens: torch.IntTensor,
206
+ max_seqlen: int,
207
+ cos_h: torch.Tensor = None,
208
+ sin_h: torch.Tensor = None,
209
+ cos_w: torch.Tensor = None,
210
+ sin_w: torch.Tensor = None,
211
+ **kwargs,
212
+ ) -> torch.Tensor:
213
+
214
+ total_q_len, _ = hidden_states.size()
215
+
216
+ query_states = self.q_proj(hidden_states)
217
+ key_states = self.k_proj(hidden_states)
218
+ value_states = self.v_proj(hidden_states)
219
+
220
+ query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
221
+ key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
222
+ value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
223
+
224
+ if self.config.rope:
225
+ qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
226
+ kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
227
+ qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
228
+ qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
229
+ query_states = torch.cat([qh, qw], dim=-1)
230
+ key_states = torch.cat([kh, kw], dim=-1)
231
+
232
+ attn_output = flash_attn_varlen_func(
233
+ query_states.to(torch.bfloat16),
234
+ key_states.to(torch.bfloat16),
235
+ value_states.to(torch.bfloat16),
236
+ cu_seqlens_q=cu_seqlens,
237
+ cu_seqlens_k=cu_seqlens,
238
+ max_seqlen_q=max_seqlen,
239
+ max_seqlen_k=max_seqlen,
240
+ causal=False,
241
+ )
242
+
243
+ attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
244
+ return attn_output
245
+
246
+
247
+ class SiglipMLP(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+ self.config = config
251
+ self.activation_fn = ACT2FN[config.hidden_act]
252
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
253
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
254
+
255
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256
+ hidden_states = self.fc1(hidden_states)
257
+ hidden_states = self.activation_fn(hidden_states)
258
+ hidden_states = self.fc2(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class SiglipEncoderLayer(nn.Module):
263
+ def __init__(self, config: SiglipVisionConfig):
264
+ super().__init__()
265
+ self.embed_dim = config.hidden_size
266
+ self.self_attn = SiglipFlashAttention2(config)
267
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
268
+ self.mlp = SiglipMLP(config)
269
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ cu_seqlens: torch.IntTensor,
275
+ max_seqlen: int,
276
+ cos_h: torch.Tensor = None,
277
+ sin_h: torch.Tensor = None,
278
+ cos_w: torch.Tensor = None,
279
+ sin_w: torch.Tensor = None
280
+ ) -> torch.Tensor:
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.layer_norm1(hidden_states)
284
+ hidden_states = self.self_attn(
285
+ hidden_states=hidden_states,
286
+ cu_seqlens=cu_seqlens,
287
+ max_seqlen=max_seqlen,
288
+ cos_h=cos_h,
289
+ sin_h=sin_h,
290
+ cos_w=cos_w,
291
+ sin_w=sin_w
292
+ )
293
+ hidden_states = residual + hidden_states
294
+
295
+ residual = hidden_states
296
+ hidden_states = self.layer_norm2(hidden_states)
297
+ hidden_states = self.mlp(hidden_states)
298
+ hidden_states = residual + hidden_states
299
+
300
+ return hidden_states
301
+
302
+
303
+ class SiglipEncoder(nn.Module):
304
+ def __init__(self, config: SiglipVisionConfig):
305
+ super().__init__()
306
+ self.config = config
307
+ self.layers = nn.ModuleList(
308
+ [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
309
+ )
310
+
311
+ def forward(
312
+ self,
313
+ inputs_embeds: torch.Tensor,
314
+ cu_seqlens: torch.IntTensor,
315
+ max_seqlen: int,
316
+ cos_h: torch.Tensor = None,
317
+ sin_h: torch.Tensor = None,
318
+ cos_w: torch.Tensor = None,
319
+ sin_w: torch.Tensor = None,
320
+ ) -> torch.Tensor:
321
+
322
+ hidden_states = inputs_embeds
323
+ for encoder_layer in self.layers:
324
+ hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
325
+ cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
326
+
327
+ return hidden_states
328
+
329
+
330
+ class SiglipVisionTransformer(nn.Module):
331
+ def __init__(self, config: SiglipVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ embed_dim = config.hidden_size
335
+
336
+ self.embeddings = SiglipVisionEmbeddings(config)
337
+ if config.rope:
338
+ max_size = config.image_size // config.patch_size
339
+ dim_head = config.hidden_size // config.num_attention_heads
340
+ self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
341
+
342
+ self.encoder = SiglipEncoder(config)
343
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
344
+
345
+ def forward(
346
+ self,
347
+ packed_pixel_values: torch.Tensor,
348
+ packed_flattened_position_ids: torch.LongTensor,
349
+ cu_seqlens: torch.IntTensor,
350
+ max_seqlen: int,
351
+ ) -> torch.Tensor:
352
+ hidden_states = self.embeddings(
353
+ packed_pixel_values=packed_pixel_values,
354
+ packed_flattened_position_ids=packed_flattened_position_ids
355
+ )
356
+
357
+ extra_inputs = {}
358
+ if self.config.rope:
359
+ extra_inputs.update(
360
+ cos_h = self.rope.cos_h[packed_flattened_position_ids],
361
+ sin_h = self.rope.sin_h[packed_flattened_position_ids],
362
+ cos_w = self.rope.cos_w[packed_flattened_position_ids],
363
+ sin_w = self.rope.sin_w[packed_flattened_position_ids]
364
+ )
365
+
366
+ last_hidden_state = self.encoder(
367
+ inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
368
+ **extra_inputs
369
+ )
370
+ last_hidden_state = self.post_layernorm(last_hidden_state)
371
+ return last_hidden_state
372
+
373
+
374
+ class SiglipVisionModel(SiglipPreTrainedModel):
375
+ config_class = SiglipVisionConfig
376
+ main_input_name = "packed_pixel_values"
377
+
378
+ def __init__(self, config: SiglipVisionConfig):
379
+ super().__init__(config)
380
+
381
+ self.vision_model = SiglipVisionTransformer(config)
382
+
383
+ # Initialize weights and apply final processing
384
+ self.post_init()
385
+
386
+ def get_input_embeddings(self) -> nn.Module:
387
+ return self.vision_model.embeddings.patch_embedding
388
+
389
+ def forward(
390
+ self,
391
+ packed_pixel_values: torch.Tensor,
392
+ packed_flattened_position_ids: torch.LongTensor,
393
+ cu_seqlens: torch.IntTensor,
394
+ max_seqlen: int,
395
+ ) -> torch.Tensor:
396
+
397
+ return self.vision_model(
398
+ packed_pixel_values=packed_pixel_values,
399
+ packed_flattened_position_ids=packed_flattened_position_ids,
400
+ cu_seqlens=cu_seqlens,
401
+ max_seqlen=max_seqlen,
402
+ )
modeling/qwen2/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_tokenizers_available,
10
+ is_torch_available,
11
+ )
12
+
13
+
14
+ _import_structure = {
15
+ "configuration_qwen2": ["Qwen2Config"],
16
+ "tokenization_qwen2": ["Qwen2Tokenizer"],
17
+ }
18
+
19
+ try:
20
+ if not is_tokenizers_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ pass
24
+ else:
25
+ _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
26
+
27
+ try:
28
+ if not is_torch_available():
29
+ raise OptionalDependencyNotAvailable()
30
+ except OptionalDependencyNotAvailable:
31
+ pass
32
+ else:
33
+ _import_structure["modeling_qwen2"] = [
34
+ "Qwen2ForCausalLM",
35
+ "Qwen2Model",
36
+ "Qwen2PreTrainedModel",
37
+ ]
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from .configuration_qwen2 import Qwen2Config
42
+ from .tokenization_qwen2 import Qwen2Tokenizer
43
+
44
+ try:
45
+ if not is_tokenizers_available():
46
+ raise OptionalDependencyNotAvailable()
47
+ except OptionalDependencyNotAvailable:
48
+ pass
49
+ else:
50
+ from .tokenization_qwen2_fast import Qwen2TokenizerFast
51
+
52
+ try:
53
+ if not is_torch_available():
54
+ raise OptionalDependencyNotAvailable()
55
+ except OptionalDependencyNotAvailable:
56
+ pass
57
+ else:
58
+ from .modeling_qwen2 import (
59
+ Qwen2ForCausalLM,
60
+ Qwen2Model,
61
+ Qwen2PreTrainedModel,
62
+ )
63
+
64
+
65
+ else:
66
+ import sys
67
+
68
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/qwen2/configuration_qwen2.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Qwen2 model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_rope_utils import rope_config_validation
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class Qwen2Config(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
17
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
18
+ with the defaults will yield a similar configuration to that of
19
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 151936):
27
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`Qwen2Model`]
29
+ hidden_size (`int`, *optional*, defaults to 4096):
30
+ Dimension of the hidden representations.
31
+ intermediate_size (`int`, *optional*, defaults to 22016):
32
+ Dimension of the MLP representations.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ num_key_value_heads (`int`, *optional*, defaults to 32):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
45
+ The non-linear activation function (function or string) in the decoder.
46
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
47
+ The maximum sequence length that this model might ever be used with.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether the model's input and output word embeddings should be tied.
57
+ rope_theta (`float`, *optional*, defaults to 10000.0):
58
+ The base period of the RoPE embeddings.
59
+ rope_scaling (`Dict`, *optional*):
60
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
+ accordingly.
63
+ Expected contents:
64
+ `rope_type` (`str`):
65
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
+ 'llama3'], with 'default' being the original RoPE implementation.
67
+ `factor` (`float`, *optional*):
68
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
+ original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
+ pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
77
+ `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
+ size divided by the number of attention heads divided by 2
88
+ `long_factor` (`List[float]`, *optional*):
89
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
+ size divided by the number of attention heads divided by 2
92
+ `low_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
+ `high_freq_factor` (`float`, *optional*):
95
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
97
+ Whether to use sliding window attention.
98
+ sliding_window (`int`, *optional*, defaults to 4096):
99
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
100
+ max_window_layers (`int`, *optional*, defaults to 28):
101
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
102
+ attention_dropout (`float`, *optional*, defaults to 0.0):
103
+ The dropout ratio for the attention probabilities.
104
+
105
+ ```python
106
+ >>> from transformers import Qwen2Model, Qwen2Config
107
+
108
+ >>> # Initializing a Qwen2 style configuration
109
+ >>> configuration = Qwen2Config()
110
+
111
+ >>> # Initializing a model from the Qwen2-7B style configuration
112
+ >>> model = Qwen2Model(configuration)
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "qwen2"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=151936,
124
+ hidden_size=4096,
125
+ intermediate_size=22016,
126
+ num_hidden_layers=32,
127
+ num_attention_heads=32,
128
+ num_key_value_heads=32,
129
+ hidden_act="silu",
130
+ max_position_embeddings=32768,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ use_sliding_window=False,
138
+ sliding_window=4096,
139
+ max_window_layers=28,
140
+ attention_dropout=0.0,
141
+ is_causal=True,
142
+ _attn_implementation="flash_attention_2",
143
+ **kwargs,
144
+ ):
145
+ self.vocab_size = vocab_size
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.hidden_size = hidden_size
148
+ self.intermediate_size = intermediate_size
149
+ self.num_hidden_layers = num_hidden_layers
150
+ self.num_attention_heads = num_attention_heads
151
+ self.use_sliding_window = use_sliding_window
152
+ self.sliding_window = sliding_window if use_sliding_window else None
153
+ self.max_window_layers = max_window_layers
154
+
155
+ # for backward compatibility
156
+ if num_key_value_heads is None:
157
+ num_key_value_heads = num_attention_heads
158
+
159
+ self.num_key_value_heads = num_key_value_heads
160
+ self.hidden_act = hidden_act
161
+ self.initializer_range = initializer_range
162
+ self.rms_norm_eps = rms_norm_eps
163
+ self.use_cache = use_cache
164
+ self.rope_theta = rope_theta
165
+ self.rope_scaling = rope_scaling
166
+ self.attention_dropout = attention_dropout
167
+ self.is_causal = is_causal
168
+ self._attn_implementation = _attn_implementation
169
+
170
+ # Validate the correctness of rotary position embeddings parameters
171
+ # BC: if there is a 'type' field, move it to 'rope_type'.
172
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
173
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
174
+ rope_config_validation(self)
175
+
176
+ super().__init__(
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
modeling/qwen2/modeling_qwen2.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """PyTorch Qwen2 model."""
5
+
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPast,
18
+ CausalLMOutputWithPast,
19
+ )
20
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import (
23
+ add_start_docstrings,
24
+ add_start_docstrings_to_model_forward,
25
+ is_flash_attn_2_available,
26
+ is_flash_attn_greater_or_equal_2_10,
27
+ logging,
28
+ replace_return_docstrings,
29
+ )
30
+ from .configuration_qwen2 import Qwen2Config
31
+
32
+
33
+ if is_flash_attn_2_available():
34
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
41
+ _CONFIG_FOR_DOC = "Qwen2Config"
42
+
43
+
44
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
45
+ class Qwen2RMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps=1e-6):
47
+ """
48
+ Qwen2RMSNorm is equivalent to T5LayerNorm
49
+ """
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.variance_epsilon = eps
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ hidden_states = hidden_states.to(torch.float32)
57
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+ return self.weight * hidden_states.to(input_dtype)
60
+
61
+ def extra_repr(self):
62
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
66
+ class Qwen2RotaryEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ dim=None,
70
+ max_position_embeddings=2048,
71
+ base=10000,
72
+ device=None,
73
+ scaling_factor=1.0,
74
+ rope_type="default",
75
+ config: Optional[Qwen2Config] = None,
76
+ ):
77
+ super().__init__()
78
+ # TODO (joao): remove the `if` below, only used for BC
79
+ self.rope_kwargs = {}
80
+ if config is None:
81
+ logger.warning_once(
82
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
83
+ "`config` argument. All other arguments will be removed in v4.46"
84
+ )
85
+ self.rope_kwargs = {
86
+ "rope_type": rope_type,
87
+ "factor": scaling_factor,
88
+ "dim": dim,
89
+ "base": base,
90
+ "max_position_embeddings": max_position_embeddings,
91
+ }
92
+ self.rope_type = rope_type
93
+ self.max_seq_len_cached = max_position_embeddings
94
+ self.original_max_seq_len = max_position_embeddings
95
+ else:
96
+ # BC: "rope_type" was originally "type"
97
+ if config.rope_scaling is not None:
98
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
99
+ else:
100
+ self.rope_type = "default"
101
+ self.max_seq_len_cached = config.max_position_embeddings
102
+ self.original_max_seq_len = config.max_position_embeddings
103
+
104
+ self.config = config
105
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
106
+
107
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
108
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
109
+ self.original_inv_freq = self.inv_freq
110
+
111
+ def _dynamic_frequency_update(self, position_ids, device):
112
+ """
113
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
114
+ 1 - growing beyond the cached sequence length (allow scaling)
115
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
116
+ """
117
+ seq_len = torch.max(position_ids) + 1
118
+ if seq_len > self.max_seq_len_cached: # growth
119
+ inv_freq, self.attention_scaling = self.rope_init_fn(
120
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
121
+ )
122
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
123
+ self.max_seq_len_cached = seq_len
124
+
125
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
126
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
127
+ self.max_seq_len_cached = self.original_max_seq_len
128
+
129
+ @torch.no_grad()
130
+ def forward(self, x, position_ids):
131
+ if "dynamic" in self.rope_type:
132
+ self._dynamic_frequency_update(position_ids, device=x.device)
133
+
134
+ # Core RoPE block
135
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
136
+ position_ids_expanded = position_ids[:, None, :].float()
137
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
138
+ device_type = x.device.type
139
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
140
+ with torch.autocast(device_type=device_type, enabled=False):
141
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
+ emb = torch.cat((freqs, freqs), dim=-1)
143
+ cos = emb.cos()
144
+ sin = emb.sin()
145
+
146
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
147
+ cos = cos * self.attention_scaling
148
+ sin = sin * self.attention_scaling
149
+
150
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
+ """Applies Rotary Position Embedding to the query and key tensors.
164
+
165
+ Args:
166
+ q (`torch.Tensor`): The query tensor.
167
+ k (`torch.Tensor`): The key tensor.
168
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
170
+ position_ids (`torch.Tensor`, *optional*):
171
+ Deprecated and unused.
172
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
173
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
+ Returns:
180
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
+ """
182
+ cos = cos.unsqueeze(unsqueeze_dim)
183
+ sin = sin.unsqueeze(unsqueeze_dim)
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
190
+ class Qwen2MLP(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = config.intermediate_size
195
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
+ self.act_fn = ACT2FN[config.hidden_act]
199
+
200
+ def forward(self, hidden_state):
201
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
202
+
203
+
204
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ class Qwen2Attention(nn.Module):
218
+ """
219
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
+ and "Generating Long Sequences with Sparse Transformers".
221
+ """
222
+
223
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ if layer_idx is None:
228
+ logger.warning_once(
229
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ "when creating this class."
232
+ )
233
+
234
+ self.hidden_size = config.hidden_size
235
+ self.num_heads = config.num_attention_heads
236
+ self.head_dim = self.hidden_size // self.num_heads
237
+ self.num_key_value_heads = config.num_key_value_heads
238
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
+ self.max_position_embeddings = config.max_position_embeddings
240
+ self.rope_theta = config.rope_theta
241
+ self.is_causal = config.is_causal
242
+ self.attention_dropout = config.attention_dropout
243
+
244
+ if (self.head_dim * self.num_heads) != self.hidden_size:
245
+ raise ValueError(
246
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
+ f" and `num_heads`: {self.num_heads})."
248
+ )
249
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
+
254
+ def forward(
255
+ self,
256
+ hidden_states: torch.Tensor,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_value: Optional[Cache] = None,
260
+ output_attentions: bool = False,
261
+ use_cache: bool = False,
262
+ cache_position: Optional[torch.LongTensor] = None,
263
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
264
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
265
+ bsz, q_len, _ = hidden_states.size()
266
+
267
+ query_states = self.q_proj(hidden_states)
268
+ key_states = self.k_proj(hidden_states)
269
+ value_states = self.v_proj(hidden_states)
270
+
271
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
272
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+
275
+ if position_embeddings is None:
276
+ logger.warning_once(
277
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
278
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
279
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
280
+ "removed and `position_embeddings` will be mandatory."
281
+ )
282
+ cos, sin = self.rotary_emb(value_states, position_ids)
283
+ else:
284
+ cos, sin = position_embeddings
285
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286
+
287
+ if past_key_value is not None:
288
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
289
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
290
+
291
+ # repeat k/v heads if n_kv_heads < n_heads
292
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
293
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
294
+
295
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
296
+ if attention_mask is not None: # no matter the length, we just slice it
297
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
298
+ attn_weights = attn_weights + causal_mask
299
+
300
+ # upcast attention to fp32
301
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
302
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
303
+ attn_output = torch.matmul(attn_weights, value_states)
304
+
305
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
306
+ raise ValueError(
307
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
308
+ f" {attn_output.size()}"
309
+ )
310
+
311
+ attn_output = attn_output.transpose(1, 2).contiguous()
312
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
313
+
314
+ attn_output = self.o_proj(attn_output)
315
+
316
+ if not output_attentions:
317
+ attn_weights = None
318
+
319
+ return attn_output, attn_weights, past_key_value
320
+
321
+
322
+ class Qwen2FlashAttention2(Qwen2Attention):
323
+ """
324
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
325
+ as the weights of the module stays untouched. The only required change would be on the forward pass
326
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
327
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
328
+ config.max_window_layers layers.
329
+ """
330
+
331
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
336
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
337
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
338
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_value: Optional[Cache] = None,
346
+ output_attentions: bool = False,
347
+ use_cache: bool = False,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
350
+ ):
351
+ bsz, q_len, _ = hidden_states.size()
352
+
353
+ query_states = self.q_proj(hidden_states)
354
+ key_states = self.k_proj(hidden_states)
355
+ value_states = self.v_proj(hidden_states)
356
+
357
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
360
+
361
+ if position_embeddings is None:
362
+ logger.warning_once(
363
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
364
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
365
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
366
+ "removed and `position_embeddings` will be mandatory."
367
+ )
368
+ cos, sin = self.rotary_emb(value_states, position_ids)
369
+ else:
370
+ cos, sin = position_embeddings
371
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
372
+
373
+ if past_key_value is not None:
374
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
375
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
376
+
377
+ # repeat k/v heads if n_kv_heads < n_heads
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
381
+
382
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
383
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
384
+ # cast them back in float16 just to be sure everything works as expected.
385
+ input_dtype = query_states.dtype
386
+ if input_dtype == torch.float32:
387
+ if torch.is_autocast_enabled():
388
+ target_dtype = torch.get_autocast_gpu_dtype()
389
+ # Handle the case where the model is quantized
390
+ elif hasattr(self.config, "_pre_quantization_dtype"):
391
+ target_dtype = self.config._pre_quantization_dtype
392
+ else:
393
+ target_dtype = self.q_proj.weight.dtype
394
+
395
+ logger.warning_once(
396
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
397
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
398
+ f" {target_dtype}."
399
+ )
400
+
401
+ query_states = query_states.to(target_dtype)
402
+ key_states = key_states.to(target_dtype)
403
+ value_states = value_states.to(target_dtype)
404
+
405
+ # Reashape to the expected shape for Flash Attention
406
+ query_states = query_states.transpose(1, 2)
407
+ key_states = key_states.transpose(1, 2)
408
+ value_states = value_states.transpose(1, 2)
409
+
410
+ if (
411
+ self.config.use_sliding_window
412
+ and getattr(self.config, "sliding_window", None) is not None
413
+ and self.layer_idx >= self.config.max_window_layers
414
+ ):
415
+ sliding_window = self.config.sliding_window
416
+ else:
417
+ sliding_window = None
418
+
419
+ attn_output = _flash_attention_forward(
420
+ query_states,
421
+ key_states,
422
+ value_states,
423
+ attention_mask,
424
+ q_len,
425
+ position_ids=position_ids,
426
+ dropout=dropout_rate,
427
+ sliding_window=sliding_window,
428
+ is_causal=self.is_causal,
429
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
430
+ )
431
+
432
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
433
+ attn_output = self.o_proj(attn_output)
434
+
435
+ if not output_attentions:
436
+ attn_weights = None
437
+
438
+ return attn_output, attn_weights, past_key_value
439
+
440
+
441
+ QWEN2_ATTENTION_CLASSES = {
442
+ "eager": Qwen2Attention,
443
+ "flash_attention_2": Qwen2FlashAttention2,
444
+ }
445
+
446
+
447
+ class Qwen2DecoderLayer(nn.Module):
448
+ def __init__(self, config: Qwen2Config, layer_idx: int):
449
+ super().__init__()
450
+ self.hidden_size = config.hidden_size
451
+
452
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
453
+ logger.warning_once(
454
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
455
+ "unexpected results may be encountered."
456
+ )
457
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
458
+
459
+ self.mlp = Qwen2MLP(config)
460
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
469
+ output_attentions: Optional[bool] = False,
470
+ use_cache: Optional[bool] = False,
471
+ cache_position: Optional[torch.LongTensor] = None,
472
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
473
+ **kwargs,
474
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
475
+ """
476
+ Args:
477
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
478
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
479
+ `(batch, sequence_length)` where padding elements are indicated by 0.
480
+ output_attentions (`bool`, *optional*):
481
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
482
+ returned tensors for more detail.
483
+ use_cache (`bool`, *optional*):
484
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
485
+ (see `past_key_values`).
486
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
487
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
488
+ Indices depicting the position of the input sequence tokens in the sequence.
489
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
490
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
491
+ with `head_dim` being the embedding dimension of each attention head.
492
+ kwargs (`dict`, *optional*):
493
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
494
+ into the model
495
+ """
496
+
497
+ residual = hidden_states
498
+
499
+ hidden_states = self.input_layernorm(hidden_states)
500
+
501
+ # Self Attention
502
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
503
+ hidden_states=hidden_states,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_value=past_key_value,
507
+ output_attentions=output_attentions,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ )
512
+ hidden_states = residual + hidden_states
513
+
514
+ # Fully Connected
515
+ residual = hidden_states
516
+ hidden_states = self.post_attention_layernorm(hidden_states)
517
+ hidden_states = self.mlp(hidden_states)
518
+ hidden_states = residual + hidden_states
519
+
520
+ outputs = (hidden_states,)
521
+
522
+ if output_attentions:
523
+ outputs += (self_attn_weights,)
524
+
525
+ if use_cache:
526
+ outputs += (present_key_value,)
527
+
528
+ return outputs
529
+
530
+
531
+ QWEN2_START_DOCSTRING = r"""
532
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
533
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
534
+ etc.)
535
+
536
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
537
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
538
+ and behavior.
539
+
540
+ Parameters:
541
+ config ([`Qwen2Config`]):
542
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
543
+ load the weights associated with the model, only the configuration. Check out the
544
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
545
+ """
546
+
547
+
548
+ @add_start_docstrings(
549
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
550
+ QWEN2_START_DOCSTRING,
551
+ )
552
+ class Qwen2PreTrainedModel(PreTrainedModel):
553
+ config_class = Qwen2Config
554
+ base_model_prefix = "model"
555
+ supports_gradient_checkpointing = True
556
+ _no_split_modules = ["Qwen2DecoderLayer"]
557
+ _skip_keys_device_placement = "past_key_values"
558
+ _supports_flash_attn_2 = True
559
+ _supports_cache_class = True
560
+ _supports_quantized_cache = True
561
+ _supports_static_cache = True
562
+
563
+ def _init_weights(self, module):
564
+ std = self.config.initializer_range
565
+ if isinstance(module, nn.Linear):
566
+ module.weight.data.normal_(mean=0.0, std=std)
567
+ if module.bias is not None:
568
+ module.bias.data.zero_()
569
+ elif isinstance(module, nn.Embedding):
570
+ module.weight.data.normal_(mean=0.0, std=std)
571
+ if module.padding_idx is not None:
572
+ module.weight.data[module.padding_idx].zero_()
573
+
574
+
575
+ QWEN2_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
+ it.
580
+
581
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
+ [`PreTrainedTokenizer.__call__`] for details.
583
+
584
+ [What are input IDs?](../glossary#input-ids)
585
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
+
588
+ - 1 for tokens that are **not masked**,
589
+ - 0 for tokens that are **masked**.
590
+
591
+ [What are attention masks?](../glossary#attention-mask)
592
+
593
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
594
+ [`PreTrainedTokenizer.__call__`] for details.
595
+
596
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
597
+ `past_key_values`).
598
+
599
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
600
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
601
+ information on the default strategy.
602
+
603
+ - 1 indicates the head is **not masked**,
604
+ - 0 indicates the head is **masked**.
605
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
607
+ config.n_positions - 1]`.
608
+
609
+ [What are position IDs?](../glossary#position-ids)
610
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
611
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
612
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
613
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
614
+
615
+ Two formats are allowed:
616
+ - a [`~cache_utils.Cache`] instance, see our
617
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
618
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
619
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
620
+ cache format.
621
+
622
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
623
+ legacy cache format will be returned.
624
+
625
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
626
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
627
+ of shape `(batch_size, sequence_length)`.
628
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
629
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
630
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
631
+ model's internal embedding lookup matrix.
632
+ use_cache (`bool`, *optional*):
633
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
634
+ `past_key_values`).
635
+ output_attentions (`bool`, *optional*):
636
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
637
+ tensors for more detail.
638
+ output_hidden_states (`bool`, *optional*):
639
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
640
+ more detail.
641
+ return_dict (`bool`, *optional*):
642
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
643
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
644
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
645
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
646
+ the complete sequence length.
647
+ """
648
+
649
+
650
+ @add_start_docstrings(
651
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
652
+ QWEN2_START_DOCSTRING,
653
+ )
654
+ class Qwen2Model(Qwen2PreTrainedModel):
655
+ """
656
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
657
+
658
+ Args:
659
+ config: Qwen2Config
660
+ """
661
+
662
+ def __init__(self, config: Qwen2Config):
663
+ super().__init__(config)
664
+ self.padding_idx = config.pad_token_id
665
+ self.vocab_size = config.vocab_size
666
+
667
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
668
+ self.layers = nn.ModuleList(
669
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
670
+ )
671
+ self._attn_implementation = config._attn_implementation
672
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
673
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
674
+
675
+ self.gradient_checkpointing = False
676
+ # Initialize weights and apply final processing
677
+ self.post_init()
678
+
679
+ def get_input_embeddings(self):
680
+ return self.embed_tokens
681
+
682
+ def set_input_embeddings(self, value):
683
+ self.embed_tokens = value
684
+
685
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
686
+ def forward(
687
+ self,
688
+ input_ids: torch.LongTensor = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ position_ids: Optional[torch.LongTensor] = None,
691
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
692
+ inputs_embeds: Optional[torch.FloatTensor] = None,
693
+ use_cache: Optional[bool] = None,
694
+ output_attentions: Optional[bool] = None,
695
+ output_hidden_states: Optional[bool] = None,
696
+ return_dict: Optional[bool] = None,
697
+ cache_position: Optional[torch.LongTensor] = None,
698
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
699
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
+ output_hidden_states = (
701
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
+ )
703
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
704
+
705
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
706
+
707
+ if (input_ids is None) ^ (inputs_embeds is not None):
708
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
709
+
710
+ if self.gradient_checkpointing and self.training:
711
+ if use_cache:
712
+ logger.warning_once(
713
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
714
+ )
715
+ use_cache = False
716
+
717
+ # kept for BC (non `Cache` `past_key_values` inputs)
718
+ return_legacy_cache = False
719
+ if use_cache and not isinstance(past_key_values, Cache):
720
+ return_legacy_cache = True
721
+ if past_key_values is None:
722
+ past_key_values = DynamicCache()
723
+ else:
724
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
725
+ logger.warning_once(
726
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
727
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
728
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
729
+ )
730
+
731
+ if inputs_embeds is None:
732
+ inputs_embeds = self.embed_tokens(input_ids)
733
+
734
+ if cache_position is None:
735
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
736
+ cache_position = torch.arange(
737
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
738
+ )
739
+ if position_ids is None:
740
+ position_ids = cache_position.unsqueeze(0)
741
+
742
+ if attention_mask is not None and 0.0 in attention_mask:
743
+ causal_mask = attention_mask
744
+ else:
745
+ causal_mask = None
746
+
747
+ hidden_states = inputs_embeds
748
+ # create position embeddings to be shared across the decoder layers
749
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
750
+
751
+ # decoder layers
752
+ all_hidden_states = () if output_hidden_states else None
753
+ all_self_attns = () if output_attentions else None
754
+ next_decoder_cache = None
755
+
756
+ for decoder_layer in self.layers:
757
+ if output_hidden_states:
758
+ all_hidden_states += (hidden_states,)
759
+
760
+ if self.gradient_checkpointing and self.training:
761
+ layer_outputs = self._gradient_checkpointing_func(
762
+ decoder_layer.__call__,
763
+ hidden_states,
764
+ causal_mask,
765
+ position_ids,
766
+ past_key_values,
767
+ output_attentions,
768
+ use_cache,
769
+ cache_position,
770
+ position_embeddings,
771
+ )
772
+ else:
773
+ layer_outputs = decoder_layer(
774
+ hidden_states,
775
+ attention_mask=causal_mask,
776
+ position_ids=position_ids,
777
+ past_key_value=past_key_values,
778
+ output_attentions=output_attentions,
779
+ use_cache=use_cache,
780
+ cache_position=cache_position,
781
+ position_embeddings=position_embeddings,
782
+ )
783
+
784
+ hidden_states = layer_outputs[0]
785
+
786
+ if use_cache:
787
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
788
+
789
+ if output_attentions:
790
+ all_self_attns += (layer_outputs[1],)
791
+
792
+ hidden_states = self.norm(hidden_states)
793
+
794
+ # add hidden states from the last decoder layer
795
+ if output_hidden_states:
796
+ all_hidden_states += (hidden_states,)
797
+
798
+ next_cache = next_decoder_cache if use_cache else None
799
+ if return_legacy_cache:
800
+ next_cache = next_cache.to_legacy_cache()
801
+
802
+ if not return_dict:
803
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
804
+ return BaseModelOutputWithPast(
805
+ last_hidden_state=hidden_states,
806
+ past_key_values=next_cache,
807
+ hidden_states=all_hidden_states,
808
+ attentions=all_self_attns,
809
+ )
810
+
811
+
812
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
813
+ _tied_weights_keys = ["lm_head.weight"]
814
+
815
+ def __init__(self, config):
816
+ super().__init__(config)
817
+ self.model = Qwen2Model(config)
818
+ self.vocab_size = config.vocab_size
819
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+
824
+ def get_input_embeddings(self):
825
+ return self.model.embed_tokens
826
+
827
+ def set_input_embeddings(self, value):
828
+ self.model.embed_tokens = value
829
+
830
+ def get_output_embeddings(self):
831
+ return self.lm_head
832
+
833
+ def set_output_embeddings(self, new_embeddings):
834
+ self.lm_head = new_embeddings
835
+
836
+ def set_decoder(self, decoder):
837
+ self.model = decoder
838
+
839
+ def get_decoder(self):
840
+ return self.model
841
+
842
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
843
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
844
+ def forward(
845
+ self,
846
+ input_ids: torch.LongTensor = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.LongTensor] = None,
849
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
850
+ inputs_embeds: Optional[torch.FloatTensor] = None,
851
+ labels: Optional[torch.LongTensor] = None,
852
+ use_cache: Optional[bool] = None,
853
+ output_attentions: Optional[bool] = None,
854
+ output_hidden_states: Optional[bool] = None,
855
+ return_dict: Optional[bool] = None,
856
+ cache_position: Optional[torch.LongTensor] = None,
857
+ num_logits_to_keep: int = 0,
858
+ **loss_kwargs,
859
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
860
+ r"""
861
+ Args:
862
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
863
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
864
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
865
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
866
+
867
+ num_logits_to_keep (`int`, *optional*):
868
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
869
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
870
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
871
+
872
+ Returns:
873
+
874
+ Example:
875
+
876
+ ```python
877
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
878
+
879
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
880
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
881
+
882
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
883
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
884
+
885
+ >>> # Generate
886
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
887
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
888
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
889
+ ```"""
890
+
891
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
+ output_hidden_states = (
893
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
+ )
895
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
896
+
897
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
898
+ outputs = self.model(
899
+ input_ids=input_ids,
900
+ attention_mask=attention_mask,
901
+ position_ids=position_ids,
902
+ past_key_values=past_key_values,
903
+ inputs_embeds=inputs_embeds,
904
+ use_cache=use_cache,
905
+ output_attentions=output_attentions,
906
+ output_hidden_states=output_hidden_states,
907
+ return_dict=return_dict,
908
+ cache_position=cache_position,
909
+ )
910
+
911
+ hidden_states = outputs[0]
912
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
913
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
918
+
919
+ if not return_dict:
920
+ output = (logits,) + outputs[1:]
921
+ return (loss,) + output if loss is not None else output
922
+
923
+ return CausalLMOutputWithPast(
924
+ loss=loss,
925
+ logits=logits,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ )
modeling/qwen2/tokenization_qwen2.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ import json
7
+ import os
8
+ import unicodedata
9
+ from functools import lru_cache
10
+ from typing import Optional, Tuple
11
+
12
+ import regex as re
13
+
14
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
15
+ from transformers.utils import logging
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ VOCAB_FILES_NAMES = {
21
+ "vocab_file": "vocab.json",
22
+ "merges_file": "merges.txt",
23
+ }
24
+
25
+
26
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
27
+
28
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
29
+
30
+
31
+ @lru_cache()
32
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
33
+ def bytes_to_unicode():
34
+ """
35
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
36
+ characters the bpe code barfs on.
37
+
38
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
39
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
40
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
41
+ tables between utf-8 bytes and unicode strings.
42
+ """
43
+ bs = (
44
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
45
+ )
46
+ cs = bs[:]
47
+ n = 0
48
+ for b in range(2**8):
49
+ if b not in bs:
50
+ bs.append(b)
51
+ cs.append(2**8 + n)
52
+ n += 1
53
+ cs = [chr(n) for n in cs]
54
+ return dict(zip(bs, cs))
55
+
56
+
57
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
58
+ def get_pairs(word):
59
+ """
60
+ Return set of symbol pairs in a word.
61
+
62
+ Word is represented as tuple of symbols (symbols being variable-length strings).
63
+ """
64
+ pairs = set()
65
+ prev_char = word[0]
66
+ for char in word[1:]:
67
+ pairs.add((prev_char, char))
68
+ prev_char = char
69
+ return pairs
70
+
71
+
72
+ class Qwen2Tokenizer(PreTrainedTokenizer):
73
+ """
74
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
75
+
76
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
77
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
78
+
79
+ ```python
80
+ >>> from transformers import Qwen2Tokenizer
81
+
82
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
83
+ >>> tokenizer("Hello world")["input_ids"]
84
+ [9707, 1879]
85
+
86
+ >>> tokenizer(" Hello world")["input_ids"]
87
+ [21927, 1879]
88
+ ```
89
+ This is expected.
90
+
91
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
92
+
93
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
94
+ this superclass for more information regarding those methods.
95
+
96
+ Args:
97
+ vocab_file (`str`):
98
+ Path to the vocabulary file.
99
+ merges_file (`str`):
100
+ Path to the merges file.
101
+ errors (`str`, *optional*, defaults to `"replace"`):
102
+ Paradigm to follow when decoding bytes to UTF-8. See
103
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
104
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
105
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
106
+ token instead.
107
+ bos_token (`str`, *optional*):
108
+ The beginning of sequence token. Not applicable for this tokenizer.
109
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
110
+ The end of sequence token.
111
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
112
+ The token used for padding, for example when batching sequences of different lengths.
113
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
114
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
115
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
116
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
117
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
118
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
119
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
120
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
121
+ """
122
+
123
+ vocab_files_names = VOCAB_FILES_NAMES
124
+ model_input_names = ["input_ids", "attention_mask"]
125
+
126
+ def __init__(
127
+ self,
128
+ vocab_file,
129
+ merges_file,
130
+ errors="replace",
131
+ unk_token="<|endoftext|>",
132
+ bos_token=None,
133
+ eos_token="<|endoftext|>",
134
+ pad_token="<|endoftext|>",
135
+ clean_up_tokenization_spaces=False,
136
+ split_special_tokens=False,
137
+ **kwargs,
138
+ ):
139
+ # Qwen vocab does not contain control tokens; added tokens need to be special
140
+ bos_token = (
141
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
142
+ if isinstance(bos_token, str)
143
+ else bos_token
144
+ )
145
+ eos_token = (
146
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
147
+ if isinstance(eos_token, str)
148
+ else eos_token
149
+ )
150
+ unk_token = (
151
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
152
+ if isinstance(unk_token, str)
153
+ else unk_token
154
+ )
155
+ pad_token = (
156
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
157
+ if isinstance(pad_token, str)
158
+ else pad_token
159
+ )
160
+
161
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
162
+ self.encoder = json.load(vocab_handle)
163
+ self.decoder = {v: k for k, v in self.encoder.items()}
164
+ self.errors = errors # how to handle errors in decoding
165
+ self.byte_encoder = bytes_to_unicode()
166
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
167
+ bpe_merges = []
168
+ with open(merges_file, encoding="utf-8") as merges_handle:
169
+ for i, line in enumerate(merges_handle):
170
+ line = line.strip()
171
+ if (i == 0 and line.startswith("#version:")) or not line:
172
+ continue
173
+ bpe_merges.append(tuple(line.split()))
174
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
175
+ # NOTE: the cache can grow without bound and will get really large for long running processes
176
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
177
+ # not a memory leak but appears as one.
178
+ # GPT2Tokenizer has the same problem, so let's be consistent.
179
+ self.cache = {}
180
+
181
+ self.pat = re.compile(PRETOKENIZE_REGEX)
182
+
183
+ if kwargs.get("add_prefix_space", False):
184
+ logger.warning_once(
185
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
186
+ )
187
+
188
+ super().__init__(
189
+ errors=errors,
190
+ bos_token=bos_token,
191
+ eos_token=eos_token,
192
+ pad_token=pad_token,
193
+ unk_token=unk_token,
194
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
195
+ split_special_tokens=split_special_tokens,
196
+ **kwargs,
197
+ )
198
+
199
+ @property
200
+ def vocab_size(self) -> int:
201
+ return len(self.encoder)
202
+
203
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
204
+ def get_vocab(self):
205
+ return dict(self.encoder, **self.added_tokens_encoder)
206
+
207
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
208
+ def bpe(self, token):
209
+ if token in self.cache:
210
+ return self.cache[token]
211
+ word = tuple(token)
212
+ pairs = get_pairs(word)
213
+
214
+ if not pairs:
215
+ return token
216
+
217
+ while True:
218
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
219
+ if bigram not in self.bpe_ranks:
220
+ break
221
+ first, second = bigram
222
+ new_word = []
223
+ i = 0
224
+ while i < len(word):
225
+ try:
226
+ j = word.index(first, i)
227
+ except ValueError:
228
+ new_word.extend(word[i:])
229
+ break
230
+ else:
231
+ new_word.extend(word[i:j])
232
+ i = j
233
+
234
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
235
+ new_word.append(first + second)
236
+ i += 2
237
+ else:
238
+ new_word.append(word[i])
239
+ i += 1
240
+ new_word = tuple(new_word)
241
+ word = new_word
242
+ if len(word) == 1:
243
+ break
244
+ else:
245
+ pairs = get_pairs(word)
246
+ word = " ".join(word)
247
+ self.cache[token] = word
248
+ return word
249
+
250
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
251
+ def _tokenize(self, text):
252
+ """Tokenize a string."""
253
+ bpe_tokens = []
254
+ for token in re.findall(self.pat, text):
255
+ token = "".join(
256
+ self.byte_encoder[b] for b in token.encode("utf-8")
257
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
258
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
259
+ return bpe_tokens
260
+
261
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
262
+ def _convert_token_to_id(self, token):
263
+ """Converts a token (str) in an id using the vocab."""
264
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
265
+
266
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
267
+ def _convert_id_to_token(self, index):
268
+ """Converts an index (integer) in a token (str) using the vocab."""
269
+ return self.decoder.get(index)
270
+
271
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
272
+ def convert_tokens_to_string(self, tokens):
273
+ """Converts a sequence of tokens (string) in a single string."""
274
+ text = "".join(tokens)
275
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
276
+ return text
277
+
278
+ def decode(
279
+ self,
280
+ token_ids,
281
+ skip_special_tokens: bool = False,
282
+ clean_up_tokenization_spaces: Optional[bool] = False,
283
+ spaces_between_special_tokens: bool = False,
284
+ **kwargs,
285
+ ) -> str:
286
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
287
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
288
+ return super().decode(
289
+ token_ids,
290
+ skip_special_tokens=skip_special_tokens,
291
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
292
+ spaces_between_special_tokens=spaces_between_special_tokens,
293
+ **kwargs,
294
+ )
295
+
296
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
297
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
298
+ if not os.path.isdir(save_directory):
299
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
300
+ return
301
+ vocab_file = os.path.join(
302
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
303
+ )
304
+ merge_file = os.path.join(
305
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
306
+ )
307
+
308
+ with open(vocab_file, "w", encoding="utf-8") as f:
309
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
310
+
311
+ index = 0
312
+ with open(merge_file, "w", encoding="utf-8") as writer:
313
+ writer.write("#version: 0.2\n")
314
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
315
+ if index != token_index:
316
+ logger.warning(
317
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
318
+ " Please check that the tokenizer is not corrupted!"
319
+ )
320
+ index = token_index
321
+ writer.write(" ".join(bpe_tokens) + "\n")
322
+ index += 1
323
+
324
+ return vocab_file, merge_file
325
+
326
+ def prepare_for_tokenization(self, text, **kwargs):
327
+ text = unicodedata.normalize("NFC", text)
328
+ return (text, kwargs)
modeling/qwen2/tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken
9
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
10
+ from transformers.utils import logging
11
+ from .tokenization_qwen2 import Qwen2Tokenizer
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ VOCAB_FILES_NAMES = {
17
+ "vocab_file": "vocab.json",
18
+ "merges_file": "merges.txt",
19
+ "tokenizer_file": "tokenizer.json",
20
+ }
21
+
22
+
23
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
24
+
25
+
26
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
27
+ """
28
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
29
+ Byte-Pair-Encoding.
30
+
31
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
32
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
33
+
34
+ ```python
35
+ >>> from transformers import Qwen2TokenizerFast
36
+
37
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
38
+ >>> tokenizer("Hello world")["input_ids"]
39
+ [9707, 1879]
40
+
41
+ >>> tokenizer(" Hello world")["input_ids"]
42
+ [21927, 1879]
43
+ ```
44
+ This is expected.
45
+
46
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
47
+ refer to this superclass for more information regarding those methods.
48
+
49
+ Args:
50
+ vocab_file (`str`, *optional*):
51
+ Path to the vocabulary file.
52
+ merges_file (`str`, *optional*):
53
+ Path to the merges file.
54
+ tokenizer_file (`str`, *optional*):
55
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
56
+ contains everything needed to load the tokenizer.
57
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
58
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
59
+ token instead. Not applicable to this tokenizer.
60
+ bos_token (`str`, *optional*):
61
+ The beginning of sequence token. Not applicable for this tokenizer.
62
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
63
+ The end of sequence token.
64
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
65
+ The token used for padding, for example when batching sequences of different lengths.
66
+ """
67
+
68
+ vocab_files_names = VOCAB_FILES_NAMES
69
+ model_input_names = ["input_ids", "attention_mask"]
70
+ slow_tokenizer_class = Qwen2Tokenizer
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file=None,
75
+ merges_file=None,
76
+ tokenizer_file=None,
77
+ unk_token="<|endoftext|>",
78
+ bos_token=None,
79
+ eos_token="<|endoftext|>",
80
+ pad_token="<|endoftext|>",
81
+ **kwargs,
82
+ ):
83
+ # We need to at least pass vocab_file and merges_file to base class
84
+ # in case a slow tokenizer needs to be initialized; other can be
85
+ # configured through files.
86
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
87
+
88
+ bos_token = (
89
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
90
+ if isinstance(bos_token, str)
91
+ else bos_token
92
+ )
93
+ eos_token = (
94
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
95
+ if isinstance(eos_token, str)
96
+ else eos_token
97
+ )
98
+ unk_token = (
99
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
100
+ if isinstance(unk_token, str)
101
+ else unk_token
102
+ )
103
+ pad_token = (
104
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
105
+ if isinstance(pad_token, str)
106
+ else pad_token
107
+ )
108
+
109
+ super().__init__(
110
+ vocab_file=vocab_file,
111
+ merges_file=merges_file,
112
+ tokenizer_file=tokenizer_file,
113
+ unk_token=unk_token,
114
+ bos_token=bos_token,
115
+ eos_token=eos_token,
116
+ pad_token=pad_token,
117
+ **kwargs,
118
+ )
119
+
120
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
121
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
122
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
123
+ return tuple(files)
modeling/siglip/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_sentencepiece_available,
10
+ is_torch_available,
11
+ is_vision_available,
12
+ )
13
+
14
+
15
+ _import_structure = {
16
+ "configuration_siglip": [
17
+ "SiglipConfig",
18
+ "SiglipTextConfig",
19
+ "SiglipVisionConfig",
20
+ ],
21
+ "processing_siglip": ["SiglipProcessor"],
22
+ }
23
+
24
+ try:
25
+ if not is_sentencepiece_available():
26
+ raise OptionalDependencyNotAvailable()
27
+ except OptionalDependencyNotAvailable:
28
+ pass
29
+ else:
30
+ _import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
31
+
32
+
33
+ try:
34
+ if not is_vision_available():
35
+ raise OptionalDependencyNotAvailable()
36
+ except OptionalDependencyNotAvailable:
37
+ pass
38
+ else:
39
+ _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
40
+
41
+ try:
42
+ if not is_torch_available():
43
+ raise OptionalDependencyNotAvailable()
44
+ except OptionalDependencyNotAvailable:
45
+ pass
46
+ else:
47
+ _import_structure["modeling_siglip"] = [
48
+ "SiglipModel",
49
+ "SiglipPreTrainedModel",
50
+ "SiglipTextModel",
51
+ "SiglipVisionModel",
52
+ "SiglipForImageClassification",
53
+ ]
54
+
55
+
56
+ if TYPE_CHECKING:
57
+ from .configuration_siglip import (
58
+ SiglipConfig,
59
+ SiglipTextConfig,
60
+ SiglipVisionConfig,
61
+ )
62
+ from .processing_siglip import SiglipProcessor
63
+
64
+ try:
65
+ if not is_sentencepiece_available():
66
+ raise OptionalDependencyNotAvailable()
67
+ except OptionalDependencyNotAvailable:
68
+ pass
69
+ else:
70
+ from .tokenization_siglip import SiglipTokenizer
71
+
72
+ try:
73
+ if not is_vision_available():
74
+ raise OptionalDependencyNotAvailable()
75
+ except OptionalDependencyNotAvailable:
76
+ pass
77
+ else:
78
+ from .image_processing_siglip import SiglipImageProcessor
79
+
80
+ try:
81
+ if not is_torch_available():
82
+ raise OptionalDependencyNotAvailable()
83
+ except OptionalDependencyNotAvailable:
84
+ pass
85
+ else:
86
+ from .modeling_siglip import (
87
+ SiglipForImageClassification,
88
+ SiglipModel,
89
+ SiglipPreTrainedModel,
90
+ SiglipTextModel,
91
+ SiglipVisionModel,
92
+ )
93
+
94
+
95
+ else:
96
+ import sys
97
+
98
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/siglip/configuration_siglip.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Siglip model configuration"""
5
+
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class SiglipTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
19
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
20
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
21
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 32000):
28
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`SiglipModel`].
30
+ hidden_size (`int`, *optional*, defaults to 768):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 3072):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 12):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 64):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ pad_token_id (`int`, *optional*, defaults to 1):
49
+ The id of the padding token in the vocabulary.
50
+ bos_token_id (`int`, *optional*, defaults to 49406):
51
+ The id of the beginning-of-sequence token in the vocabulary.
52
+ eos_token_id (`int`, *optional*, defaults to 49407):
53
+ The id of the end-of-sequence token in the vocabulary.
54
+
55
+ Example:
56
+
57
+ ```python
58
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
59
+
60
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
61
+ >>> configuration = SiglipTextConfig()
62
+
63
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
64
+ >>> model = SiglipTextModel(configuration)
65
+
66
+ >>> # Accessing the model configuration
67
+ >>> configuration = model.config
68
+ ```"""
69
+
70
+ model_type = "siglip_text_model"
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size=32000,
75
+ hidden_size=768,
76
+ intermediate_size=3072,
77
+ num_hidden_layers=12,
78
+ num_attention_heads=12,
79
+ max_position_embeddings=64,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
84
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
85
+ pad_token_id=1,
86
+ bos_token_id=49406,
87
+ eos_token_id=49407,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
91
+
92
+ self.vocab_size = vocab_size
93
+ self.hidden_size = hidden_size
94
+ self.intermediate_size = intermediate_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.attention_dropout = attention_dropout
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
104
+ cls._set_token_in_kwargs(kwargs)
105
+
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ # get the text config dict if we are loading from SiglipConfig
109
+ if config_dict.get("model_type") == "siglip":
110
+ config_dict = config_dict["text_config"]
111
+
112
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
113
+ logger.warning(
114
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
115
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
116
+ )
117
+
118
+ return cls.from_dict(config_dict, **kwargs)
119
+
120
+
121
+ class SiglipVisionConfig(PretrainedConfig):
122
+ r"""
123
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
124
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
125
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
126
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
127
+
128
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
+ documentation from [`PretrainedConfig`] for more information.
130
+
131
+ Args:
132
+ hidden_size (`int`, *optional*, defaults to 768):
133
+ Dimensionality of the encoder layers and the pooler layer.
134
+ intermediate_size (`int`, *optional*, defaults to 3072):
135
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
136
+ num_hidden_layers (`int`, *optional*, defaults to 12):
137
+ Number of hidden layers in the Transformer encoder.
138
+ num_attention_heads (`int`, *optional*, defaults to 12):
139
+ Number of attention heads for each attention layer in the Transformer encoder.
140
+ num_channels (`int`, *optional*, defaults to 3):
141
+ Number of channels in the input images.
142
+ image_size (`int`, *optional*, defaults to 224):
143
+ The size (resolution) of each image.
144
+ patch_size (`int`, *optional*, defaults to 16):
145
+ The size (resolution) of each patch.
146
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
147
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
148
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
149
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
150
+ The epsilon used by the layer normalization layers.
151
+ attention_dropout (`float`, *optional*, defaults to 0.0):
152
+ The dropout ratio for the attention probabilities.
153
+
154
+ Example:
155
+
156
+ ```python
157
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
158
+
159
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
160
+ >>> configuration = SiglipVisionConfig()
161
+
162
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
163
+ >>> model = SiglipVisionModel(configuration)
164
+
165
+ >>> # Accessing the model configuration
166
+ >>> configuration = model.config
167
+ ```"""
168
+
169
+ model_type = "siglip_vision_model"
170
+
171
+ def __init__(
172
+ self,
173
+ hidden_size=768,
174
+ intermediate_size=3072,
175
+ num_hidden_layers=12,
176
+ num_attention_heads=12,
177
+ num_channels=3,
178
+ image_size=224,
179
+ patch_size=16,
180
+ hidden_act="gelu_pytorch_tanh",
181
+ layer_norm_eps=1e-6,
182
+ attention_dropout=0.0,
183
+ **kwargs,
184
+ ):
185
+ super().__init__(**kwargs)
186
+
187
+ self.hidden_size = hidden_size
188
+ self.intermediate_size = intermediate_size
189
+ self.num_hidden_layers = num_hidden_layers
190
+ self.num_attention_heads = num_attention_heads
191
+ self.num_channels = num_channels
192
+ self.patch_size = patch_size
193
+ self.image_size = image_size
194
+ self.attention_dropout = attention_dropout
195
+ self.layer_norm_eps = layer_norm_eps
196
+ self.hidden_act = hidden_act
197
+
198
+ @classmethod
199
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
200
+ cls._set_token_in_kwargs(kwargs)
201
+
202
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
203
+
204
+ # get the vision config dict if we are loading from SiglipConfig
205
+ if config_dict.get("model_type") == "siglip":
206
+ config_dict = config_dict["vision_config"]
207
+
208
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
209
+ logger.warning(
210
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
211
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
212
+ )
213
+
214
+ return cls.from_dict(config_dict, **kwargs)
215
+
216
+
217
+ class SiglipConfig(PretrainedConfig):
218
+ r"""
219
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
220
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
221
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
222
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
223
+
224
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
225
+ documentation from [`PretrainedConfig`] for more information.
226
+
227
+ Args:
228
+ text_config (`dict`, *optional*):
229
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
230
+ vision_config (`dict`, *optional*):
231
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
232
+ kwargs (*optional*):
233
+ Dictionary of keyword arguments.
234
+
235
+ Example:
236
+
237
+ ```python
238
+ >>> from transformers import SiglipConfig, SiglipModel
239
+
240
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
+ >>> configuration = SiglipConfig()
242
+
243
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
244
+ >>> model = SiglipModel(configuration)
245
+
246
+ >>> # Accessing the model configuration
247
+ >>> configuration = model.config
248
+
249
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
250
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
251
+
252
+ >>> # Initializing a SiglipText and SiglipVision configuration
253
+ >>> config_text = SiglipTextConfig()
254
+ >>> config_vision = SiglipVisionConfig()
255
+
256
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
257
+ ```"""
258
+
259
+ model_type = "siglip"
260
+
261
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
262
+ super().__init__(**kwargs)
263
+
264
+ if text_config is None:
265
+ text_config = {}
266
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
267
+
268
+ if vision_config is None:
269
+ vision_config = {}
270
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
271
+
272
+ self.text_config = SiglipTextConfig(**text_config)
273
+ self.vision_config = SiglipVisionConfig(**vision_config)
274
+
275
+ self.initializer_factor = 1.0
276
+
277
+ @classmethod
278
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
279
+ r"""
280
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
281
+ model configuration.
282
+
283
+ Returns:
284
+ [`SiglipConfig`]: An instance of a configuration object
285
+ """
286
+
287
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
modeling/siglip/convert_siglip_to_hf.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Convert SigLIP checkpoints from the original repository.
5
+
6
+ URL: https://github.com/google-research/big_vision/tree/main
7
+ """
8
+
9
+ import argparse
10
+ import collections
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from huggingface_hub import hf_hub_download
17
+ from numpy import load
18
+ from PIL import Image
19
+
20
+ from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
21
+ from transformers.utils import logging
22
+
23
+
24
+ logging.set_verbosity_info()
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ model_name_to_checkpoint = {
29
+ # base checkpoints
30
+ "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
31
+ "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
32
+ "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
33
+ "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
34
+ # large checkpoints
35
+ "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
36
+ "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
37
+ # multilingual checkpoint
38
+ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
39
+ # so400m checkpoints
40
+ "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
41
+ }
42
+
43
+ model_name_to_image_size = {
44
+ "siglip-base-patch16-224": 224,
45
+ "siglip-base-patch16-256": 256,
46
+ "siglip-base-patch16-384": 384,
47
+ "siglip-base-patch16-512": 512,
48
+ "siglip-large-patch16-256": 256,
49
+ "siglip-large-patch16-384": 384,
50
+ "siglip-base-patch16-256-i18n": 256,
51
+ "siglip-so400m-patch14-384": 384,
52
+ }
53
+
54
+
55
+ def get_siglip_config(model_name):
56
+ config = SiglipConfig()
57
+
58
+ vocab_size = 250000 if "i18n" in model_name else 32000
59
+ image_size = model_name_to_image_size[model_name]
60
+ patch_size = 16 if "patch16" in model_name else 14
61
+
62
+ # size of the architecture
63
+ config.vision_config.image_size = image_size
64
+ config.vision_config.patch_size = patch_size
65
+ config.text_config.vocab_size = vocab_size
66
+
67
+ if "base" in model_name:
68
+ pass
69
+ elif "large" in model_name:
70
+ config.text_config.hidden_size = 1024
71
+ config.text_config.intermediate_size = 4096
72
+ config.text_config.num_hidden_layers = 24
73
+ config.text_config.num_attention_heads = 16
74
+ config.vision_config.hidden_size = 1024
75
+ config.vision_config.intermediate_size = 4096
76
+ config.vision_config.num_hidden_layers = 24
77
+ config.vision_config.num_attention_heads = 16
78
+ elif "so400m" in model_name:
79
+ config.text_config.hidden_size = 1152
80
+ config.text_config.intermediate_size = 4304
81
+ config.text_config.num_hidden_layers = 27
82
+ config.text_config.num_attention_heads = 16
83
+ config.vision_config.hidden_size = 1152
84
+ config.vision_config.intermediate_size = 4304
85
+ config.vision_config.num_hidden_layers = 27
86
+ config.vision_config.num_attention_heads = 16
87
+ else:
88
+ raise ValueError("Model not supported")
89
+
90
+ return config
91
+
92
+
93
+ def create_rename_keys(config):
94
+ rename_keys = []
95
+ # fmt: off
96
+
97
+ # vision encoder
98
+
99
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
100
+ rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
101
+ rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
102
+
103
+ for i in range(config.vision_config.num_hidden_layers):
104
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
105
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
106
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
107
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
108
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
109
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
110
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
111
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
112
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
113
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
114
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
115
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
116
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
117
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
118
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
119
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
120
+
121
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
122
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
123
+
124
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
125
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
126
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
127
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
128
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
129
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
130
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
131
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
132
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
133
+
134
+ # text encoder
135
+
136
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
137
+ rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
138
+
139
+ for i in range(config.text_config.num_hidden_layers):
140
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
141
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
142
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
143
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
144
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
145
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
146
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
147
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
148
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
149
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
150
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
151
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
152
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
153
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
154
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
155
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
156
+
157
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
158
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
159
+ rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
160
+ rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
161
+
162
+ # learned temperature and bias
163
+ rename_keys.append(("params/t", "logit_scale"))
164
+ rename_keys.append(("params/b", "logit_bias"))
165
+
166
+ # fmt: on
167
+ return rename_keys
168
+
169
+
170
+ def rename_key(dct, old, new, config):
171
+ val = dct.pop(old)
172
+
173
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
174
+ val = val.reshape(-1, config.vision_config.hidden_size)
175
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
176
+ val = val.reshape(-1, config.text_config.hidden_size)
177
+
178
+ if "patch_embedding.weight" in new:
179
+ val = val.transpose(3, 2, 0, 1)
180
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
181
+ val = val.T
182
+
183
+ if "position_embedding" in new and "vision" in new:
184
+ val = val.reshape(-1, config.vision_config.hidden_size)
185
+ if "position_embedding" in new and "text" in new:
186
+ val = val.reshape(-1, config.text_config.hidden_size)
187
+
188
+ if new.endswith("bias"):
189
+ val = val.reshape(-1)
190
+
191
+ dct[new] = torch.from_numpy(val)
192
+
193
+
194
+ def read_in_q_k_v_head(state_dict, config):
195
+ # read in individual input projection layers
196
+ key_proj_weight = (
197
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
198
+ .reshape(-1, config.vision_config.hidden_size)
199
+ .T
200
+ )
201
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
202
+ value_proj_weight = (
203
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
204
+ .reshape(-1, config.vision_config.hidden_size)
205
+ .T
206
+ )
207
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
208
+ query_proj_weight = (
209
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
210
+ .reshape(-1, config.vision_config.hidden_size)
211
+ .T
212
+ )
213
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
214
+
215
+ # next, add them to the state dict as a single matrix + vector
216
+ state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
217
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
218
+ )
219
+ state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
220
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
221
+ )
222
+
223
+
224
+ # We will verify our results on an image of cute cats
225
+ def prepare_img():
226
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
227
+ image = Image.open(requests.get(url, stream=True).raw)
228
+ return image
229
+
230
+
231
+ def flatten_nested_dict(params, parent_key="", sep="/"):
232
+ items = []
233
+
234
+ for k, v in params.items():
235
+ new_key = parent_key + sep + k if parent_key else k
236
+
237
+ if isinstance(v, collections.abc.MutableMapping):
238
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
239
+ else:
240
+ items.append((new_key, v))
241
+ return dict(items)
242
+
243
+
244
+ @torch.no_grad()
245
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
246
+ """
247
+ Copy/paste/tweak model's weights to our SigLIP structure.
248
+ """
249
+
250
+ # define default SigLIP configuration
251
+ config = get_siglip_config(model_name)
252
+
253
+ # get checkpoint
254
+ checkpoint = model_name_to_checkpoint[model_name]
255
+
256
+ # get vocab file
257
+ if "i18n" in model_name:
258
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
259
+ else:
260
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
261
+
262
+ # load original state dict
263
+ data = load(checkpoint)
264
+ state_dict = flatten_nested_dict(data)
265
+
266
+ # remove and rename some keys
267
+ rename_keys = create_rename_keys(config)
268
+ for src, dest in rename_keys:
269
+ rename_key(state_dict, src, dest, config)
270
+
271
+ # qkv matrices of attention pooling head need special treatment
272
+ read_in_q_k_v_head(state_dict, config)
273
+
274
+ # load HuggingFace model
275
+ model = SiglipModel(config).eval()
276
+ model.load_state_dict(state_dict)
277
+
278
+ # create processor
279
+ # important: make tokenizer not return attention_mask since original one doesn't require it
280
+ image_size = config.vision_config.image_size
281
+ size = {"height": image_size, "width": image_size}
282
+ image_processor = SiglipImageProcessor(size=size)
283
+ tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
284
+ processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
285
+
286
+ # verify on dummy images and texts
287
+ url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
288
+ image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
289
+ url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
290
+ image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
291
+ texts = ["an apple", "a picture of an apple"]
292
+
293
+ inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
294
+
295
+ # verify input_ids against original ones
296
+ if image_size == 224:
297
+ filename = "siglip_pixel_values.pt"
298
+ elif image_size == 256:
299
+ filename = "siglip_pixel_values_256.pt"
300
+ elif image_size == 384:
301
+ filename = "siglip_pixel_values_384.pt"
302
+ elif image_size == 512:
303
+ filename = "siglip_pixel_values_512.pt"
304
+ else:
305
+ raise ValueError("Image size not supported")
306
+
307
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
308
+ original_pixel_values = torch.load(filepath)
309
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
310
+ original_input_ids = torch.load(filepath)
311
+
312
+ if "i18n" not in model_name:
313
+ assert inputs.input_ids.tolist() == original_input_ids.tolist()
314
+
315
+ print("Mean of original pixel values:", original_pixel_values.mean())
316
+ print("Mean of new pixel values:", inputs.pixel_values.mean())
317
+
318
+ # note: we're testing with original pixel values here since we don't have exact pixel values
319
+ with torch.no_grad():
320
+ outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
321
+
322
+ # with torch.no_grad():
323
+ # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
324
+
325
+ print(outputs.logits_per_image[:3, :3])
326
+
327
+ probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
328
+ print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
329
+ print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
330
+
331
+ if verify_logits:
332
+ if model_name == "siglip-base-patch16-224":
333
+ expected_slice = torch.tensor(
334
+ [[-2.9621, -2.1672], [-0.2713, 0.2910]],
335
+ )
336
+ elif model_name == "siglip-base-patch16-256":
337
+ expected_slice = torch.tensor(
338
+ [[-3.1146, -1.9894], [-0.7312, 0.6387]],
339
+ )
340
+ elif model_name == "siglip-base-patch16-384":
341
+ expected_slice = torch.tensor(
342
+ [[-2.8098, -2.1891], [-0.4242, 0.4102]],
343
+ )
344
+ elif model_name == "siglip-base-patch16-512":
345
+ expected_slice = torch.tensor(
346
+ [[-2.7899, -2.2668], [-0.4295, -0.0735]],
347
+ )
348
+ elif model_name == "siglip-large-patch16-256":
349
+ expected_slice = torch.tensor(
350
+ [[-1.5827, -0.5801], [-0.9153, 0.1363]],
351
+ )
352
+ elif model_name == "siglip-large-patch16-384":
353
+ expected_slice = torch.tensor(
354
+ [[-2.1523, -0.2899], [-0.2959, 0.7884]],
355
+ )
356
+ elif model_name == "siglip-so400m-patch14-384":
357
+ expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
358
+ elif model_name == "siglip-base-patch16-256-i18n":
359
+ expected_slice = torch.tensor(
360
+ [[-0.9064, 0.1073], [-0.0299, 0.5304]],
361
+ )
362
+
363
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
364
+ print("Looks ok!")
365
+
366
+ if pytorch_dump_folder_path is not None:
367
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
368
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
369
+ model.save_pretrained(pytorch_dump_folder_path)
370
+ print(f"Saving processor to {pytorch_dump_folder_path}")
371
+ processor.save_pretrained(pytorch_dump_folder_path)
372
+
373
+ if push_to_hub:
374
+ model.push_to_hub(f"nielsr/{model_name}")
375
+ processor.push_to_hub(f"nielsr/{model_name}")
376
+
377
+
378
+ if __name__ == "__main__":
379
+ parser = argparse.ArgumentParser()
380
+ # Required parameters
381
+ parser.add_argument(
382
+ "--model_name",
383
+ default="siglip-base-patch16-224",
384
+ type=str,
385
+ choices=model_name_to_checkpoint.keys(),
386
+ help="Name of the model you'd like to convert.",
387
+ )
388
+ parser.add_argument(
389
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
390
+ )
391
+ parser.add_argument(
392
+ "--verify_logits",
393
+ action="store_false",
394
+ help="Whether to verify logits against the original implementation.",
395
+ )
396
+ parser.add_argument(
397
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
398
+ )
399
+
400
+ args = parser.parse_args()
401
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
modeling/siglip/image_processing_siglip.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Image processor class for SigLIP."""
5
+
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
9
+ from transformers.image_transforms import (
10
+ convert_to_rgb,
11
+ resize,
12
+ to_channel_dimension_format,
13
+ )
14
+ from transformers.image_utils import (
15
+ IMAGENET_STANDARD_MEAN,
16
+ IMAGENET_STANDARD_STD,
17
+ ChannelDimension,
18
+ ImageInput,
19
+ PILImageResampling,
20
+ infer_channel_dimension_format,
21
+ is_scaled_image,
22
+ make_list_of_images,
23
+ to_numpy_array,
24
+ valid_images,
25
+ validate_preprocess_arguments,
26
+ )
27
+ from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ if is_vision_available():
34
+ import PIL
35
+
36
+
37
+ class SiglipImageProcessor(BaseImageProcessor):
38
+ r"""
39
+ Constructs a SigLIP image processor.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
44
+ `do_resize` in the `preprocess` method.
45
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
46
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
47
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
48
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
49
+ do_rescale (`bool`, *optional*, defaults to `True`):
50
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
51
+ the `preprocess` method.
52
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
53
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
54
+ method.
55
+ do_normalize (`bool`, *optional*, defaults to `True`):
56
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
57
+ `do_normalize` in the `preprocess` method.
58
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
59
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
60
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
61
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
62
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
63
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
64
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
65
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
66
+ Whether to convert the image to RGB.
67
+ """
68
+
69
+ model_input_names = ["pixel_values"]
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ size: Dict[str, int] = None,
75
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
76
+ do_rescale: bool = True,
77
+ rescale_factor: Union[int, float] = 1 / 255,
78
+ do_normalize: bool = True,
79
+ image_mean: Optional[Union[float, List[float]]] = None,
80
+ image_std: Optional[Union[float, List[float]]] = None,
81
+ do_convert_rgb: bool = None,
82
+ **kwargs,
83
+ ) -> None:
84
+ super().__init__(**kwargs)
85
+ size = size if size is not None else {"height": 224, "width": 224}
86
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
87
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
88
+
89
+ self.do_resize = do_resize
90
+ self.size = size
91
+ self.resample = resample
92
+ self.do_rescale = do_rescale
93
+ self.rescale_factor = rescale_factor
94
+ self.do_normalize = do_normalize
95
+ self.image_mean = image_mean
96
+ self.image_std = image_std
97
+ self.do_convert_rgb = do_convert_rgb
98
+
99
+ @filter_out_non_signature_kwargs()
100
+ def preprocess(
101
+ self,
102
+ images: ImageInput,
103
+ do_resize: bool = None,
104
+ size: Dict[str, int] = None,
105
+ resample: PILImageResampling = None,
106
+ do_rescale: bool = None,
107
+ rescale_factor: float = None,
108
+ do_normalize: bool = None,
109
+ image_mean: Optional[Union[float, List[float]]] = None,
110
+ image_std: Optional[Union[float, List[float]]] = None,
111
+ return_tensors: Optional[Union[str, TensorType]] = None,
112
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
113
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
114
+ do_convert_rgb: bool = None,
115
+ ) -> PIL.Image.Image:
116
+ """
117
+ Preprocess an image or batch of images.
118
+
119
+ Args:
120
+ images (`ImageInput`):
121
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
122
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
123
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
124
+ Whether to resize the image.
125
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
126
+ Size of the image after resizing.
127
+ resample (`int`, *optional*, defaults to `self.resample`):
128
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
129
+ has an effect if `do_resize` is set to `True`.
130
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
131
+ Whether to rescale the image.
132
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
133
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
134
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
135
+ Whether to normalize the image.
136
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
137
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
138
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
139
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
140
+ `True`.
141
+ return_tensors (`str` or `TensorType`, *optional*):
142
+ The type of tensors to return. Can be one of:
143
+ - Unset: Return a list of `np.ndarray`.
144
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
145
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
146
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
147
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
148
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
149
+ The channel dimension format for the output image. Can be one of:
150
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
151
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
152
+ - Unset: Use the channel dimension format of the input image.
153
+ input_data_format (`ChannelDimension` or `str`, *optional*):
154
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
155
+ from the input image. Can be one of:
156
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
157
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
158
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
159
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
160
+ Whether to convert the image to RGB.
161
+ """
162
+ do_resize = do_resize if do_resize is not None else self.do_resize
163
+ size = size if size is not None else self.size
164
+ size = get_size_dict(size, param_name="size", default_to_square=False)
165
+ resample = resample if resample is not None else self.resample
166
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
167
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
168
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
169
+ image_mean = image_mean if image_mean is not None else self.image_mean
170
+ image_std = image_std if image_std is not None else self.image_std
171
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
172
+
173
+ images = make_list_of_images(images)
174
+
175
+ if not valid_images(images):
176
+ raise ValueError(
177
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
178
+ "torch.Tensor, tf.Tensor or jax.ndarray."
179
+ )
180
+ validate_preprocess_arguments(
181
+ do_rescale=do_rescale,
182
+ rescale_factor=rescale_factor,
183
+ do_normalize=do_normalize,
184
+ image_mean=image_mean,
185
+ image_std=image_std,
186
+ do_resize=do_resize,
187
+ size=size,
188
+ resample=resample,
189
+ )
190
+ # All transformations expect numpy arrays.
191
+ images = [to_numpy_array(image) for image in images]
192
+
193
+ if do_convert_rgb:
194
+ images = [convert_to_rgb(image) for image in images]
195
+
196
+ if is_scaled_image(images[0]) and do_rescale:
197
+ logger.warning_once(
198
+ "It looks like you are trying to rescale already rescaled images. If the input"
199
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
200
+ )
201
+
202
+ if input_data_format is None:
203
+ # We assume that all images have the same channel dimension format.
204
+ input_data_format = infer_channel_dimension_format(images[0])
205
+
206
+ if do_resize:
207
+ height, width = size["height"], size["width"]
208
+ images = [
209
+ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
210
+ for image in images
211
+ ]
212
+
213
+ if do_rescale:
214
+ images = [
215
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
216
+ for image in images
217
+ ]
218
+
219
+ if do_normalize:
220
+ images = [
221
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
222
+ for image in images
223
+ ]
224
+
225
+ images = [
226
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
227
+ ]
228
+
229
+ data = {"pixel_values": images}
230
+ return BatchFeature(data=data, tensor_type=return_tensors)
modeling/siglip/modeling_siglip.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """PyTorch Siglip model."""
5
+
6
+ import math
7
+ import warnings
8
+ from dataclasses import dataclass
9
+ from typing import Any, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import (
23
+ ModelOutput,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ is_flash_attn_2_available,
27
+ is_flash_attn_greater_or_equal_2_10,
28
+ logging,
29
+ replace_return_docstrings,
30
+ torch_int,
31
+ )
32
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
33
+
34
+
35
+ if is_flash_attn_2_available():
36
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ # General docstring
42
+ _CONFIG_FOR_DOC = "SiglipConfig"
43
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
44
+
45
+
46
+ def _trunc_normal_(tensor, mean, std, a, b):
47
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
48
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
49
+ def norm_cdf(x):
50
+ # Computes standard normal cumulative distribution function
51
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
52
+
53
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
54
+ warnings.warn(
55
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
56
+ "The distribution of values may be incorrect.",
57
+ stacklevel=2,
58
+ )
59
+
60
+ # Values are generated by using a truncated uniform distribution and
61
+ # then using the inverse CDF for the normal distribution.
62
+ # Get upper and lower cdf values
63
+ l = norm_cdf((a - mean) / std)
64
+ u = norm_cdf((b - mean) / std)
65
+
66
+ # Uniformly fill tensor with values from [l, u], then translate to
67
+ # [2l-1, 2u-1].
68
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
69
+
70
+ # Use inverse cdf transform for normal distribution to get truncated
71
+ # standard normal
72
+ tensor.erfinv_()
73
+
74
+ # Transform to proper mean, std
75
+ tensor.mul_(std * math.sqrt(2.0))
76
+ tensor.add_(mean)
77
+
78
+ # Clamp to ensure it's in the proper range
79
+ tensor.clamp_(min=a, max=b)
80
+
81
+
82
+ def trunc_normal_tf_(
83
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
84
+ ) -> torch.Tensor:
85
+ """Fills the input Tensor with values drawn from a truncated
86
+ normal distribution. The values are effectively drawn from the
87
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
88
+ with values outside :math:`[a, b]` redrawn until they are within
89
+ the bounds. The method used for generating the random values works
90
+ best when :math:`a \\leq \text{mean} \\leq b`.
91
+
92
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
93
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
94
+ and the result is subsequently scaled and shifted by the mean and std args.
95
+
96
+ Args:
97
+ tensor: an n-dimensional `torch.Tensor`
98
+ mean: the mean of the normal distribution
99
+ std: the standard deviation of the normal distribution
100
+ a: the minimum cutoff value
101
+ b: the maximum cutoff value
102
+ """
103
+ with torch.no_grad():
104
+ _trunc_normal_(tensor, 0, 1.0, a, b)
105
+ tensor.mul_(std).add_(mean)
106
+
107
+
108
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
109
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
110
+ if mode == "fan_in":
111
+ denom = fan_in
112
+ elif mode == "fan_out":
113
+ denom = fan_out
114
+ elif mode == "fan_avg":
115
+ denom = (fan_in + fan_out) / 2
116
+
117
+ variance = scale / denom
118
+
119
+ if distribution == "truncated_normal":
120
+ # constant is stddev of standard normal truncated to (-2, 2)
121
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
122
+ elif distribution == "normal":
123
+ with torch.no_grad():
124
+ tensor.normal_(std=math.sqrt(variance))
125
+ elif distribution == "uniform":
126
+ bound = math.sqrt(3 * variance)
127
+ with torch.no_grad():
128
+ tensor.uniform_(-bound, bound)
129
+ else:
130
+ raise ValueError(f"invalid distribution {distribution}")
131
+
132
+
133
+ def lecun_normal_(tensor):
134
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
135
+
136
+
137
+ def default_flax_embed_init(tensor):
138
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
139
+
140
+
141
+ @dataclass
142
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
143
+ class SiglipVisionModelOutput(ModelOutput):
144
+ """
145
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
146
+
147
+ Args:
148
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
149
+ The image embeddings obtained by applying the projection layer to the pooler_output.
150
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
151
+ Sequence of hidden-states at the output of the last layer of the model.
152
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
153
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
154
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
155
+
156
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
157
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
158
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
159
+ sequence_length)`.
160
+
161
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
162
+ heads.
163
+ """
164
+
165
+ image_embeds: Optional[torch.FloatTensor] = None
166
+ last_hidden_state: torch.FloatTensor = None
167
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
168
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
169
+
170
+
171
+ @dataclass
172
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
173
+ class SiglipTextModelOutput(ModelOutput):
174
+ """
175
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
176
+
177
+ Args:
178
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
179
+ The text embeddings obtained by applying the projection layer to the pooler_output.
180
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
181
+ Sequence of hidden-states at the output of the last layer of the model.
182
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
183
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
184
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
185
+
186
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
187
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
188
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
189
+ sequence_length)`.
190
+
191
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
192
+ heads.
193
+ """
194
+
195
+ text_embeds: Optional[torch.FloatTensor] = None
196
+ last_hidden_state: torch.FloatTensor = None
197
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
198
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
199
+
200
+
201
+ @dataclass
202
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
203
+ class SiglipOutput(ModelOutput):
204
+ """
205
+ Args:
206
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
207
+ Contrastive loss for image-text similarity.
208
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
209
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
210
+ similarity scores.
211
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
212
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
213
+ similarity scores.
214
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
215
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
216
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
217
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
218
+ text_model_output (`BaseModelOutputWithPooling`):
219
+ The output of the [`SiglipTextModel`].
220
+ vision_model_output (`BaseModelOutputWithPooling`):
221
+ The output of the [`SiglipVisionModel`].
222
+ """
223
+
224
+ loss: Optional[torch.FloatTensor] = None
225
+ logits_per_image: torch.FloatTensor = None
226
+ logits_per_text: torch.FloatTensor = None
227
+ text_embeds: torch.FloatTensor = None
228
+ image_embeds: torch.FloatTensor = None
229
+ text_model_output: BaseModelOutputWithPooling = None
230
+ vision_model_output: BaseModelOutputWithPooling = None
231
+
232
+ def to_tuple(self) -> Tuple[Any]:
233
+ return tuple(
234
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
235
+ for k in self.keys()
236
+ )
237
+
238
+
239
+ class SiglipVisionEmbeddings(nn.Module):
240
+ def __init__(self, config: SiglipVisionConfig):
241
+ super().__init__()
242
+ self.config = config
243
+ self.embed_dim = config.hidden_size
244
+ self.image_size = config.image_size
245
+ self.patch_size = config.patch_size
246
+
247
+ self.patch_embedding = nn.Conv2d(
248
+ in_channels=config.num_channels,
249
+ out_channels=self.embed_dim,
250
+ kernel_size=self.patch_size,
251
+ stride=self.patch_size,
252
+ padding="valid",
253
+ )
254
+
255
+ self.num_patches = (self.image_size // self.patch_size) ** 2
256
+ self.num_positions = self.num_patches
257
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
258
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
259
+
260
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
261
+ """
262
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
263
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
264
+
265
+ Adapted from:
266
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
267
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
268
+ """
269
+
270
+ num_patches = embeddings.shape[1]
271
+ num_positions = self.position_embedding.weight.shape[0]
272
+
273
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
274
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
275
+ return self.position_embedding(self.position_ids)
276
+
277
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
278
+
279
+ dim = embeddings.shape[-1]
280
+
281
+ new_height = height // self.patch_size
282
+ new_width = width // self.patch_size
283
+
284
+ sqrt_num_positions = torch_int(num_positions**0.5)
285
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
286
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
287
+
288
+ patch_pos_embed = nn.functional.interpolate(
289
+ patch_pos_embed,
290
+ size=(new_height, new_width),
291
+ mode="bicubic",
292
+ align_corners=False,
293
+ )
294
+
295
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
296
+ return patch_pos_embed
297
+
298
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
299
+ _, _, height, width = pixel_values.shape
300
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
301
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
302
+
303
+ if interpolate_pos_encoding:
304
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
305
+ else:
306
+ embeddings = embeddings + self.position_embedding(self.position_ids)
307
+ return embeddings
308
+
309
+
310
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
311
+ class SiglipTextEmbeddings(nn.Module):
312
+ def __init__(self, config: SiglipTextConfig):
313
+ super().__init__()
314
+ embed_dim = config.hidden_size
315
+
316
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
317
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
318
+
319
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
320
+ self.register_buffer(
321
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ input_ids: Optional[torch.LongTensor] = None,
327
+ position_ids: Optional[torch.LongTensor] = None,
328
+ inputs_embeds: Optional[torch.FloatTensor] = None,
329
+ ) -> torch.Tensor:
330
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
331
+
332
+ if position_ids is None:
333
+ position_ids = self.position_ids[:, :seq_length]
334
+
335
+ if inputs_embeds is None:
336
+ inputs_embeds = self.token_embedding(input_ids)
337
+
338
+ position_embeddings = self.position_embedding(position_ids)
339
+ embeddings = inputs_embeds + position_embeddings
340
+
341
+ return embeddings
342
+
343
+
344
+ class SiglipAttention(nn.Module):
345
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
346
+
347
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.config = config
351
+ self.embed_dim = config.hidden_size
352
+ self.num_heads = config.num_attention_heads
353
+ self.head_dim = self.embed_dim // self.num_heads
354
+ if self.head_dim * self.num_heads != self.embed_dim:
355
+ raise ValueError(
356
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
357
+ f" {self.num_heads})."
358
+ )
359
+ self.scale = self.head_dim**-0.5
360
+ self.dropout = config.attention_dropout
361
+
362
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
363
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
364
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
365
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.Tensor,
370
+ attention_mask: Optional[torch.Tensor] = None,
371
+ output_attentions: Optional[bool] = False,
372
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
373
+ """Input shape: Batch x Time x Channel"""
374
+
375
+ batch_size, q_len, _ = hidden_states.size()
376
+
377
+ query_states = self.q_proj(hidden_states)
378
+ key_states = self.k_proj(hidden_states)
379
+ value_states = self.v_proj(hidden_states)
380
+
381
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
382
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
383
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+
385
+ k_v_seq_len = key_states.shape[-2]
386
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
387
+
388
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
389
+ raise ValueError(
390
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
391
+ f" {attn_weights.size()}"
392
+ )
393
+
394
+ if attention_mask is not None:
395
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
396
+ raise ValueError(
397
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
398
+ )
399
+ attn_weights = attn_weights + attention_mask
400
+
401
+ # upcast attention to fp32
402
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
403
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
404
+ attn_output = torch.matmul(attn_weights, value_states)
405
+
406
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
407
+ raise ValueError(
408
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
409
+ f" {attn_output.size()}"
410
+ )
411
+
412
+ attn_output = attn_output.transpose(1, 2).contiguous()
413
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
414
+
415
+ attn_output = self.out_proj(attn_output)
416
+
417
+ return attn_output, attn_weights
418
+
419
+
420
+ class SiglipFlashAttention2(SiglipAttention):
421
+ """
422
+ SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
423
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
424
+ flash attention and deal with padding tokens in case the input contains any of them.
425
+ """
426
+
427
+ is_causal = False
428
+
429
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
430
+ def __init__(self, *args, **kwargs):
431
+ super().__init__(*args, **kwargs)
432
+
433
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
434
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
435
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
436
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
437
+
438
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: Optional[torch.LongTensor] = None,
443
+ output_attentions: bool = False,
444
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
445
+ output_attentions = False
446
+
447
+ batch_size, q_len, _ = hidden_states.size()
448
+
449
+ query_states = self.q_proj(hidden_states)
450
+ key_states = self.k_proj(hidden_states)
451
+ value_states = self.v_proj(hidden_states)
452
+
453
+ # Flash attention requires the input to have the shape
454
+ # batch_size x seq_length x head_dim x hidden_dim
455
+ # therefore we just need to keep the original shape
456
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
457
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
458
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
459
+
460
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
461
+ # to be able to avoid many of these transpose/reshape/view.
462
+ query_states = query_states.transpose(1, 2)
463
+ key_states = key_states.transpose(1, 2)
464
+ value_states = value_states.transpose(1, 2)
465
+
466
+ dropout_rate = self.dropout if self.training else 0.0
467
+
468
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
469
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
470
+ # cast them back in the correct dtype just to be sure everything works as expected.
471
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
472
+ # in fp32.
473
+
474
+ input_dtype = query_states.dtype
475
+ if input_dtype == torch.float32:
476
+ if torch.is_autocast_enabled():
477
+ target_dtype = torch.get_autocast_gpu_dtype()
478
+ # Handle the case where the model is quantized
479
+ elif hasattr(self.config, "_pre_quantization_dtype"):
480
+ target_dtype = self.config._pre_quantization_dtype
481
+ else:
482
+ target_dtype = self.q_proj.weight.dtype
483
+
484
+ logger.warning_once(
485
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
486
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
487
+ f" {target_dtype}."
488
+ )
489
+
490
+ query_states = query_states.to(target_dtype)
491
+ key_states = key_states.to(target_dtype)
492
+ value_states = value_states.to(target_dtype)
493
+
494
+ attn_output = _flash_attention_forward(
495
+ query_states,
496
+ key_states,
497
+ value_states,
498
+ attention_mask,
499
+ q_len,
500
+ dropout=dropout_rate,
501
+ is_causal=self.is_causal,
502
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
503
+ )
504
+
505
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
506
+ attn_output = self.out_proj(attn_output)
507
+
508
+ if not output_attentions:
509
+ attn_weights = None
510
+
511
+ return attn_output, attn_weights
512
+
513
+
514
+ class SiglipSdpaAttention(SiglipAttention):
515
+ """
516
+ Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
517
+ `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
518
+ SDPA API.
519
+ """
520
+
521
+ is_causal = False
522
+
523
+ # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.Tensor,
527
+ attention_mask: Optional[torch.Tensor] = None,
528
+ output_attentions: Optional[bool] = False,
529
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
530
+ if output_attentions:
531
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
532
+ logger.warning_once(
533
+ "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
534
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
535
+ )
536
+ return super().forward(
537
+ hidden_states=hidden_states,
538
+ attention_mask=attention_mask,
539
+ output_attentions=output_attentions,
540
+ )
541
+
542
+ batch_size, q_len, _ = hidden_states.size()
543
+
544
+ query_states = self.q_proj(hidden_states)
545
+ key_states = self.k_proj(hidden_states)
546
+ value_states = self.v_proj(hidden_states)
547
+
548
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
549
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
550
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
551
+
552
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
553
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
554
+ if query_states.device.type == "cuda" and attention_mask is not None:
555
+ query_states = query_states.contiguous()
556
+ key_states = key_states.contiguous()
557
+ value_states = value_states.contiguous()
558
+
559
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
560
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
561
+ is_causal = True if self.is_causal and q_len > 1 else False
562
+
563
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
564
+ query_states,
565
+ key_states,
566
+ value_states,
567
+ attn_mask=attention_mask,
568
+ dropout_p=self.dropout if self.training else 0.0,
569
+ is_causal=is_causal,
570
+ )
571
+
572
+ attn_output = attn_output.transpose(1, 2).contiguous()
573
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
574
+
575
+ attn_output = self.out_proj(attn_output)
576
+
577
+ return attn_output, None
578
+
579
+
580
+ SIGLIP_ATTENTION_CLASSES = {
581
+ "eager": SiglipAttention,
582
+ "flash_attention_2": SiglipFlashAttention2,
583
+ "sdpa": SiglipSdpaAttention,
584
+ }
585
+
586
+
587
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
588
+ class SiglipMLP(nn.Module):
589
+ def __init__(self, config):
590
+ super().__init__()
591
+ self.config = config
592
+ self.activation_fn = ACT2FN[config.hidden_act]
593
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
594
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
595
+
596
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
597
+ hidden_states = self.fc1(hidden_states)
598
+ hidden_states = self.activation_fn(hidden_states)
599
+ hidden_states = self.fc2(hidden_states)
600
+ return hidden_states
601
+
602
+
603
+ class SiglipEncoderLayer(nn.Module):
604
+ def __init__(self, config: SiglipConfig):
605
+ super().__init__()
606
+ self.embed_dim = config.hidden_size
607
+ self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
608
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
609
+ self.mlp = SiglipMLP(config)
610
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
611
+
612
+ # Ignore copy
613
+ def forward(
614
+ self,
615
+ hidden_states: torch.Tensor,
616
+ attention_mask: torch.Tensor,
617
+ output_attentions: Optional[bool] = False,
618
+ ) -> Tuple[torch.FloatTensor]:
619
+ """
620
+ Args:
621
+ hidden_states (`torch.FloatTensor`):
622
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
623
+ attention_mask (`torch.FloatTensor`):
624
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
625
+ output_attentions (`bool`, *optional*, defaults to `False`):
626
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
627
+ returned tensors for more detail.
628
+ """
629
+ residual = hidden_states
630
+
631
+ hidden_states = self.layer_norm1(hidden_states)
632
+ hidden_states, attn_weights = self.self_attn(
633
+ hidden_states=hidden_states,
634
+ attention_mask=attention_mask,
635
+ output_attentions=output_attentions,
636
+ )
637
+ hidden_states = residual + hidden_states
638
+
639
+ residual = hidden_states
640
+ hidden_states = self.layer_norm2(hidden_states)
641
+ hidden_states = self.mlp(hidden_states)
642
+ hidden_states = residual + hidden_states
643
+
644
+ outputs = (hidden_states,)
645
+
646
+ if output_attentions:
647
+ outputs += (attn_weights,)
648
+
649
+ return outputs
650
+
651
+
652
+ class SiglipPreTrainedModel(PreTrainedModel):
653
+ """
654
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
655
+ models.
656
+ """
657
+
658
+ config_class = SiglipConfig
659
+ base_model_prefix = "siglip"
660
+ supports_gradient_checkpointing = True
661
+
662
+ _no_split_modules = [
663
+ "SiglipTextEmbeddings",
664
+ "SiglipEncoderLayer",
665
+ "SiglipVisionEmbeddings",
666
+ "SiglipEncoderLayer",
667
+ "SiglipMultiheadAttentionPoolingHead",
668
+ ]
669
+ _supports_flash_attn_2 = True
670
+ _supports_sdpa = True
671
+
672
+ def _init_weights(self, module):
673
+ """Initialize the weights"""
674
+ if isinstance(module, SiglipVisionEmbeddings):
675
+ width = (
676
+ self.config.vision_config.hidden_size
677
+ if isinstance(self.config, SiglipConfig)
678
+ else self.config.hidden_size
679
+ )
680
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
681
+ elif isinstance(module, nn.Embedding):
682
+ default_flax_embed_init(module.weight)
683
+ elif isinstance(module, SiglipAttention):
684
+ nn.init.xavier_uniform_(module.q_proj.weight)
685
+ nn.init.xavier_uniform_(module.k_proj.weight)
686
+ nn.init.xavier_uniform_(module.v_proj.weight)
687
+ nn.init.xavier_uniform_(module.out_proj.weight)
688
+ nn.init.zeros_(module.q_proj.bias)
689
+ nn.init.zeros_(module.k_proj.bias)
690
+ nn.init.zeros_(module.v_proj.bias)
691
+ nn.init.zeros_(module.out_proj.bias)
692
+ elif isinstance(module, SiglipMLP):
693
+ nn.init.xavier_uniform_(module.fc1.weight)
694
+ nn.init.xavier_uniform_(module.fc2.weight)
695
+ nn.init.normal_(module.fc1.bias, std=1e-6)
696
+ nn.init.normal_(module.fc2.bias, std=1e-6)
697
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
698
+ nn.init.xavier_uniform_(module.probe.data)
699
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
700
+ nn.init.zeros_(module.attention.in_proj_bias.data)
701
+ elif isinstance(module, SiglipModel):
702
+ logit_scale_init = torch.log(torch.tensor(1.0))
703
+ module.logit_scale.data.fill_(logit_scale_init)
704
+ module.logit_bias.data.zero_()
705
+ elif isinstance(module, SiglipForImageClassification):
706
+ nn.init.normal_(
707
+ module.classifier.weight,
708
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
709
+ )
710
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
711
+ lecun_normal_(module.weight)
712
+ if module.bias is not None:
713
+ nn.init.zeros_(module.bias)
714
+ elif isinstance(module, nn.LayerNorm):
715
+ module.bias.data.zero_()
716
+ module.weight.data.fill_(1.0)
717
+
718
+
719
+ SIGLIP_START_DOCSTRING = r"""
720
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
721
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
722
+ etc.)
723
+
724
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
725
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
726
+ and behavior.
727
+
728
+ Parameters:
729
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
730
+ Initializing with a config file does not load the weights associated with the model, only the
731
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
732
+ """
733
+
734
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
735
+ Args:
736
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
737
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
738
+ it.
739
+
740
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
741
+ [`PreTrainedTokenizer.__call__`] for details.
742
+
743
+ [What are input IDs?](../glossary#input-ids)
744
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
746
+
747
+ - 1 for tokens that are **not masked**,
748
+ - 0 for tokens that are **masked**.
749
+
750
+ [What are attention masks?](../glossary#attention-mask)
751
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
752
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
753
+ config.max_position_embeddings - 1]`.
754
+
755
+ [What are position IDs?](../glossary#position-ids)
756
+ output_attentions (`bool`, *optional*):
757
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
758
+ tensors for more detail.
759
+ output_hidden_states (`bool`, *optional*):
760
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
761
+ more detail.
762
+ return_dict (`bool`, *optional*):
763
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
764
+ """
765
+
766
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
767
+ Args:
768
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
769
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
770
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
771
+ output_attentions (`bool`, *optional*):
772
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
773
+ tensors for more detail.
774
+ output_hidden_states (`bool`, *optional*):
775
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
776
+ more detail.
777
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
778
+ Whether to interpolate the pre-trained position encodings.
779
+ return_dict (`bool`, *optional*):
780
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
781
+ """
782
+
783
+ SIGLIP_INPUTS_DOCSTRING = r"""
784
+ Args:
785
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
786
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
787
+ it.
788
+
789
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
790
+ [`PreTrainedTokenizer.__call__`] for details.
791
+
792
+ [What are input IDs?](../glossary#input-ids)
793
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
795
+
796
+ - 1 for tokens that are **not masked**,
797
+ - 0 for tokens that are **masked**.
798
+
799
+ [What are attention masks?](../glossary#attention-mask)
800
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
801
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
802
+ config.max_position_embeddings - 1]`.
803
+
804
+ [What are position IDs?](../glossary#position-ids)
805
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
806
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
807
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
808
+ return_loss (`bool`, *optional*):
809
+ Whether or not to return the contrastive loss.
810
+ output_attentions (`bool`, *optional*):
811
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
812
+ tensors for more detail.
813
+ output_hidden_states (`bool`, *optional*):
814
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
815
+ more detail.
816
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
817
+ Whether to interpolate the pre-trained position encodings.
818
+ return_dict (`bool`, *optional*):
819
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
820
+ """
821
+
822
+
823
+ # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
824
+ class SiglipEncoder(nn.Module):
825
+ """
826
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
827
+ [`SiglipEncoderLayer`].
828
+
829
+ Args:
830
+ config: SiglipConfig
831
+ """
832
+
833
+ def __init__(self, config: SiglipConfig):
834
+ super().__init__()
835
+ self.config = config
836
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
837
+ self.gradient_checkpointing = False
838
+
839
+ # Ignore copy
840
+ def forward(
841
+ self,
842
+ inputs_embeds,
843
+ attention_mask: Optional[torch.Tensor] = None,
844
+ output_attentions: Optional[bool] = None,
845
+ output_hidden_states: Optional[bool] = None,
846
+ return_dict: Optional[bool] = None,
847
+ ) -> Union[Tuple, BaseModelOutput]:
848
+ r"""
849
+ Args:
850
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
851
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
852
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
853
+ than the model's internal embedding lookup matrix.
854
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
855
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
856
+
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+
860
+ [What are attention masks?](../glossary#attention-mask)
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
863
+ returned tensors for more detail.
864
+ output_hidden_states (`bool`, *optional*):
865
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
866
+ for more detail.
867
+ return_dict (`bool`, *optional*):
868
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
869
+ """
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ encoder_states = () if output_hidden_states else None
877
+ all_attentions = () if output_attentions else None
878
+
879
+ hidden_states = inputs_embeds
880
+ for encoder_layer in self.layers:
881
+ if output_hidden_states:
882
+ encoder_states = encoder_states + (hidden_states,)
883
+ if self.gradient_checkpointing and self.training:
884
+ layer_outputs = self._gradient_checkpointing_func(
885
+ encoder_layer.__call__,
886
+ hidden_states,
887
+ attention_mask,
888
+ output_attentions,
889
+ )
890
+ else:
891
+ layer_outputs = encoder_layer(
892
+ hidden_states,
893
+ attention_mask,
894
+ output_attentions=output_attentions,
895
+ )
896
+
897
+ hidden_states = layer_outputs[0]
898
+
899
+ if output_attentions:
900
+ all_attentions = all_attentions + (layer_outputs[1],)
901
+
902
+ if output_hidden_states:
903
+ encoder_states = encoder_states + (hidden_states,)
904
+
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
907
+ return BaseModelOutput(
908
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
909
+ )
910
+
911
+
912
+ class SiglipTextTransformer(nn.Module):
913
+ def __init__(self, config: SiglipTextConfig):
914
+ super().__init__()
915
+ self.config = config
916
+ embed_dim = config.hidden_size
917
+ self.embeddings = SiglipTextEmbeddings(config)
918
+ self.encoder = SiglipEncoder(config)
919
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
920
+
921
+ self.head = nn.Linear(embed_dim, embed_dim)
922
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
923
+
924
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
925
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
926
+ def forward(
927
+ self,
928
+ input_ids: Optional[torch.Tensor] = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ position_ids: Optional[torch.Tensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
935
+ r"""
936
+ Returns:
937
+
938
+ """
939
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
940
+ output_hidden_states = (
941
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
942
+ )
943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
+
945
+ if input_ids is None:
946
+ raise ValueError("You have to specify input_ids")
947
+
948
+ input_shape = input_ids.size()
949
+ input_ids = input_ids.view(-1, input_shape[-1])
950
+
951
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
952
+
953
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
954
+ # expand attention_mask
955
+ if attention_mask is not None and not self._use_flash_attention_2:
956
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
957
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
958
+
959
+ encoder_outputs = self.encoder(
960
+ inputs_embeds=hidden_states,
961
+ attention_mask=attention_mask,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ return_dict=return_dict,
965
+ )
966
+
967
+ last_hidden_state = encoder_outputs[0]
968
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
969
+
970
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
971
+ pooled_output = last_hidden_state[:, -1, :]
972
+ pooled_output = self.head(pooled_output)
973
+
974
+ if not return_dict:
975
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
976
+
977
+ return BaseModelOutputWithPooling(
978
+ last_hidden_state=last_hidden_state,
979
+ pooler_output=pooled_output,
980
+ hidden_states=encoder_outputs.hidden_states,
981
+ attentions=encoder_outputs.attentions,
982
+ )
983
+
984
+
985
+ @add_start_docstrings(
986
+ """The text model from SigLIP without any head or projection on top.""",
987
+ SIGLIP_START_DOCSTRING,
988
+ )
989
+ class SiglipTextModel(SiglipPreTrainedModel):
990
+ config_class = SiglipTextConfig
991
+
992
+ def __init__(self, config: SiglipTextConfig):
993
+ super().__init__(config)
994
+ self.text_model = SiglipTextTransformer(config)
995
+ # Initialize weights and apply final processing
996
+ self.post_init()
997
+
998
+ def get_input_embeddings(self) -> nn.Module:
999
+ return self.text_model.embeddings.token_embedding
1000
+
1001
+ def set_input_embeddings(self, value):
1002
+ self.text_model.embeddings.token_embedding = value
1003
+
1004
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1005
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1006
+ def forward(
1007
+ self,
1008
+ input_ids: Optional[torch.Tensor] = None,
1009
+ attention_mask: Optional[torch.Tensor] = None,
1010
+ position_ids: Optional[torch.Tensor] = None,
1011
+ output_attentions: Optional[bool] = None,
1012
+ output_hidden_states: Optional[bool] = None,
1013
+ return_dict: Optional[bool] = None,
1014
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1015
+ r"""
1016
+ Returns:
1017
+
1018
+ Examples:
1019
+
1020
+ ```python
1021
+ >>> from transformers import AutoTokenizer, SiglipTextModel
1022
+
1023
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1024
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1025
+
1026
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1027
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1028
+
1029
+ >>> outputs = model(**inputs)
1030
+ >>> last_hidden_state = outputs.last_hidden_state
1031
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
1032
+ ```"""
1033
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034
+
1035
+ return self.text_model(
1036
+ input_ids=input_ids,
1037
+ attention_mask=attention_mask,
1038
+ position_ids=position_ids,
1039
+ output_attentions=output_attentions,
1040
+ output_hidden_states=output_hidden_states,
1041
+ return_dict=return_dict,
1042
+ )
1043
+
1044
+
1045
+ class SiglipVisionTransformer(nn.Module):
1046
+ def __init__(self, config: SiglipVisionConfig):
1047
+ super().__init__()
1048
+ self.config = config
1049
+ embed_dim = config.hidden_size
1050
+
1051
+ self.embeddings = SiglipVisionEmbeddings(config)
1052
+ self.encoder = SiglipEncoder(config)
1053
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1054
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
1055
+ if self.use_head:
1056
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1057
+
1058
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1059
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1060
+ def forward(
1061
+ self,
1062
+ pixel_values,
1063
+ output_attentions: Optional[bool] = None,
1064
+ output_hidden_states: Optional[bool] = None,
1065
+ return_dict: Optional[bool] = None,
1066
+ interpolate_pos_encoding: Optional[bool] = False,
1067
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1068
+ r"""
1069
+ Returns:
1070
+
1071
+ """
1072
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1073
+ output_hidden_states = (
1074
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1075
+ )
1076
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1077
+
1078
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1079
+
1080
+ encoder_outputs = self.encoder(
1081
+ inputs_embeds=hidden_states,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ return_dict=return_dict,
1085
+ )
1086
+
1087
+ last_hidden_state = encoder_outputs[0]
1088
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1089
+
1090
+ pooler_output = self.head(last_hidden_state) if self.use_head else None
1091
+ if not return_dict:
1092
+ return (last_hidden_state, pooler_output) + encoder_outputs[1:]
1093
+
1094
+ return BaseModelOutputWithPooling(
1095
+ last_hidden_state=last_hidden_state,
1096
+ pooler_output=pooler_output,
1097
+ hidden_states=encoder_outputs.hidden_states,
1098
+ attentions=encoder_outputs.attentions,
1099
+ )
1100
+
1101
+
1102
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1103
+ """Multihead Attention Pooling."""
1104
+
1105
+ def __init__(self, config: SiglipVisionConfig):
1106
+ super().__init__()
1107
+
1108
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1109
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1110
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1111
+ self.mlp = SiglipMLP(config)
1112
+
1113
+ def forward(self, hidden_state):
1114
+ batch_size = hidden_state.shape[0]
1115
+ probe = self.probe.repeat(batch_size, 1, 1)
1116
+
1117
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
1118
+
1119
+ residual = hidden_state
1120
+ hidden_state = self.layernorm(hidden_state)
1121
+ hidden_state = residual + self.mlp(hidden_state)
1122
+
1123
+ return hidden_state[:, 0]
1124
+
1125
+
1126
+ @add_start_docstrings(
1127
+ """The vision model from SigLIP without any head or projection on top.""",
1128
+ SIGLIP_START_DOCSTRING,
1129
+ )
1130
+ class SiglipVisionModel(SiglipPreTrainedModel):
1131
+ config_class = SiglipVisionConfig
1132
+ main_input_name = "pixel_values"
1133
+
1134
+ def __init__(self, config: SiglipVisionConfig):
1135
+ super().__init__(config)
1136
+
1137
+ self.vision_model = SiglipVisionTransformer(config)
1138
+
1139
+ # Initialize weights and apply final processing
1140
+ self.post_init()
1141
+
1142
+ def get_input_embeddings(self) -> nn.Module:
1143
+ return self.vision_model.embeddings.patch_embedding
1144
+
1145
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1146
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1147
+ def forward(
1148
+ self,
1149
+ pixel_values,
1150
+ output_attentions: Optional[bool] = None,
1151
+ output_hidden_states: Optional[bool] = None,
1152
+ return_dict: Optional[bool] = None,
1153
+ interpolate_pos_encoding: bool = False,
1154
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1155
+ r"""
1156
+ Returns:
1157
+
1158
+ Examples:
1159
+
1160
+ ```python
1161
+ >>> from PIL import Image
1162
+ >>> import requests
1163
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1164
+
1165
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1166
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1167
+
1168
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1169
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1170
+
1171
+ >>> inputs = processor(images=image, return_tensors="pt")
1172
+
1173
+ >>> outputs = model(**inputs)
1174
+ >>> last_hidden_state = outputs.last_hidden_state
1175
+ >>> pooled_output = outputs.pooler_output # pooled features
1176
+ ```"""
1177
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1178
+
1179
+ return self.vision_model(
1180
+ pixel_values=pixel_values,
1181
+ output_attentions=output_attentions,
1182
+ output_hidden_states=output_hidden_states,
1183
+ return_dict=return_dict,
1184
+ interpolate_pos_encoding=interpolate_pos_encoding,
1185
+ )
1186
+
1187
+
1188
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1189
+ class SiglipModel(SiglipPreTrainedModel):
1190
+ config_class = SiglipConfig
1191
+
1192
+ def __init__(self, config: SiglipConfig):
1193
+ super().__init__(config)
1194
+
1195
+ if not isinstance(config.text_config, SiglipTextConfig):
1196
+ raise TypeError(
1197
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1198
+ f" {type(config.text_config)}."
1199
+ )
1200
+
1201
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1202
+ raise TypeError(
1203
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1204
+ f" {type(config.vision_config)}."
1205
+ )
1206
+
1207
+ text_config = config.text_config
1208
+ vision_config = config.vision_config
1209
+
1210
+ # First, initialize the text and vision models with proper attention implementation
1211
+ text_model = SiglipTextModel._from_config(text_config)
1212
+ vision_model = SiglipVisionModel._from_config(vision_config)
1213
+
1214
+ # Second, get the text and vision submodules (for backward compatibility)
1215
+ self.text_model = text_model.text_model
1216
+ self.vision_model = vision_model.vision_model
1217
+
1218
+ self.logit_scale = nn.Parameter(torch.randn(1))
1219
+ self.logit_bias = nn.Parameter(torch.randn(1))
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1225
+ def get_text_features(
1226
+ self,
1227
+ input_ids: Optional[torch.Tensor] = None,
1228
+ attention_mask: Optional[torch.Tensor] = None,
1229
+ position_ids: Optional[torch.Tensor] = None,
1230
+ output_attentions: Optional[bool] = None,
1231
+ output_hidden_states: Optional[bool] = None,
1232
+ return_dict: Optional[bool] = None,
1233
+ ) -> torch.FloatTensor:
1234
+ r"""
1235
+ Returns:
1236
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1237
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1238
+
1239
+ Examples:
1240
+
1241
+ ```python
1242
+ >>> from transformers import AutoTokenizer, AutoModel
1243
+ >>> import torch
1244
+
1245
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1246
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1247
+
1248
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1249
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1250
+ >>> with torch.no_grad():
1251
+ ... text_features = model.get_text_features(**inputs)
1252
+ ```"""
1253
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1254
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1255
+ output_hidden_states = (
1256
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1257
+ )
1258
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1259
+
1260
+ text_outputs = self.text_model(
1261
+ input_ids=input_ids,
1262
+ attention_mask=attention_mask,
1263
+ position_ids=position_ids,
1264
+ output_attentions=output_attentions,
1265
+ output_hidden_states=output_hidden_states,
1266
+ return_dict=return_dict,
1267
+ )
1268
+
1269
+ pooled_output = text_outputs[1]
1270
+
1271
+ return pooled_output
1272
+
1273
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1274
+ def get_image_features(
1275
+ self,
1276
+ pixel_values: Optional[torch.FloatTensor] = None,
1277
+ output_attentions: Optional[bool] = None,
1278
+ output_hidden_states: Optional[bool] = None,
1279
+ return_dict: Optional[bool] = None,
1280
+ interpolate_pos_encoding: bool = False,
1281
+ ) -> torch.FloatTensor:
1282
+ r"""
1283
+ Returns:
1284
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1285
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1286
+
1287
+ Examples:
1288
+
1289
+ ```python
1290
+ >>> from PIL import Image
1291
+ >>> import requests
1292
+ >>> from transformers import AutoProcessor, AutoModel
1293
+ >>> import torch
1294
+
1295
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1296
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1297
+
1298
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1299
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1300
+
1301
+ >>> inputs = processor(images=image, return_tensors="pt")
1302
+
1303
+ >>> with torch.no_grad():
1304
+ ... image_features = model.get_image_features(**inputs)
1305
+ ```"""
1306
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1307
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1308
+ output_hidden_states = (
1309
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1310
+ )
1311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1312
+
1313
+ vision_outputs = self.vision_model(
1314
+ pixel_values=pixel_values,
1315
+ output_attentions=output_attentions,
1316
+ output_hidden_states=output_hidden_states,
1317
+ return_dict=return_dict,
1318
+ interpolate_pos_encoding=interpolate_pos_encoding,
1319
+ )
1320
+
1321
+ pooled_output = vision_outputs[1]
1322
+
1323
+ return pooled_output
1324
+
1325
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1326
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1327
+ def forward(
1328
+ self,
1329
+ input_ids: Optional[torch.LongTensor] = None,
1330
+ pixel_values: Optional[torch.FloatTensor] = None,
1331
+ attention_mask: Optional[torch.Tensor] = None,
1332
+ position_ids: Optional[torch.LongTensor] = None,
1333
+ return_loss: Optional[bool] = None,
1334
+ output_attentions: Optional[bool] = None,
1335
+ output_hidden_states: Optional[bool] = None,
1336
+ return_dict: Optional[bool] = None,
1337
+ interpolate_pos_encoding: bool = False,
1338
+ ) -> Union[Tuple, SiglipOutput]:
1339
+ r"""
1340
+ Returns:
1341
+
1342
+ Examples:
1343
+
1344
+ ```python
1345
+ >>> from PIL import Image
1346
+ >>> import requests
1347
+ >>> from transformers import AutoProcessor, AutoModel
1348
+ >>> import torch
1349
+
1350
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1351
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1352
+
1353
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1354
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1355
+
1356
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1357
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1358
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1359
+
1360
+ >>> with torch.no_grad():
1361
+ ... outputs = model(**inputs)
1362
+
1363
+ >>> logits_per_image = outputs.logits_per_image
1364
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1365
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1366
+ 31.9% that image 0 is 'a photo of 2 cats'
1367
+ ```"""
1368
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1370
+ output_hidden_states = (
1371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1372
+ )
1373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374
+
1375
+ vision_outputs = self.vision_model(
1376
+ pixel_values=pixel_values,
1377
+ output_attentions=output_attentions,
1378
+ output_hidden_states=output_hidden_states,
1379
+ return_dict=return_dict,
1380
+ interpolate_pos_encoding=interpolate_pos_encoding,
1381
+ )
1382
+
1383
+ text_outputs = self.text_model(
1384
+ input_ids=input_ids,
1385
+ attention_mask=attention_mask,
1386
+ position_ids=position_ids,
1387
+ output_attentions=output_attentions,
1388
+ output_hidden_states=output_hidden_states,
1389
+ return_dict=return_dict,
1390
+ )
1391
+
1392
+ image_embeds = vision_outputs[1]
1393
+ text_embeds = text_outputs[1]
1394
+
1395
+ # normalized features
1396
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1397
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1398
+
1399
+ # cosine similarity as logits
1400
+ logits_per_text = (
1401
+ torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
1402
+ + self.logit_bias
1403
+ )
1404
+ logits_per_image = logits_per_text.t()
1405
+
1406
+ loss = None
1407
+ if return_loss:
1408
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
1409
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
1410
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
1411
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
1412
+ nll = -torch.sum(loglik, dim=-1)
1413
+ loss = nll.mean()
1414
+
1415
+ if not return_dict:
1416
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1417
+ return ((loss,) + output) if loss is not None else output
1418
+
1419
+ return SiglipOutput(
1420
+ loss=loss,
1421
+ logits_per_image=logits_per_image,
1422
+ logits_per_text=logits_per_text,
1423
+ text_embeds=text_embeds,
1424
+ image_embeds=image_embeds,
1425
+ text_model_output=text_outputs,
1426
+ vision_model_output=vision_outputs,
1427
+ )
1428
+
1429
+
1430
+ @add_start_docstrings(
1431
+ """
1432
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1433
+ the patch tokens) e.g. for ImageNet.
1434
+ """,
1435
+ SIGLIP_START_DOCSTRING,
1436
+ )
1437
+ class SiglipForImageClassification(SiglipPreTrainedModel):
1438
+ main_input_name = "pixel_values"
1439
+
1440
+ def __init__(self, config: SiglipConfig) -> None:
1441
+ super().__init__(config)
1442
+
1443
+ self.num_labels = config.num_labels
1444
+
1445
+ # Create the vision model with proper attention
1446
+ # and take only vision_model submodule (for backward compatibility)
1447
+ vision_model = SiglipVisionModel._from_config(config.vision_config)
1448
+ self.vision_model = vision_model.vision_model
1449
+
1450
+ # Classifier head
1451
+ self.classifier = (
1452
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1453
+ )
1454
+
1455
+ # Initialize weights and apply final processing
1456
+ self.post_init()
1457
+
1458
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1459
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
1460
+ def forward(
1461
+ self,
1462
+ pixel_values: Optional[torch.Tensor] = None,
1463
+ labels: Optional[torch.Tensor] = None,
1464
+ output_attentions: Optional[bool] = None,
1465
+ output_hidden_states: Optional[bool] = None,
1466
+ return_dict: Optional[bool] = None,
1467
+ interpolate_pos_encoding: bool = False,
1468
+ ) -> Union[tuple, ImageClassifierOutput]:
1469
+ r"""
1470
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1471
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1472
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1473
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1474
+
1475
+ Returns:
1476
+
1477
+ Examples:
1478
+
1479
+ ```python
1480
+ >>> from transformers import AutoImageProcessor, SiglipForImageClassification
1481
+ >>> import torch
1482
+ >>> from PIL import Image
1483
+ >>> import requests
1484
+
1485
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
1486
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1487
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1488
+
1489
+ >>> # note: we are loading a `SiglipModel` from the hub here,
1490
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
1491
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
1492
+ >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
1493
+
1494
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1495
+ >>> outputs = model(**inputs)
1496
+ >>> logits = outputs.logits
1497
+ >>> # model predicts one of the two classes
1498
+ >>> predicted_class_idx = logits.argmax(-1).item()
1499
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
1500
+ Predicted class: LABEL_1
1501
+ ```"""
1502
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1503
+ output_hidden_states = (
1504
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1505
+ )
1506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1507
+
1508
+ outputs = self.vision_model(
1509
+ pixel_values,
1510
+ output_attentions=output_attentions,
1511
+ output_hidden_states=output_hidden_states,
1512
+ return_dict=return_dict,
1513
+ interpolate_pos_encoding=interpolate_pos_encoding,
1514
+ )
1515
+
1516
+ sequence_output = outputs[0]
1517
+
1518
+ # average pool the patch tokens
1519
+ sequence_output = torch.mean(sequence_output, dim=1)
1520
+ # apply classifier
1521
+ logits = self.classifier(sequence_output)
1522
+
1523
+ loss = None
1524
+ if labels is not None:
1525
+ # move labels to correct device to enable model parallelism
1526
+ labels = labels.to(logits.device)
1527
+ if self.config.problem_type is None:
1528
+ if self.num_labels == 1:
1529
+ self.config.problem_type = "regression"
1530
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1531
+ self.config.problem_type = "single_label_classification"
1532
+ else:
1533
+ self.config.problem_type = "multi_label_classification"
1534
+
1535
+ if self.config.problem_type == "regression":
1536
+ loss_fct = MSELoss()
1537
+ if self.num_labels == 1:
1538
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1539
+ else:
1540
+ loss = loss_fct(logits, labels)
1541
+ elif self.config.problem_type == "single_label_classification":
1542
+ loss_fct = CrossEntropyLoss()
1543
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1544
+ elif self.config.problem_type == "multi_label_classification":
1545
+ loss_fct = BCEWithLogitsLoss()
1546
+ loss = loss_fct(logits, labels)
1547
+
1548
+ if not return_dict:
1549
+ output = (logits,) + outputs[2:]
1550
+ return ((loss,) + output) if loss is not None else output
1551
+
1552
+ return ImageClassifierOutput(
1553
+ loss=loss,
1554
+ logits=logits,
1555
+ hidden_states=outputs.hidden_states,
1556
+ attentions=outputs.attentions,
1557
+ )
modeling/siglip/processing_siglip.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Image/Text processor class for SigLIP.
6
+ """
7
+
8
+ from typing import List, Optional, Union
9
+
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import ImageInput
12
+ from transformers.processing_utils import ProcessorMixin
13
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
14
+ from transformers.utils import TensorType
15
+
16
+
17
+ class SiglipProcessor(ProcessorMixin):
18
+ r"""
19
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
20
+
21
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
22
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
23
+
24
+ Args:
25
+ image_processor ([`SiglipImageProcessor`]):
26
+ The image processor is a required input.
27
+ tokenizer ([`SiglipTokenizer`]):
28
+ The tokenizer is a required input.
29
+ """
30
+
31
+ attributes = ["image_processor", "tokenizer"]
32
+ image_processor_class = "SiglipImageProcessor"
33
+ tokenizer_class = "SiglipTokenizer"
34
+
35
+ def __init__(self, image_processor, tokenizer):
36
+ super().__init__(image_processor, tokenizer)
37
+
38
+ def __call__(
39
+ self,
40
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
41
+ images: ImageInput = None,
42
+ padding: Union[bool, str, PaddingStrategy] = False,
43
+ truncation: Union[bool, str, TruncationStrategy] = None,
44
+ max_length: int = None,
45
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
46
+ ) -> BatchFeature:
47
+ """
48
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
49
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
50
+ the text. To prepare the image(s), this method forwards the `images` argument to
51
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
52
+ of the above two methods for more information.
53
+
54
+ Args:
55
+ text (`str`, `List[str]`, `List[List[str]]`):
56
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
57
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
58
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
59
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
60
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
61
+ tensor. Both channels-first and channels-last formats are supported.
62
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
63
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
64
+ index) among:
65
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
66
+ sequence if provided).
67
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
68
+ acceptable input length for the model if that argument is not provided.
69
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
70
+ lengths).
71
+ max_length (`int`, *optional*):
72
+ Maximum length of the returned list and optionally padding length (see above).
73
+ truncation (`bool`, *optional*):
74
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
75
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
76
+ If set, will return tensors of a particular framework. Acceptable values are:
77
+
78
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
79
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
80
+ - `'np'`: Return NumPy `np.ndarray` objects.
81
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
82
+
83
+ Returns:
84
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
85
+
86
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
87
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
88
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
89
+ `None`).
90
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
91
+ """
92
+
93
+ if text is None and images is None:
94
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
95
+
96
+ if text is not None:
97
+ encoding = self.tokenizer(
98
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
99
+ )
100
+
101
+ if images is not None:
102
+ image_features = self.image_processor(images, return_tensors=return_tensors)
103
+
104
+ if text is not None and images is not None:
105
+ encoding["pixel_values"] = image_features.pixel_values
106
+ return encoding
107
+ elif text is not None:
108
+ return encoding
109
+ else:
110
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
111
+
112
+ def decode(self, *args, **kwargs):
113
+ """
114
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
115
+ the docstring of this method for more information.
116
+ """
117
+ return self.tokenizer.decode(*args, **kwargs)
118
+
119
+ def batch_decode(self, *args, **kwargs):
120
+ """
121
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
122
+ refer to the docstring of this method for more information.
123
+ """
124
+ return self.tokenizer.batch_decode(*args, **kwargs)
125
+
126
+ @property
127
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
128
+ def model_input_names(self):
129
+ tokenizer_input_names = self.tokenizer.model_input_names
130
+ image_processor_input_names = self.image_processor.model_input_names
131
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
modeling/siglip/tokenization_siglip.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization class for SigLIP model."""
5
+
6
+ import os
7
+ import re
8
+ import string
9
+ import warnings
10
+ from shutil import copyfile
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
12
+
13
+ import sentencepiece as spm
14
+
15
+ from transformers.convert_slow_tokenizer import import_protobuf
16
+ from transformers.tokenization_utils import PreTrainedTokenizer
17
+ from transformers.tokenization_utils_base import AddedToken
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.tokenization_utils_base import TextInput
22
+ from transformers.utils import logging, requires_backends
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
28
+
29
+
30
+ SPIECE_UNDERLINE = "▁"
31
+
32
+
33
+ class SiglipTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+
37
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
38
+ this superclass for more information regarding those methods.
39
+
40
+ Args:
41
+ vocab_file (`str`):
42
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
43
+ contains the vocabulary necessary to instantiate a tokenizer.
44
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
45
+ The end of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"</s>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ additional_special_tokens (`List[str]`, *optional*):
52
+ Additional special tokens used by the tokenizer.
53
+ sp_model_kwargs (`dict`, *optional*):
54
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
55
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
56
+ to set:
57
+
58
+ - `enable_sampling`: Enable subword regularization.
59
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
60
+
61
+ - `nbest_size = {0,1}`: No sampling is performed.
62
+ - `nbest_size > 1`: samples from the nbest_size results.
63
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
64
+ using forward-filtering-and-backward-sampling algorithm.
65
+
66
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
67
+ BPE-dropout.
68
+ model_max_length (`int`, *optional*, defaults to 64):
69
+ The maximum length (in number of tokens) for model inputs.
70
+ do_lower_case (`bool`, *optional*, defaults to `True`):
71
+ Whether or not to lowercase the input when tokenizing.
72
+ """
73
+
74
+ vocab_files_names = VOCAB_FILES_NAMES
75
+ model_input_names = ["input_ids", "attention_mask"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_file,
80
+ eos_token="</s>",
81
+ unk_token="<unk>",
82
+ pad_token="</s>",
83
+ additional_special_tokens=None,
84
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
85
+ model_max_length=64,
86
+ do_lower_case=True,
87
+ **kwargs,
88
+ ) -> None:
89
+ requires_backends(self, "protobuf")
90
+
91
+ pad_token = (
92
+ AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
93
+ if isinstance(pad_token, str)
94
+ else pad_token
95
+ )
96
+ unk_token = (
97
+ AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
98
+ if isinstance(unk_token, str)
99
+ else unk_token
100
+ )
101
+ eos_token = (
102
+ AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
103
+ if isinstance(eos_token, str)
104
+ else eos_token
105
+ )
106
+
107
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
108
+
109
+ self.do_lower_case = do_lower_case
110
+ self.vocab_file = vocab_file
111
+
112
+ self.sp_model = self.get_spm_processor()
113
+ self.vocab_file = vocab_file
114
+
115
+ super().__init__(
116
+ eos_token=eos_token,
117
+ unk_token=unk_token,
118
+ pad_token=pad_token,
119
+ additional_special_tokens=additional_special_tokens,
120
+ sp_model_kwargs=self.sp_model_kwargs,
121
+ model_max_length=model_max_length,
122
+ do_lower_case=do_lower_case,
123
+ **kwargs,
124
+ )
125
+
126
+ def get_spm_processor(self):
127
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
128
+ with open(self.vocab_file, "rb") as f:
129
+ sp_model = f.read()
130
+ model_pb2 = import_protobuf()
131
+ model = model_pb2.ModelProto.FromString(sp_model)
132
+ normalizer_spec = model_pb2.NormalizerSpec()
133
+ normalizer_spec.add_dummy_prefix = False
134
+ model.normalizer_spec.MergeFrom(normalizer_spec)
135
+ sp_model = model.SerializeToString()
136
+ tokenizer.LoadFromSerializedProto(sp_model)
137
+ return tokenizer
138
+
139
+ @property
140
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
141
+ def vocab_size(self):
142
+ return self.sp_model.get_piece_size()
143
+
144
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
145
+ def get_vocab(self):
146
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
147
+ vocab.update(self.added_tokens_encoder)
148
+ return vocab
149
+
150
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
151
+ def get_special_tokens_mask(
152
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
153
+ ) -> List[int]:
154
+ """
155
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
156
+ special tokens using the tokenizer `prepare_for_model` method.
157
+
158
+ Args:
159
+ token_ids_0 (`List[int]`):
160
+ List of IDs.
161
+ token_ids_1 (`List[int]`, *optional*):
162
+ Optional second list of IDs for sequence pairs.
163
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
164
+ Whether or not the token list is already formatted with special tokens for the model.
165
+
166
+ Returns:
167
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
168
+ """
169
+ if already_has_special_tokens:
170
+ return super().get_special_tokens_mask(
171
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
172
+ )
173
+
174
+ # normal case: some special tokens
175
+ if token_ids_1 is None:
176
+ return ([0] * len(token_ids_0)) + [1]
177
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
178
+
179
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
180
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
181
+ """Do not add eos again if user already added it."""
182
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
183
+ warnings.warn(
184
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
185
+ " eos tokens being added."
186
+ )
187
+ return token_ids
188
+ else:
189
+ return token_ids + [self.eos_token_id]
190
+
191
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
192
+ def create_token_type_ids_from_sequences(
193
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
194
+ ) -> List[int]:
195
+ """
196
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
197
+ use of token type ids, therefore a list of zeros is returned.
198
+
199
+ Args:
200
+ token_ids_0 (`List[int]`):
201
+ List of IDs.
202
+ token_ids_1 (`List[int]`, *optional*):
203
+ Optional second list of IDs for sequence pairs.
204
+
205
+ Returns:
206
+ `List[int]`: List of zeros.
207
+ """
208
+ eos = [self.eos_token_id]
209
+
210
+ if token_ids_1 is None:
211
+ return len(token_ids_0 + eos) * [0]
212
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
213
+
214
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
215
+ def build_inputs_with_special_tokens(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
+ ) -> List[int]:
218
+ """
219
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
220
+ adding special tokens. A sequence has the following format:
221
+
222
+ - single sequence: `X </s>`
223
+ - pair of sequences: `A </s> B </s>`
224
+
225
+ Args:
226
+ token_ids_0 (`List[int]`):
227
+ List of IDs to which the special tokens will be added.
228
+ token_ids_1 (`List[int]`, *optional*):
229
+ Optional second list of IDs for sequence pairs.
230
+
231
+ Returns:
232
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
233
+ """
234
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
235
+ if token_ids_1 is None:
236
+ return token_ids_0
237
+ else:
238
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
239
+ return token_ids_0 + token_ids_1
240
+
241
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
242
+ def __getstate__(self):
243
+ state = self.__dict__.copy()
244
+ state["sp_model"] = None
245
+ return state
246
+
247
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
248
+ def __setstate__(self, d):
249
+ self.__dict__ = d
250
+
251
+ # for backward compatibility
252
+ if not hasattr(self, "sp_model_kwargs"):
253
+ self.sp_model_kwargs = {}
254
+
255
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
256
+ self.sp_model.Load(self.vocab_file)
257
+
258
+ def remove_punctuation(self, text: str) -> str:
259
+ return text.translate(str.maketrans("", "", string.punctuation))
260
+
261
+ # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
262
+ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
263
+ """Returns canonicalized `text` (puncuation removed).
264
+
265
+ Args:
266
+ text (`str`):
267
+ String to be canonicalized.
268
+ keep_punctuation_exact_string (`str`, *optional*):
269
+ If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
270
+ (but will still remove '{' and '}' that appear separately).
271
+ """
272
+ if keep_punctuation_exact_string:
273
+ text = keep_punctuation_exact_string.join(
274
+ self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
275
+ )
276
+ else:
277
+ text = self.remove_punctuation(text)
278
+ text = re.sub(r"\s+", " ", text)
279
+ text = text.strip()
280
+
281
+ return text
282
+
283
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
284
+ """
285
+ Converts a string to a list of tokens.
286
+ """
287
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
288
+
289
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
290
+ tokens = tokens[1:]
291
+ return tokens
292
+
293
+ @property
294
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
295
+ def unk_token_length(self):
296
+ return len(self.sp_model.encode(str(self.unk_token)))
297
+
298
+ def _tokenize(self, text, **kwargs):
299
+ """
300
+ Returns a tokenized string.
301
+
302
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
303
+ SPIECE_UNDERLINE.
304
+
305
+ For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
306
+
307
+ Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
308
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
309
+ """
310
+ text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
311
+ tokens = self.sp_model.encode(text, out_type=str)
312
+
313
+ # 1. Encode string + prefix ex: "<unk> Hey"
314
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
315
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
316
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
317
+
318
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
319
+ def _convert_token_to_id(self, token):
320
+ """Converts a token (str) in an id using the vocab."""
321
+ return self.sp_model.piece_to_id(token)
322
+
323
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
324
+ def _convert_id_to_token(self, index):
325
+ """Converts an index (integer) in a token (str) using the vocab."""
326
+ token = self.sp_model.IdToPiece(index)
327
+ return token
328
+
329
+ def convert_tokens_to_string(self, tokens):
330
+ """Converts a sequence of tokens (string) in a single string."""
331
+ current_sub_tokens = []
332
+ out_string = ""
333
+ prev_is_special = False
334
+ for token in tokens:
335
+ # make sure that special tokens are not decoded using sentencepiece model
336
+ if token in self.all_special_tokens:
337
+ if not prev_is_special:
338
+ out_string += " "
339
+ out_string += self.sp_model.decode(current_sub_tokens) + token
340
+ prev_is_special = True
341
+ current_sub_tokens = []
342
+ else:
343
+ current_sub_tokens.append(token)
344
+ prev_is_special = False
345
+ out_string += self.sp_model.decode(current_sub_tokens)
346
+ return out_string.strip()
347
+
348
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
349
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
350
+ if not os.path.isdir(save_directory):
351
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
352
+ return
353
+ out_vocab_file = os.path.join(
354
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
+ )
356
+
357
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
358
+ copyfile(self.vocab_file, out_vocab_file)
359
+ elif not os.path.isfile(self.vocab_file):
360
+ with open(out_vocab_file, "wb") as fi:
361
+ content_spiece_model = self.sp_model.serialized_model_proto()
362
+ fi.write(content_spiece_model)
363
+
364
+ return (out_vocab_file,)
run.err ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793]
2
+ W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W1025 21:14:01.211000 2808260 site-packages/torch/distributed/run.py:793] *****************************************
5
+ wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id h200-zebra-cot-20251025_211359-run0.
6
+ [rank2]:[W1025 21:14:15.369652204 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
7
+ [rank7]:[W1025 21:14:15.502472578 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 7] using GPU 7 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
8
+ [rank5]:[W1025 21:14:15.521361526 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 5] using GPU 5 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
9
+ [rank4]:[W1025 21:14:15.539230512 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 4] using GPU 4 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
10
+ [rank1]:[W1025 21:14:15.559660446 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
11
+ [rank3]:[W1025 21:14:15.636618409 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
12
+ [rank6]:[W1025 21:14:15.814060558 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 6] using GPU 6 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
13
+ wandb: Tracking run with wandb version 0.22.2
14
+ wandb: W&B syncing is set to `offline` in this directory. Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
15
+ wandb: Run data is saved locally in /scratch/by2593/Bagel-Zebra-CoT-origin/wandb/offline-run-20251025_211414-h200-zebra-cot-20251025_211359-run0
16
+ wandb: Detected [huggingface_hub.inference] in use.
17
+ wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
18
+ wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/
19
+ [rank0]:[W1025 21:14:16.181889866 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
20
+ [2025-10-25 21:14:20] Training arguments TrainingArguments(visual_gen=True, visual_und=True, results_dir='results/', checkpoint_dir='results/checkpoints_smm_semantic_part1_v1_origin/', wandb_project='zebra-cot', wandb_name='h200-zebra-cot-20251025_211359', wandb_runid='0', wandb_resume='allow', wandb_offline=True, global_seed=4396, auto_resume=True, resume_from='/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8', resume_model_only=True, finetune_from_ema=False, finetune_from_hf=True, log_every=1, save_every=50, total_steps=5000, warmup_steps=50, lr_scheduler='cosine', lr=2e-05, min_lr=1e-06, beta1=0.9, beta2=0.95, eps=1e-08, ema=0.9999, max_grad_norm=1.0, timestep_shift=1.0, mse_weight=1.0, ce_weight=1.0, ce_loss_reweighting=False, expected_num_tokens=40000, num_replicate=1, num_shard=8, sharding_strategy='HYBRID_SHARD', backward_prefetch='BACKWARD_PRE', cpu_offload=True, freeze_llm=False, freeze_vit=False, freeze_vae=True, freeze_und=False, copy_init_moe=True, use_flex=False)
21
+ [2025-10-25 21:14:20] Model arguments ModelArguments(model_path='/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8', llm_path='hf/Qwen2.5-0.5B-Instruct/', llm_qk_norm=True, tie_word_embeddings=False, layer_module='Qwen2MoTDecoderLayer', vae_path='flux/vae/ae.safetensors', vit_path='hf/siglip-so400m-14-980-flash-attn2-navit/', max_latent_size=64, latent_patch_size=2, vit_patch_size=14, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', interpolate_pos=False, vit_select_layer=-2, vit_rope=False, text_cond_dropout_prob=0.1, vae_cond_dropout_prob=0.3, vit_cond_dropout_prob=0.3)
22
+ [2025-10-25 21:14:20] Data arguments DataArguments(dataset_config_file='./data/configs/example_smm_semantic.yaml', prefetch_factor=2, num_workers=1, max_num_tokens_per_sample=40000, max_num_tokens=40000, prefer_buffer_before=10000, max_buffer_size=50, data_seed=42)
23
+ [2025-10-25 21:16:50] Loading checkpoint from /scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8.
24
+ [2025-10-25 21:18:10] _IncompatibleKeys(missing_keys=['latent_pos_embed.pos_embed', 'vit_pos_embed.pos_embed'], unexpected_keys=[])
25
+ [2025-10-25 21:18:10] replicaing ema model from /scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8/model_bf16.safetensors.
26
+ [2025-10-25 21:18:20] _IncompatibleKeys(missing_keys=['latent_pos_embed.pos_embed', 'vit_pos_embed.pos_embed'], unexpected_keys=[])
27
+ [2025-10-25 21:18:51] Training for 5000 steps, starting at 0...
28
+ [2025-10-25 21:20:20] (step=0000000) Train Loss mse: 0.0185, Train Loss ce: 1.8625, Train Steps/Sec: 0.01,
29
+ [2025-10-25 21:20:57] (step=0000001) Train Loss mse: 0.0168, Train Loss ce: 1.8560, Train Steps/Sec: 0.03,
30
+ [2025-10-25 21:21:32] (step=0000002) Train Loss mse: 0.0208, Train Loss ce: 1.8139, Train Steps/Sec: 0.03,
31
+ [2025-10-25 21:22:13] (step=0000003) Train Loss mse: 0.0200, Train Loss ce: 1.6772, Train Steps/Sec: 0.02,
32
+ [2025-10-25 21:22:49] (step=0000004) Train Loss mse: 0.0164, Train Loss ce: 1.7684, Train Steps/Sec: 0.03,
33
+ [2025-10-25 21:23:31] (step=0000005) Train Loss mse: 0.0199, Train Loss ce: 1.8439, Train Steps/Sec: 0.02,
34
+ [2025-10-25 21:24:04] (step=0000006) Train Loss mse: 0.0166, Train Loss ce: 1.6152, Train Steps/Sec: 0.03,
35
+ [2025-10-25 21:24:40] (step=0000007) Train Loss mse: 0.0181, Train Loss ce: 1.7539, Train Steps/Sec: 0.03,
36
+ [2025-10-25 21:25:15] (step=0000008) Train Loss mse: 0.0164, Train Loss ce: 1.7400, Train Steps/Sec: 0.03,
37
+ [2025-10-25 21:25:49] (step=0000009) Train Loss mse: 0.0167, Train Loss ce: 1.8076, Train Steps/Sec: 0.03,
38
+ [2025-10-25 21:26:25] (step=0000010) Train Loss mse: 0.0233, Train Loss ce: 1.4616, Train Steps/Sec: 0.03,
39
+ [2025-10-25 21:26:56] (step=0000011) Train Loss mse: 0.0168, Train Loss ce: 1.6259, Train Steps/Sec: 0.03,
40
+ [2025-10-25 21:27:37] (step=0000012) Train Loss mse: 0.0170, Train Loss ce: 1.5824, Train Steps/Sec: 0.02,
41
+ [2025-10-25 21:28:08] (step=0000013) Train Loss mse: 0.0189, Train Loss ce: 1.5811, Train Steps/Sec: 0.03,
42
+ [2025-10-25 21:28:42] (step=0000014) Train Loss mse: 0.0221, Train Loss ce: 1.2260, Train Steps/Sec: 0.03,
43
+ [2025-10-25 21:29:16] (step=0000015) Train Loss mse: 0.0140, Train Loss ce: 1.1394, Train Steps/Sec: 0.03,
44
+ [2025-10-25 21:29:49] (step=0000016) Train Loss mse: 0.0163, Train Loss ce: 1.1381, Train Steps/Sec: 0.03,
45
+ [2025-10-25 21:30:26] (step=0000017) Train Loss mse: 0.0229, Train Loss ce: 1.0493, Train Steps/Sec: 0.03,
46
+ [2025-10-25 21:31:02] (step=0000018) Train Loss mse: 0.0169, Train Loss ce: 1.0484, Train Steps/Sec: 0.03,
47
+ [2025-10-25 21:31:43] (step=0000019) Train Loss mse: 0.0187, Train Loss ce: 0.5945, Train Steps/Sec: 0.02,
48
+ [2025-10-25 21:32:19] (step=0000020) Train Loss mse: 0.0158, Train Loss ce: 0.6128, Train Steps/Sec: 0.03,
49
+ [2025-10-25 21:33:00] (step=0000021) Train Loss mse: 0.0157, Train Loss ce: 0.4668, Train Steps/Sec: 0.02,
50
+ [2025-10-25 21:33:33] (step=0000022) Train Loss mse: 0.0181, Train Loss ce: 0.4042, Train Steps/Sec: 0.03,
51
+ [2025-10-25 21:34:07] (step=0000023) Train Loss mse: 0.0209, Train Loss ce: 0.2930, Train Steps/Sec: 0.03,
52
+ [2025-10-25 21:34:40] (step=0000024) Train Loss mse: 0.0190, Train Loss ce: 0.2934, Train Steps/Sec: 0.03,
53
+ [2025-10-25 21:35:16] (step=0000025) Train Loss mse: 0.0144, Train Loss ce: 0.2189, Train Steps/Sec: 0.03,
54
+ [2025-10-25 21:35:49] (step=0000026) Train Loss mse: 0.0185, Train Loss ce: 0.1414, Train Steps/Sec: 0.03,
55
+ [2025-10-25 21:36:22] (step=0000027) Train Loss mse: 0.0166, Train Loss ce: 0.1090, Train Steps/Sec: 0.03,
56
+ [2025-10-25 21:36:59] (step=0000028) Train Loss mse: 0.0202, Train Loss ce: 0.1350, Train Steps/Sec: 0.03,
57
+ [2025-10-25 21:37:36] (step=0000029) Train Loss mse: 0.0175, Train Loss ce: 0.1263, Train Steps/Sec: 0.03,
58
+ [2025-10-25 21:38:11] (step=0000030) Train Loss mse: 0.0165, Train Loss ce: 0.0860, Train Steps/Sec: 0.03,
59
+ [2025-10-25 21:38:47] (step=0000031) Train Loss mse: 0.0169, Train Loss ce: 0.0864, Train Steps/Sec: 0.03,
60
+ [2025-10-25 21:39:20] (step=0000032) Train Loss mse: 0.0218, Train Loss ce: 0.0792, Train Steps/Sec: 0.03,
61
+ [2025-10-25 21:39:57] (step=0000033) Train Loss mse: 0.0203, Train Loss ce: 0.0852, Train Steps/Sec: 0.03,
62
+ [2025-10-25 21:40:30] (step=0000034) Train Loss mse: 0.0200, Train Loss ce: 0.0734, Train Steps/Sec: 0.03,
63
+ [2025-10-25 21:41:07] (step=0000035) Train Loss mse: 0.0166, Train Loss ce: 0.0830, Train Steps/Sec: 0.03,
64
+ [2025-10-25 21:41:42] (step=0000036) Train Loss mse: 0.0167, Train Loss ce: 0.0776, Train Steps/Sec: 0.03,
65
+ [2025-10-25 21:42:14] (step=0000037) Train Loss mse: 0.0175, Train Loss ce: 0.0556, Train Steps/Sec: 0.03,
66
+ [2025-10-25 21:42:51] (step=0000038) Train Loss mse: 0.0176, Train Loss ce: 0.0520, Train Steps/Sec: 0.03,
67
+ [2025-10-25 21:43:23] (step=0000039) Train Loss mse: 0.0144, Train Loss ce: 0.0607, Train Steps/Sec: 0.03,
68
+ [2025-10-25 21:43:59] (step=0000040) Train Loss mse: 0.0151, Train Loss ce: 0.0683, Train Steps/Sec: 0.03,
69
+ [2025-10-25 21:44:32] (step=0000041) Train Loss mse: 0.0180, Train Loss ce: 0.0456, Train Steps/Sec: 0.03,
70
+ [2025-10-25 21:45:08] (step=0000042) Train Loss mse: 0.0157, Train Loss ce: 0.0620, Train Steps/Sec: 0.03,
71
+ [2025-10-25 21:45:51] (step=0000043) Train Loss mse: 0.0167, Train Loss ce: 0.0552, Train Steps/Sec: 0.02,
72
+ [2025-10-25 21:46:28] (step=0000044) Train Loss mse: 0.0143, Train Loss ce: 0.0522, Train Steps/Sec: 0.03,
73
+ [2025-10-25 21:47:08] (step=0000045) Train Loss mse: 0.0159, Train Loss ce: 0.0494, Train Steps/Sec: 0.02,
74
+ [2025-10-25 21:47:41] (step=0000046) Train Loss mse: 0.0160, Train Loss ce: 0.0484, Train Steps/Sec: 0.03,
75
+ [2025-10-25 21:48:14] (step=0000047) Train Loss mse: 0.0187, Train Loss ce: 0.0599, Train Steps/Sec: 0.03,
76
+ [2025-10-25 21:48:52] (step=0000048) Train Loss mse: 0.0173, Train Loss ce: 0.0629, Train Steps/Sec: 0.03,
77
+ [2025-10-25 21:49:26] (step=0000049) Train Loss mse: 0.0167, Train Loss ce: 0.0466, Train Steps/Sec: 0.03,
78
+ [2025-10-25 21:50:00] (step=0000050) Train Loss mse: 0.0150, Train Loss ce: 0.0540, Train Steps/Sec: 0.03,
79
+ [2025-10-25 21:50:01] Saving checkpoint to results/checkpoints_smm_semantic_part1_v1_origin/0000050.
80
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
81
+ warnings.warn(
82
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
83
+ warnings.warn(
84
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
85
+ warnings.warn(
86
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
87
+ warnings.warn(
88
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
89
+ warnings.warn(
90
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
91
+ warnings.warn(
92
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
93
+ warnings.warn(
94
+ /scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
95
+ warnings.warn(
96
+ [2025-10-25 21:55:05] Sorted checkpoint directories: ['0000050']
97
+ [2025-10-25 21:55:40] (step=0000051) Train Loss mse: 0.0139, Train Loss ce: 0.0539, Train Steps/Sec: 0.00,
98
+ [2025-10-25 21:56:13] (step=0000052) Train Loss mse: 0.0176, Train Loss ce: 0.0495, Train Steps/Sec: 0.03,
99
+ [2025-10-25 21:56:51] (step=0000053) Train Loss mse: 0.0168, Train Loss ce: 0.0485, Train Steps/Sec: 0.03,
100
+ [2025-10-25 21:57:23] (step=0000054) Train Loss mse: 0.0151, Train Loss ce: 0.0446, Train Steps/Sec: 0.03,
101
+ [2025-10-25 21:58:00] (step=0000055) Train Loss mse: 0.0144, Train Loss ce: 0.0490, Train Steps/Sec: 0.03,
102
+ [2025-10-25 21:58:37] (step=0000056) Train Loss mse: 0.0143, Train Loss ce: 0.0461, Train Steps/Sec: 0.03,
103
+ [2025-10-25 21:59:11] (step=0000057) Train Loss mse: 0.0152, Train Loss ce: 0.0459, Train Steps/Sec: 0.03,
104
+ [2025-10-25 21:59:48] (step=0000058) Train Loss mse: 0.0152, Train Loss ce: 0.0402, Train Steps/Sec: 0.03,
105
+ [2025-10-25 22:00:22] (step=0000059) Train Loss mse: 0.0145, Train Loss ce: 0.0566, Train Steps/Sec: 0.03,
106
+ [2025-10-25 22:00:59] (step=0000060) Train Loss mse: 0.0174, Train Loss ce: 0.0509, Train Steps/Sec: 0.03,
107
+ [rank6]: Traceback (most recent call last):
108
+ [rank6]: File "/scratch/by2593/Bagel-Zebra-CoT-origin/train/pretrain_unified_navit.py", line 727, in <module>
109
+ [rank6]: main()
110
+ [rank6]: File "/scratch/by2593/Bagel-Zebra-CoT-origin/train/pretrain_unified_navit.py", line 609, in main
111
+ [rank6]: assert not training_args.visual_und
112
+ [rank6]: AssertionError
113
+ [rank6]:[W1025 22:01:04.973896433 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
114
+ W1025 22:01:11.227000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808294 closing signal SIGTERM
115
+ W1025 22:01:11.264000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808295 closing signal SIGTERM
116
+ W1025 22:01:11.265000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808296 closing signal SIGTERM
117
+ W1025 22:01:11.271000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808297 closing signal SIGTERM
118
+ W1025 22:01:11.314000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808298 closing signal SIGTERM
119
+ W1025 22:01:11.332000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808299 closing signal SIGTERM
120
+ W1025 22:01:11.357000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 2808301 closing signal SIGTERM
121
+ E1025 22:01:37.654000 2808260 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 6 (pid: 2808300) of binary: /scratch/by2593/miniconda3/envs/bagel/bin/python3.10
122
+ Traceback (most recent call last):
123
+ File "/scratch/by2593/miniconda3/envs/bagel/bin/torchrun", line 7, in <module>
124
+ sys.exit(main())
125
+ File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
126
+ return f(*args, **kwargs)
127
+ File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
128
+ run(args)
129
+ File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
130
+ elastic_launch(
131
+ File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
132
+ return launch_agent(self._config, self._entrypoint, list(args))
133
+ File "/scratch/by2593/miniconda3/envs/bagel/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
134
+ raise ChildFailedError(
135
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
136
+ ============================================================
137
+ train/pretrain_unified_navit.py FAILED
138
+ ------------------------------------------------------------
139
+ Failures:
140
+ <NO_OTHER_FAILURES>
141
+ ------------------------------------------------------------
142
+ Root Cause (first observed failure):
143
+ [0]:
144
+ time : 2025-10-25_22:01:11
145
+ host : gh129.hpc.nyu.edu
146
+ rank : 6 (local_rank: 6)
147
+ exitcode : 1 (pid: 2808300)
148
+ error_file: <N/A>
149
+ traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
150
+ ============================================================
run.out ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
2
+ rank-3 worker-0 dataset-block_dataset: resuming data at row#0
3
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
4
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
5
+ rank-6 worker-0 dataset-block_dataset: resuming data at row#0
6
+ rank-4 worker-0 dataset-block_dataset: resuming data at row#0
7
+ FullyShardedDataParallel(
8
+ (_fsdp_wrapped_module): Bagel(
9
+ (language_model): Qwen2ForCausalLM(
10
+ (model): Qwen2Model(
11
+ (embed_tokens): Embedding(152064, 3584)
12
+ (layers): ModuleList(
13
+ (0-27): 28 x FullyShardedDataParallel(
14
+ (_fsdp_wrapped_module): CheckpointWrapper(
15
+ (_checkpoint_wrapped_module): Qwen2MoTDecoderLayer(
16
+ (self_attn): PackedAttentionMoT(
17
+ (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
18
+ (k_proj): Linear(in_features=3584, out_features=512, bias=True)
19
+ (v_proj): Linear(in_features=3584, out_features=512, bias=True)
20
+ (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
21
+ (q_norm): Qwen2RMSNorm((128,), eps=1e-06)
22
+ (k_norm): Qwen2RMSNorm((128,), eps=1e-06)
23
+ (q_norm_moe_gen): Qwen2RMSNorm((128,), eps=1e-06)
24
+ (k_norm_moe_gen): Qwen2RMSNorm((128,), eps=1e-06)
25
+ (q_proj_moe_gen): Linear(in_features=3584, out_features=3584, bias=True)
26
+ (k_proj_moe_gen): Linear(in_features=3584, out_features=512, bias=True)
27
+ (v_proj_moe_gen): Linear(in_features=3584, out_features=512, bias=True)
28
+ (o_proj_moe_gen): Linear(in_features=3584, out_features=3584, bias=False)
29
+ )
30
+ (mlp): Qwen2MLP(
31
+ (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
32
+ (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
33
+ (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
34
+ (act_fn): SiLU()
35
+ )
36
+ (mlp_moe_gen): Qwen2MLP(
37
+ (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
38
+ (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
39
+ (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
40
+ (act_fn): SiLU()
41
+ )
42
+ (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
43
+ (input_layernorm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
44
+ (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
45
+ (post_attention_layernorm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
46
+ )
47
+ )
48
+ )
49
+ )
50
+ (norm): Qwen2RMSNorm((3584,), eps=1e-06)
51
+ (norm_moe_gen): Qwen2RMSNorm((3584,), eps=1e-06)
52
+ (rotary_emb): Qwen2RotaryEmbedding()
53
+ )
54
+ (lm_head): Linear(in_features=3584, out_features=152064, bias=False)
55
+ )
56
+ (time_embedder): FullyShardedDataParallel(
57
+ (_fsdp_wrapped_module): TimestepEmbedder(
58
+ (mlp): Sequential(
59
+ (0): Linear(in_features=256, out_features=3584, bias=True)
60
+ (1): SiLU()
61
+ (2): Linear(in_features=3584, out_features=3584, bias=True)
62
+ )
63
+ )
64
+ )
65
+ (vae2llm): Linear(in_features=64, out_features=3584, bias=True)
66
+ (llm2vae): Linear(in_features=3584, out_features=64, bias=True)
67
+ (latent_pos_embed): FullyShardedDataParallel(
68
+ (_fsdp_wrapped_module): PositionEmbedding()
69
+ )
70
+ (vit_model): SiglipVisionModel(
71
+ (vision_model): FullyShardedDataParallel(
72
+ (_fsdp_wrapped_module): SiglipVisionTransformer(
73
+ (embeddings): SiglipVisionEmbeddings(
74
+ (position_embedding): Embedding(4900, 1152)
75
+ (patch_embedding): Linear(in_features=588, out_features=1152, bias=True)
76
+ )
77
+ (encoder): SiglipEncoder(
78
+ (layers): ModuleList(
79
+ (0-25): 26 x FullyShardedDataParallel(
80
+ (_fsdp_wrapped_module): CheckpointWrapper(
81
+ (_checkpoint_wrapped_module): SiglipEncoderLayer(
82
+ (self_attn): SiglipFlashAttention2(
83
+ (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
84
+ (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
85
+ (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
86
+ (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
87
+ )
88
+ (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
89
+ (mlp): SiglipMLP(
90
+ (activation_fn): PytorchGELUTanh()
91
+ (fc1): Linear(in_features=1152, out_features=4304, bias=True)
92
+ (fc2): Linear(in_features=4304, out_features=1152, bias=True)
93
+ )
94
+ (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
95
+ )
96
+ )
97
+ )
98
+ )
99
+ )
100
+ (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
101
+ )
102
+ )
103
+ )
104
+ (connector): FullyShardedDataParallel(
105
+ (_fsdp_wrapped_module): CheckpointWrapper(
106
+ (_checkpoint_wrapped_module): MLPconnector(
107
+ (activation_fn): PytorchGELUTanh()
108
+ (fc1): Linear(in_features=1152, out_features=3584, bias=True)
109
+ (fc2): Linear(in_features=3584, out_features=3584, bias=True)
110
+ )
111
+ )
112
+ )
113
+ (vit_pos_embed): FullyShardedDataParallel(
114
+ (_fsdp_wrapped_module): PositionEmbedding()
115
+ )
116
+ )
117
+ )
118
+ _flat_param True
119
+ language_model.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
120
+ language_model.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
121
+ language_model.model.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
122
+ language_model.model.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
123
+ language_model.model.layers.4._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
124
+ language_model.model.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
125
+ language_model.model.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
126
+ language_model.model.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
127
+ language_model.model.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
128
+ language_model.model.layers.9._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
129
+ language_model.model.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
130
+ language_model.model.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
131
+ language_model.model.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
132
+ language_model.model.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
133
+ language_model.model.layers.14._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
134
+ language_model.model.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
135
+ language_model.model.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
136
+ language_model.model.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
137
+ language_model.model.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
138
+ language_model.model.layers.19._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
139
+ language_model.model.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
140
+ language_model.model.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
141
+ language_model.model.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
142
+ language_model.model.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
143
+ language_model.model.layers.24._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
144
+ language_model.model.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
145
+ language_model.model.layers.26._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
146
+ language_model.model.layers.27._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
147
+ time_embedder._fsdp_wrapped_module._flat_param True
148
+ latent_pos_embed._fsdp_wrapped_module._flat_param False
149
+ vit_model.vision_model._fsdp_wrapped_module._flat_param True
150
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
151
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
152
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
153
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
154
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.4._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
155
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
156
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
157
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
158
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
159
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.9._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
160
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
161
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
162
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
163
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
164
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.14._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
165
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
166
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
167
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
168
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
169
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.19._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
170
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
171
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
172
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
173
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
174
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.24._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
175
+ vit_model.vision_model._fsdp_wrapped_module.encoder.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
176
+ connector._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param True
177
+ vit_pos_embed._fsdp_wrapped_module._flat_param False
178
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
179
+ Preparing Dataset block_dataset/block_dataset
180
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
181
+ rank-0 worker-0 dataset-block_dataset: resuming data at row#0
182
+ rank-7 worker-0 dataset-block_dataset: resuming data at row#0
183
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
184
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
185
+ rank-2 worker-0 dataset-block_dataset: resuming data at row#0
186
+ rank-5 worker-0 dataset-block_dataset: resuming data at row#0
187
+ {'block_dataset': {'dataset_names': ['block_dataset'], 'jsonl_path_list': ['/scratch/by2593/project/SMM/SMM_data/semantic_block_train_part1.jsonl'], 'num_used_data': 'None', 'image_prefix_dir': '/scratch/by2593/project/SMM/semantic_blocks_part1', 'image_transform_args': {'image_stride': 16, 'max_image_size': 512, 'min_image_size': 512}, 'vit_image_transform_args': {'image_stride': 14, 'max_image_size': 512, 'min_image_size': 512}, 'weight': 1.0, 'is_mandatory': True}}
188
+ rank-1 worker-0 dataset-block_dataset: resuming data at row#0
189
+ skip a sample with length 43202
190
+ skip a sample with length 48060
191
+ skip a sample with length 41094
192
+ skip a sample with length 43245
193
+ skip a sample with length 57756
194
+ skip a sample with length 41160
195
+ skip a sample with length 44611
196
+ skip a sample with length 41094
197
+ skip a sample with length 48060
198
+ skip a sample with length 50787
199
+ skip a sample with length 44611
200
+ skip a sample with length 43245
201
+ skip a sample with length 41106
202
+ skip a sample with length 41160
203
+ skip a sample with length 57756
204
+ skip a sample with length 42480
205
+ skip a sample with length 42486
206
+ skip a sample with length 42486
207
+ skip a sample with length 50787
208
+ skip a sample with length 43202
209
+ skip a sample with length 42480
210
+ block_dataset repeat in rank-3 worker-0
211
+ block_dataset repeat in rank-4 worker-0
212
+ block_dataset repeat in rank-6 worker-0
213
+ block_dataset repeat in rank-7 worker-0
214
+ block_dataset repeat in rank-0 worker-0
215
+ block_dataset repeat in rank-5 worker-0
216
+ block_dataset repeat in rank-2 worker-0
217
+ skip a sample with length 41106
218
+ skip a sample with length 48060
219
+ skip a sample with length 43202
220
+ block_dataset repeat in rank-1 worker-0
221
+ skip a sample with length 41094
222
+ skip a sample with length 57756
223
+ Yielding data with length 31517
224
+ skip a sample with length 43245
225
+ skip a sample with length 41160
226
+ skip a sample with length 44611
227
+ Yielding data with length 33637
228
+ Yielding data with length 33154
229
+ Yielding data with length 15542
230
+ skip a sample with length 50787
231
+ Yielding data with length 35486
232
+ Yielding data with length 12716
233
+ skip a sample with length 48060
234
+ skip a sample with length 41094
235
+ skip a sample with length 43245
236
+ skip a sample with length 41160
237
+ skip a sample with length 44611
238
+ Yielding data with length 26172
239
+ Yielding data with length 23933
240
+ skip a sample with length 41106
241
+ skip a sample with length 57756
242
+ Yielding data with length 32737
243
+ Yielding data with length 27691
244
+ Yielding data with length 31628
245
+ Yielding data with length 36149
246
+ skip a sample with length 42486
247
+ Yielding data with length 30708
248
+ Yielding data with length 13411
249
+ Yielding data with length 18973
250
+ Yielding data with length 27959
251
+ Yielding data with length 23821
252
+ skip a sample with length 50787
253
+ Yielding data with length 27474
254
+ Yielding data with length 7870
255
+ skip a sample with length 42486
256
+ Yielding data with length 37241
257
+ Yielding data with length 27998
258
+ Yielding data with length 13811
259
+ Yielding data with length 20795
260
+ Yielding data with length 32169
261
+ Yielding data with length 16921
262
+ Yielding data with length 16202
263
+ Yielding data with length 21081
264
+ Yielding data with length 21217
265
+ Yielding data with length 26994
266
+ Yielding data with length 17856
267
+ Yielding data with length 33309
268
+ Yielding data with length 31064
269
+ Yielding data with length 23492
270
+ Yielding data with length 20761
271
+ Yielding data with length 31378
272
+ Yielding data with length 23451
273
+ Yielding data with length 25220
274
+ Yielding data with length 26611
275
+ Yielding data with length 27250
276
+ Yielding data with length 35216
277
+ skip a sample with length 42480
278
+ Yielding data with length 13720
279
+ Yielding data with length 19578
280
+ Yielding data with length 25498
281
+ Yielding data with length 22109
282
+ Yielding data with length 19619
283
+ Yielding data with length 23415
284
+ Yielding data with length 30332
285
+ Yielding data with length 34858
286
+ block_dataset repeat in rank-0 worker-0
287
+ block_dataset repeat in rank-6 worker-0
288
+ block_dataset repeat in rank-5 worker-0
289
+ Yielding data with length 19720
290
+ Yielding data with length 25991
291
+ Yielding data with length 29387
292
+ Yielding data with length 21979
293
+ skip a sample with length 43202
294
+ skip a sample with length 41106
295
+ Yielding data with length 23402
296
+ Yielding data with length 22465
297
+ Yielding data with length 21998
298
+ Yielding data with length 25679
299
+ block_dataset repeat in rank-3 worker-0
300
+ Yielding data with length 17957
301
+ Yielding data with length 22013
302
+ Yielding data with length 20711
303
+ Yielding data with length 23461
304
+ Yielding data with length 24469
305
+ Yielding data with length 24915
306
+ Yielding data with length 27691
307
+ Yielding data with length 37262
308
+ skip a sample with length 43202
309
+ block_dataset repeat in rank-7 worker-0
310
+ block_dataset repeat in rank-1 worker-0
311
+ Yielding data with length 17288
312
+ Yielding data with length 20687
313
+ Yielding data with length 20361
314
+ Yielding data with length 28560
315
+ Yielding data with length 31247
316
+ Yielding data with length 17983
317
+ block_dataset repeat in rank-2 worker-0
318
+ Yielding data with length 27946
319
+ Yielding data with length 27631
320
+ skip a sample with length 41094
321
+ skip a sample with length 42480
322
+ Yielding data with length 10650
323
+ Yielding data with length 14641
324
+ Yielding data with length 23037
325
+ skip a sample with length 43245
326
+ Yielding data with length 16219
327
+ Yielding data with length 35530
328
+ Yielding data with length 16208
329
+ Yielding data with length 26188
330
+ Yielding data with length 27937
331
+ block_dataset repeat in rank-4 worker-0
332
+ Yielding data with length 11424
333
+ Yielding data with length 12453
334
+ Yielding data with length 16146
335
+ Yielding data with length 18287
336
+ Yielding data with length 20791
337
+ Yielding data with length 24236
338
+ Yielding data with length 25579
339
+ Yielding data with length 28956
340
+ Yielding data with length 14121
341
+ Yielding data with length 14781
342
+ Yielding data with length 15221
343
+ Yielding data with length 15921
344
+ skip a sample with length 41160
345
+ Yielding data with length 28466
346
+ skip a sample with length 48060
347
+ Yielding data with length 17646
348
+ Yielding data with length 31256
349
+ Yielding data with length 26792
350
+ Yielding data with length 17122
351
+ Yielding data with length 20057
352
+ skip a sample with length 48060
353
+ Yielding data with length 31691
354
+ Yielding data with length 32761
355
+ Yielding data with length 23701
356
+ Yielding data with length 23722
357
+ Yielding data with length 27340
358
+ Yielding data with length 33869
359
+ skip a sample with length 44611
360
+ skip a sample with length 57756
361
+ skip a sample with length 41094
362
+ Yielding data with length 15091Yielding data with length 16206
363
+
364
+ Yielding data with length 13157
365
+ Yielding data with length 26843
366
+ Yielding data with length 21094
367
+ Yielding data with length 24549
368
+ Yielding data with length 20404
369
+ Yielding data with length 25400
370
+ skip a sample with length 41106
371
+ Yielding data with length 18332
372
+ Yielding data with length 20708
373
+ Yielding data with length 21310
374
+ skip a sample with length 50787
375
+ Yielding data with length 27881
376
+ Yielding data with length 25557
377
+ skip a sample with length 57756
378
+ Yielding data with length 24894
379
+ Yielding data with length 28219
380
+ Yielding data with length 24140
381
+ skip a sample with length 43245
382
+ skip a sample with length 44611
383
+ skip a sample with length 41160
384
+ Yielding data with length 27592
385
+ Yielding data with length 26168
386
+ Yielding data with length 20709
387
+ Yielding data with length 23581
388
+ skip a sample with length 42486
389
+ Yielding data with length 29274
390
+ Yielding data with length 24805
391
+ Yielding data with length 31112
392
+ Yielding data with length 36407
393
+ skip a sample with length 50787
394
+ Yielding data with length 18262
395
+ Yielding data with length 26439
396
+ Yielding data with length 18322
397
+ Yielding data with length 33505
398
+ Yielding data with length 29023
399
+ Yielding data with length 25487
400
+ Yielding data with length 31643
401
+ Yielding data with length 27712
402
+ skip a sample with length 42486
403
+ Yielding data with length 15735
404
+ Yielding data with length 17616
405
+ Yielding data with length 13811
406
+ Yielding data with length 19365
407
+ Yielding data with length 19566
408
+ Yielding data with length 24227
409
+ Yielding data with length 28214
410
+ Yielding data with length 30026
411
+ Yielding data with length 18195
412
+ Yielding data with length 18206
413
+ Yielding data with length 19699
414
+ Yielding data with length 23103
415
+ Yielding data with length 33474
416
+ Yielding data with length 29109
417
+ Yielding data with length 36518
418
+ Yielding data with length 27659
419
+ Yielding data with length 21031
420
+ Yielding data with length 27532
421
+ Yielding data with length 21080
422
+ Yielding data with length 20740
423
+ Yielding data with length 24066
424
+ Yielding data with length 26959
425
+ Yielding data with length 32162
426
+ skip a sample with length 42480
427
+ Yielding data with length 31373
428
+ block_dataset repeat in rank-0 worker-0
429
+ Yielding data with length 9629
430
+ block_dataset repeat in rank-6 worker-0
431
+ Yielding data with length 12734
432
+ Yielding data with length 20622
433
+ Yielding data with length 31650
434
+ Yielding data with length 23291
435
+ Yielding data with length 25245
436
+ Yielding data with length 27515
437
+ Yielding data with length 28296
438
+ Yielding data with length 20698
439
+ Yielding data with length 21726
440
+ skip a sample with length 43202
441
+ Yielding data with length 21768
442
+ Yielding data with length 18011
443
+ Yielding data with length 23070
444
+ Yielding data with length 19691
445
+ Yielding data with length 25171
446
+ Yielding data with length 33860
447
+ block_dataset repeat in rank-5 worker-0
448
+ Yielding data with length 8964
449
+ skip a sample with length 43202
450
+ block_dataset repeat in rank-3 worker-0
451
+ Yielding data with length 19248
452
+ Yielding data with length 16262
453
+ Yielding data with length 29186
454
+ skip a sample with length 41106
455
+ Yielding data with length 19245
456
+ Yielding data with length 24191
457
+ Yielding data with length 23133
458
+ Yielding data with length 35614
459
+ block_dataset repeat in rank-2 worker-0
460
+ Yielding data with length 13769
461
+ Yielding data with length 24400
462
+ Yielding data with length 31113
463
+ Yielding data with length 25652
464
+ Yielding data with length 25500
465
+ Yielding data with length 26979
466
+ block_dataset repeat in rank-7 worker-0
467
+ Yielding data with length 24263
468
+ Yielding data with length 27393
469
+ skip a sample with length 41094
470
+ Yielding data with length 23188
471
+ Yielding data with length 19658
472
+ block_dataset repeat in rank-1 worker-0
473
+ Yielding data with length 24787
474
+ Yielding data with length 26221
475
+ Yielding data with length 21409
476
+ Yielding data with length 32059
477
+ skip a sample with length 43245
478
+ Yielding data with length 26058
479
+ Yielding data with length 24507
480
+ Yielding data with length 8292
481
+ skip a sample with length 42480
482
+ Yielding data with length 12746
483
+ Yielding data with length 17288
484
+ Yielding data with length 20793
485
+ Yielding data with length 17252
486
+ Yielding data with length 25240
487
+ Yielding data with length 25304
488
+ Yielding data with length 30376
489
+ block_dataset repeat in rank-4 worker-0
490
+ Yielding data with length 14138
491
+ skip a sample with length 48060
492
+ skip a sample with length 41160
493
+ skip a sample with length 48060
494
+ Yielding data with length 19684
495
+ Yielding data with length 14748
496
+ Yielding data with length 21158
497
+ Yielding data with length 21425
498
+ Yielding data with length 30781
499
+ Yielding data with length 33027
500
+ Yielding data with length 33537
501
+ Yielding data with length 14837
502
+ Yielding data with length 12766
503
+ Yielding data with length 14115
504
+ Yielding data with length 15474
505
+ Yielding data with length 21749
506
+ Yielding data with length 33147
507
+ Yielding data with length 25621
508
+ skip a sample with length 57756
509
+ Yielding data with length 22466
510
+ skip a sample with length 41094
511
+ Yielding data with length 22413
512
+ skip a sample with length 50787
513
+ Yielding data with length 27913
514
+ Yielding data with length 25090
515
+ Yielding data with length 25551
516
+ Yielding data with length 25335
517
+ skip a sample with length 57756
518
+ skip a sample with length 43245
519
+ Yielding data with length 25947
520
+ skip a sample with length 41160
521
+ Yielding data with length 31872
522
+ Yielding data with length 36109
523
+ skip a sample with length 44611
524
+ Yielding data with length 16234
525
+ Yielding data with length 19945
526
+ Yielding data with length 19685
527
+ Yielding data with length 34186
528
+ Yielding data with length 36943
529
+ Yielding data with length 23090
530
+ Yielding data with length 29034
531
+ Yielding data with length 30067
532
+ Yielding data with length 8489
533
+ skip a sample with length 41106
534
+ skip a sample with length 44611
535
+ skip a sample with length 50787
536
+ Yielding data with length 10008
537
+ Yielding data with length 32829
538
+ Yielding data with length 23593
539
+ Yielding data with length 29907
540
+ skip a sample with length 42486
541
+ Yielding data with length 25500
542
+ Yielding data with length 34717
543
+ Yielding data with length 29714
544
+ Yielding data with length 16266
545
+ Yielding data with length 17271
546
+ Yielding data with length 20547
547
+ Yielding data with length 22351
548
+ Yielding data with length 26637
549
+ Yielding data with length 32390
550
+ Yielding data with length 30503
551
+ Yielding data with length 29728
552
+ skip a sample with length 42486
553
+ Yielding data with length 21561
554
+ Yielding data with length 16923
555
+ Yielding data with length 19642
556
+ Yielding data with length 20198
557
+ Yielding data with length 22735
558
+ Yielding data with length 32930
559
+ Yielding data with length 24262
560
+ Yielding data with length 34823
561
+ Yielding data with length 28608
562
+ Yielding data with length 28122
563
+ Yielding data with length 24532
564
+ Yielding data with length 26210
565
+ Yielding data with length 36308
566
+ Yielding data with length 27414
567
+ Yielding data with length 30425
568
+ Yielding data with length 30774
569
+ block_dataset repeat in rank-6 worker-0
570
+ block_dataset repeat in rank-0 worker-0
571
+ skip a sample with length 42480
572
+ Yielding data with length 15870
573
+ Yielding data with length 15590
574
+ Yielding data with length 18509
575
+ Yielding data with length 23812
576
+ Yielding data with length 18170
577
+ Yielding data with length 32514
578
+ Yielding data with length 24814
579
+ Yielding data with length 28298
580
+ Yielding data with length 9988
581
+ Yielding data with length 18332
582
+ Yielding data with length 21420
583
+ Yielding data with length 23903
584
+ Yielding data with length 25120
585
+ Yielding data with length 28991
586
+ Yielding data with length 30114
587
+ Yielding data with length 30128
588
+ skip a sample with length 43202
589
+ skip a sample with length 43202
590
+ Yielding data with length 17850
591
+ Yielding data with length 18166
592
+ Yielding data with length 22663
593
+ Yielding data with length 20751
594
+ Yielding data with length 19273
595
+ Yielding data with length 17552
596
+ Yielding data with length 26616
597
+ Yielding data with length 28527
598
+ block_dataset repeat in rank-3 worker-0
599
+ block_dataset repeat in rank-5 worker-0
600
+ Yielding data with length 13466
601
+ Yielding data with length 14852
602
+ block_dataset repeat in rank-7 worker-0
603
+ Yielding data with length 20760
604
+ Yielding data with length 22448
605
+ Yielding data with length 20269
606
+ Yielding data with length 27307
607
+ Yielding data with length 31128
608
+ Yielding data with length 23848
609
+ block_dataset repeat in rank-2 worker-0
610
+ Yielding data with length 8948
611
+ skip a sample with length 41106
612
+ Yielding data with length 10367
613
+ Yielding data with length 12612
614
+ Yielding data with length 18632
615
+ Yielding data with length 32428
616
+ Yielding data with length 25651
617
+ Yielding data with length 22117
618
+ Yielding data with length 30468
619
+ Yielding data with length 12051
620
+ Yielding data with length 13346
621
+ Yielding data with length 15726
622
+ Yielding data with length 11383
623
+ skip a sample with length 41094
624
+ Yielding data with length 19358
625
+ Yielding data with length 31964
626
+ skip a sample with length 43245
627
+ Yielding data with length 34359
628
+ Yielding data with length 25146
629
+ block_dataset repeat in rank-1 worker-0
630
+ Yielding data with length 12528
631
+ Yielding data with length 14445
632
+ Yielding data with length 21808
633
+ skip a sample with length 48060
634
+ Yielding data with length 24973
635
+ Yielding data with length 24141
636
+ Yielding data with length 35965
637
+ Yielding data with length 29665
638
+ Yielding data with length 28975
639
+ skip a sample with length 42480
640
+ skip a sample with length 48060
641
+ Yielding data with length 15182
642
+ Yielding data with length 19712
643
+ skip a sample with length 41160
644
+ Yielding data with length 19698
645
+ Yielding data with length 18255
646
+ Yielding data with length 30749
647
+ Yielding data with length 34841
648
+ Yielding data with length 22848
649
+ Yielding data with length 28618
650
+ block_dataset repeat in rank-4 worker-0
651
+ Yielding data with length 12071
652
+ Yielding data with length 15527
653
+ Yielding data with length 19227
654
+ Yielding data with length 19199
655
+ skip a sample with length 50787
656
+ Yielding data with length 25207
657
+ Yielding data with length 26500
658
+ skip a sample with length 57756
659
+ Yielding data with length 25915
660
+ skip a sample with length 57756
661
+ Yielding data with length 29886
662
+ skip a sample with length 43245
663
+ skip a sample with length 41160
664
+ skip a sample with length 41094
665
+ Yielding data with length 18659
666
+ Yielding data with length 23460
667
+ Yielding data with length 29942
668
+ Yielding data with length 30289
669
+ Yielding data with length 27297
670
+ Yielding data with length 28034
671
+ Yielding data with length 29025
672
+ Yielding data with length 36590
673
+ skip a sample with length 44611
674
+ skip a sample with length 50787
675
+ skip a sample with length 44611
676
+ Yielding data with length 24792
677
+ Yielding data with length 20748
678
+ Yielding data with length 23187
679
+ Yielding data with length 19037
680
+ Yielding data with length 31561
681
+ Yielding data with length 34200
682
+ Yielding data with length 26330
683
+ Yielding data with length 30027
684
+ skip a sample with length 41106
685
+ Yielding data with length 12095
686
+ Yielding data with length 15214
687
+ Yielding data with length 17243
688
+ Yielding data with length 23097
689
+ Yielding data with length 24142
690
+ Yielding data with length 28934
691
+ Yielding data with length 29052
692
+ skip a sample with length 42486
693
+ Yielding data with length 34556
694
+ Yielding data with length 14895
695
+ Yielding data with length 19552
696
+ Yielding data with length 22053
697
+ Yielding data with length 29467
698
+ Yielding data with length 23444
699
+ Yielding data with length 26636
700
+ Yielding data with length 33801
701
+ Yielding data with length 34191
702
+ skip a sample with length 42486
703
+ Yielding data with length 20414
704
+ Yielding data with length 21739
705
+ Yielding data with length 23877
706
+ Yielding data with length 26520
707
+ Yielding data with length 24877
708
+ Yielding data with length 27696
709
+ Yielding data with length 27597
710
+ Yielding data with length 32703
711
+ block_dataset repeat in rank-0 worker-0
712
+ block_dataset repeat in rank-6 worker-0
713
+ Yielding data with length 18005
714
+ Yielding data with length 26527
715
+ Yielding data with length 20791
716
+ Yielding data with length 20719
717
+ Yielding data with length 22114
718
+ Yielding data with length 22512
719
+ Yielding data with length 29336
720
+ Yielding data with length 31527
721
+ Yielding data with length 9284
722
+ skip a sample with length 42480
723
+ Yielding data with length 17316
724
+ Yielding data with length 19314
725
+ Yielding data with length 25239
726
+ Yielding data with length 19703
727
+ Yielding data with length 21232
728
+ Yielding data with length 17268
729
+ Yielding data with length 26931
730
+ skip a sample with length 43202
731
+ Yielding data with length 16230
732
+ Yielding data with length 19692
733
+ Yielding data with length 23196
734
+ Yielding data with length 22444
735
+ Yielding data with length 29708
736
+ Yielding data with length 20680
737
+ Yielding data with length 30765
738
+ Yielding data with length 27917
739
+ skip a sample with length 43202
740
+ Yielding data with length 17207
741
+ Yielding data with length 17853
742
+ Yielding data with length 23427
743
+ block_dataset repeat in rank-5 worker-0
744
+ Yielding data with length 27646
745
+ Yielding data with length 25169
746
+ Yielding data with length 26475
747
+ Yielding data with length 25127
748
+ Yielding data with length 27339
749
+ block_dataset repeat in rank-3 worker-0
750
+ block_dataset repeat in rank-7 worker-0
751
+ Yielding data with length 16546
752
+ Yielding data with length 16256
753
+ Yielding data with length 22339
754
+ Yielding data with length 17919
755
+ Yielding data with length 23138
756
+ Yielding data with length 19676
757
+ Yielding data with length 24070
758
+ Yielding data with length 25924
759
+ block_dataset repeat in rank-2 worker-0
760
+ Yielding data with length 14569
761
+ Yielding data with length 31705
762
+ Yielding data with length 24120
763
+ Yielding data with length 33709
764
+ Yielding data with length 26245
765
+ Yielding data with length 39397
766
+ Yielding data with length 31035
767
+ Yielding data with length 15921
768
+ skip a sample with length 41094
769
+ skip a sample with length 41106
770
+ Yielding data with length 11323
771
+ Yielding data with length 20758
772
+ Yielding data with length 24109
773
+ skip a sample with length 43245
774
+ skip a sample with length 48060
775
+ Yielding data with length 21739
776
+ Yielding data with length 22062
777
+ Yielding data with length 11069
778
+ Yielding data with length 33774
779
+ Yielding data with length 24783
780
+ Yielding data with length 13348
781
+ Yielding data with length 13218
782
+ Yielding data with length 17288
783
+ Yielding data with length 26493
784
+ Yielding data with length 24246
785
+ Yielding data with length 26920
786
+ Yielding data with length 28599
787
+ Yielding data with length 31042
788
+ block_dataset repeat in rank-1 worker-0
789
+ skip a sample with length 48060
790
+ skip a sample with length 41160
791
+ Yielding data with length 25722
792
+ Yielding data with length 33186
793
+ skip a sample with length 50787
794
+ Yielding data with length 19367
795
+ Yielding data with length 26598
796
+ Yielding data with length 18672
797
+ Yielding data with length 27291
798
+ Yielding data with length 33105
799
+ skip a sample with length 57756
800
+ Yielding data with length 31380
801
+ skip a sample with length 43245
802
+ skip a sample with length 42480
803
+ skip a sample with length 41160
804
+ Yielding data with length 22996
805
+ Yielding data with length 18896
806
+ Yielding data with length 19621
807
+ Yielding data with length 24453
808
+ Yielding data with length 37227
809
+ Yielding data with length 28758
810
+ skip a sample with length 57756
811
+ Yielding data with length 31736
812
+ Yielding data with length 26241
813
+ block_dataset repeat in rank-4 worker-0
814
+ skip a sample with length 44611
815
+ skip a sample with length 50787
816
+ Yielding data with length 8502
817
+ skip a sample with length 41094
818
+ Yielding data with length 23339
819
+ Yielding data with length 26828
820
+ Yielding data with length 22141
821
+ Yielding data with length 27917
822
+ Yielding data with length 30731
823
+ Yielding data with length 35152
824
+ Yielding data with length 32504
825
+ Yielding data with length 14524
826
+ Yielding data with length 21770
827
+ Yielding data with length 23021
828
+ Yielding data with length 31645
829
+ Yielding data with length 34056
830
+ skip a sample with length 44611
831
+ Yielding data with length 24506
832
+ Yielding data with length 27457
833
+ Yielding data with length 28513
834
+ Yielding data with length 15147
835
+ skip a sample with length 41106
836
+ Yielding data with length 16968
837
+ Yielding data with length 13491
838
+ Yielding data with length 22125
839
+ Yielding data with length 21138
840
+ Yielding data with length 24903
841
+ Yielding data with length 28043
842
+ skip a sample with length 42486
843
+ Yielding data with length 31782
844
+ Yielding data with length 6701
845
+ Yielding data with length 13494
846
+ Yielding data with length 15875
847
+ Yielding data with length 17545
848
+ Yielding data with length 21060
849
+ Yielding data with length 22115
850
+ Yielding data with length 29729
851
+ Yielding data with length 31752
852
+ skip a sample with length 42486
853
+ Yielding data with length 17316
854
+ block_dataset repeat in rank-0 worker-0
855
+ Yielding data with length 24478
856
+ Yielding data with length 24714
857
+ skip a sample with length 42480
858
+ Yielding data with length 24145
859
+ Yielding data with length 25188
860
+ Yielding data with length 21724
861
+ block_dataset repeat in rank-6 worker-0
862
+ Yielding data with length 28652
863
+ Yielding data with length 31606
864
+ Yielding data with length 8619
865
+ Yielding data with length 16608
866
+ Yielding data with length 21134
867
+ Yielding data with length 28671
868
+ Yielding data with length 24139
869
+ Yielding data with length 34737
870
+ Yielding data with length 28959
871
+ Yielding data with length 30967
scripts/eval/eval_vlm.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Check if enough arguments are provided
5
+ if [ $# -lt 2 ]; then
6
+ echo "Error: PREFIX_DIR and MODEL_PATH are required as the first and second arguments respectively."
7
+ exit 1
8
+ fi
9
+
10
+ LOG_PATH=$1
11
+ if [ ! -d "$LOG_PATH" ]; then
12
+ mkdir -p "$LOG_PATH"
13
+ fi
14
+ shift 1
15
+ ARGS=("$@")
16
+ export MASTER_PORT=10042
17
+
18
+ FULL_MODEL_PATH="$PREFIX_DIR/$MODEL_PATH"
19
+
20
+ IFS=' ' read -r -a DATASETS <<< "$DATASETS_STR"
21
+
22
+ for DATASET in "${DATASETS[@]}"; do
23
+ bash eval/vlm/evaluate.sh \
24
+ "$DATASET" \
25
+ --out-dir "$LOG_PATH/$DATASET" \
26
+ "${ARGS[@]}"
27
+ done
scripts/eval/run_eval_vlm.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ # Set proxy and API key
7
+ export OPENAI_API_KEY=$openai_api_key
8
+
9
+ export GPUS=1
10
+
11
+ DATASETS=("mme" "mmbench-dev-en" "mmvet" "mmmu-val" "mathvista-testmini" "mmvp")
12
+ # DATASETS=("mmmu-val_cot")
13
+
14
+ DATASETS_STR="${DATASETS[*]}"
15
+ export DATASETS_STR
16
+
17
+ bash scripts/eval/eval_vlm.sh \
18
+ $output_path \
19
+ --model-path $model_path
scripts/eval/run_gedit.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # run this script at the root of the project folder
5
+ pip install httpx==0.23.0
6
+ pip install openai==1.87.0
7
+ pip install datasets
8
+ pip install megfile
9
+
10
+
11
+ N_GPU=8 # Number of GPU used in for the evaluation
12
+ MODEL_PATH="/Path/to/BAGEL-7B-MoT"
13
+ OUTPUT_DIR="/Path/to/save/results"
14
+ GEN_DIR="$OUTPUT_DIR/gen_image"
15
+ LOG_DIR="$OUTPUT_DIR/logs"
16
+
17
+ AZURE_ENDPOINT="https://azure_endpoint_url_you_use" # set up the azure openai endpoint url
18
+ AZURE_OPENAI_KEY="" # set up the azure openai key
19
+ N_GPT_PARALLEL=10
20
+
21
+
22
+ mkdir -p "$OUTPUT_DIR"
23
+ mkdir -p "$GEN_DIR"
24
+ mkdir -p "$LOG_DIR"
25
+
26
+
27
+ # # ----------------------------
28
+ # # Download GEdit Dataset
29
+ # # ----------------------------
30
+ python -c "from datasets import load_dataset; dataset = load_dataset('stepfun-ai/GEdit-Bench')"
31
+ echo "Dataset Downloaded"
32
+
33
+
34
+ # # ---------------------
35
+ # # Generate Images
36
+ # # ---------------------
37
+ for ((i=0; i<$N_GPU; i++)); do
38
+ nohup python3 eval/gen/gedit/gen_images_gedit.py --model_path "$MODEL_PATH" --output_dir "$GEN_DIR" --shard_id $i --total_shards "$N_GPU" --device $i 2>&1 | tee "$LOG_DIR"/request_$(($N_GPU + i)).log &
39
+ done
40
+
41
+ wait
42
+ echo "Image Generation Done"
43
+
44
+
45
+ # # ---------------------
46
+ # # GPT Evaluation
47
+ # # ---------------------
48
+ cd eval/gen/gedit
49
+ python test_gedit_score.py --save_path "$OUTPUT_DIR" --azure_endpoint "$AZURE_ENDPOINT" --gpt_keys "$AZURE_OPENAI_KEY" --max_workers "$N_GPT_PARALLEL"
50
+ echo "Evaluation Done"
51
+
52
+
53
+ # # --------------------
54
+ # # Print Results
55
+ # # --------------------
56
+ python calculate_statistics.py --save_path "$OUTPUT_DIR" --language en
57
+
scripts/eval/run_geneval.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ GPUS=8
7
+
8
+
9
+ # generate images
10
+ torchrun \
11
+ --nnodes=1 \
12
+ --node_rank=0 \
13
+ --nproc_per_node=$GPUS \
14
+ --master_addr=127.0.0.1 \
15
+ --master_port=12345 \
16
+ ./eval/gen/gen_images_mp.py \
17
+ --output_dir $output_path/images \
18
+ --metadata_file ./eval/gen/geneval/prompts/evaluation_metadata_long.jsonl \
19
+ --batch_size 1 \
20
+ --num_images 4 \
21
+ --resolution 1024 \
22
+ --max_latent_size 64 \
23
+ --model-path $model_path \
24
+ # --metadata_file ./eval/gen/geneval/prompts/evaluation_metadata.jsonl \
25
+
26
+
27
+ # calculate score
28
+ torchrun \
29
+ --nnodes=1 \
30
+ --node_rank=0 \
31
+ --nproc_per_node=$GPUS \
32
+ --master_addr=127.0.0.1 \
33
+ --master_port=12345 \
34
+ ./eval/gen/geneval/evaluation/evaluate_images_mp.py \
35
+ $output_path/images \
36
+ --outfile $output_path/results.jsonl \
37
+ --model-path ./eval/gen/geneval/model
38
+
39
+
40
+ # summarize score
41
+ python ./eval/gen/geneval/evaluation/summary_scores.py $output_path/results.jsonl
scripts/eval/run_imgedit.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ export OPENAI_API_KEY=$openai_api_key
7
+
8
+ GPUS=8
9
+
10
+
11
+ # generate images
12
+ torchrun \
13
+ --nnodes=1 \
14
+ --node_rank=0 \
15
+ --nproc_per_node=$GPUS \
16
+ --master_addr=127.0.0.1 \
17
+ --master_port=12345 \
18
+ ./eval/gen/gen_images_mp_imgedit.py \
19
+ --output_dir $output_path/bagel \
20
+ --metadata_file ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
21
+ --max_latent_size 64 \
22
+ --model-path $model_path
23
+
24
+
25
+ # calculate score
26
+ python ./eval/gen/imgedit/basic_bench.py \
27
+ --result_img_folder $output_path/bagel \
28
+ --edit_json ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
29
+ --origin_img_root ./eval/gen/imgedit/Benchmark/singleturn \
30
+ --num_processes 4 \
31
+ --prompts_json ./eval/gen/imgedit/Benchmark/singleturn/judge_prompt.json
32
+
33
+
34
+ # summarize score
35
+ python ./eval/gen/imgedit/step1_get_avgscore.py \
36
+ --result_json $output_path/bagel/result.json \
37
+ --average_score_json $output_path/bagel/average_score.json
38
+
39
+ python ./eval/gen/imgedit/step2_typescore.py \
40
+ --average_score_json $output_path/bagel/average_score.json \
41
+ --edit_json ./eval/gen/imgedit/Benchmark/singleturn/singleturn.json \
42
+ --typescore_json $output_path/bagel/typescore.json
scripts/eval/run_kris.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ export OPENAI_API_KEY=$openai_api_key
7
+
8
+ GPUS=8
9
+
10
+
11
+ # generate images
12
+ torchrun \
13
+ --nnodes=1 \
14
+ --node_rank=0 \
15
+ --nproc_per_node=$GPUS \
16
+ --master_addr=127.0.0.1 \
17
+ --master_port=12345 \
18
+ ./eval/gen/gen_images_mp_kris.py \
19
+ --output_dir $output_path/bagel \
20
+ --metadata_file ./eval/gen/kris/final_data.json \
21
+ --max_latent_size 64 \
22
+ --model-path $model_path \
23
+ --think
24
+
25
+
26
+ # calculate score
27
+ python ./eval/gen/kris/metrics_common.py \
28
+ --results_dir $output_path \
29
+ --max_workers 8
30
+
31
+ python ./eval/gen/kris/metrics_knowledge.py \
32
+ --results_dir $output_path \
33
+ --max_workers 8
34
+
35
+ python ./eval/gen/kris/metrics_multi_element.py \
36
+ --results_dir $output_path \
37
+ --max_workers 8
38
+
39
+ python ./eval/gen/kris/metrics_temporal_prediction.py \
40
+ --results_dir $output_path \
41
+ --max_workers 8
42
+
43
+ python ./eval/gen/kris/metrics_view_change.py \
44
+ --results_dir $output_path \
45
+ --max_workers 8
46
+
47
+
48
+ # summarize score
49
+ python ./eval/gen/kris/summarize.py \
50
+ --results_dir $output_path/bagel \
scripts/eval/run_rise.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ export OPENAI_API_KEY=$openai_api_key
7
+
8
+ GPUS=8
9
+
10
+
11
+ # generate images
12
+ torchrun \
13
+ --nnodes=1 \
14
+ --node_rank=0 \
15
+ --nproc_per_node=$GPUS \
16
+ --master_addr=127.0.0.1 \
17
+ --master_port=12345 \
18
+ ./eval/gen/gen_images_mp_rise.py \
19
+ --output_dir $output_path/bagel \
20
+ --metadata_file ./eval/gen/rise/data/datav2_total_w_subtask.json \
21
+ --max_latent_size 64 \
22
+ --model-path $model_path \
23
+ --think
24
+
25
+
26
+ # calculate score
27
+ python ./eval/gen/rise/gpt_eval.py \
28
+ --data ./eval/gen/rise/data/datav2_total_w_subtask.json \
29
+ --input ./eval/gen/rise/data \
30
+ --output $output_path/bagel
scripts/eval/run_wise.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ set -x
5
+
6
+ export OPENAI_API_KEY=$openai_api_key
7
+
8
+ GPUS=8
9
+
10
+
11
+ # generate images
12
+ torchrun \
13
+ --nnodes=1 \
14
+ --node_rank=0 \
15
+ --nproc_per_node=$GPUS \
16
+ --master_addr=127.0.0.1 \
17
+ --master_port=12345 \
18
+ ./eval/gen/gen_images_mp_wise.py \
19
+ --output_dir $output_path/images \
20
+ --metadata-file ./eval/gen/wise/final_data.json \
21
+ --resolution 1024 \
22
+ --max-latent_size 64 \
23
+ --model-path $model_path \
24
+ --think
25
+
26
+
27
+ # calculate score
28
+ python3 eval/gen/wise/gpt_eval_mp.py \
29
+ --json_path eval/gen/wise/data/cultural_common_sense.json \
30
+ --image_dir $output_path/images \
31
+ --output_dir $output_path
32
+
33
+ python3 eval/gen/wise/gpt_eval_mp.py \
34
+ --json_path eval/gen/wise/data/spatio-temporal_reasoning.json \
35
+ --image_dir $output_path/images \
36
+ --output_dir $output_path
37
+
38
+ python3 eval/gen/wise/gpt_eval_mp.py \
39
+ --json_path eval/gen/wise/data/natural_science.json \
40
+ --image_dir $output_path/images \
41
+ --output_dir $output_path
42
+
43
+ python3 eval/gen/wise/cal_score.py \
44
+ --output_dir $output_path
scripts/train.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ export HF_HOME=/dev/shm/
5
+ NUM_NODES=1
6
+ NODE_RANK=0
7
+ MASTER_ADDR=localhost
8
+ MASTER_PORT=29500
9
+ NPROC_PER_NODE=8
10
+ MODEL_PATH=/dev/shm/models/BAGEL-7B-MoT
11
+
12
+ # replace the variables with your own
13
+ torchrun \
14
+ --nnodes=$NUM_NODES \
15
+ --node_rank=$NODE_RANK \
16
+ --nproc_per_node=$NPROC_PER_NODE \
17
+ --master_addr=$MASTER_ADDR \
18
+ --master_port=$MASTER_PORT \
19
+ train/pretrain_unified_navit.py \
20
+ --dataset_config_file ./data/configs/example.yaml \
21
+ --model_path $MODEL_PATH \
22
+ --layer_module Qwen2MoTDecoderLayer \
23
+ --max_latent_size 64 \
24
+ --resume-from $MODEL_PATH \
25
+ --finetune_from_hf True \
26
+ --auto_resume True \
27
+ --resume-model-only True \
28
+ --finetune-from-ema True \
29
+ --log_every 1 \
30
+ --lr 2e-5 \
31
+ --lr_scheduler cosine \
32
+ --min_lr 1e-6 \
33
+ --num_worker 1 \
34
+ --expected_num_tokens 60000 \
35
+ --max_num_tokens 60000 \
36
+ --max_num_tokens_per_sample 60000 \
37
+ --prefer_buffer_before 30000 \
38
+ --num_shard=$NPROC_PER_NODE \
39
+ --sharding_strategy="HYBRID_SHARD" \
40
+ --wandb_project "zebra-cot" \
41
+ --wandb_name "h200-zebra-cot-$(date +%Y%m%d_%H%M%S)" \
42
+ --save_every 50 \
43
+ --warmup_steps 50 \
44
+ --total_steps 5000 \
45
+ --results_dir results/ \
46
+ --checkpoint_dir results/checkpoints/ > run.out 2> run.err
47
+
48
+ # --cpu_offload True \
scripts/train_smm.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # Change to the project directory
6
+ cd /scratch/by2593/Bagel-Zebra-CoT-origin
7
+
8
+ export HF_HOME=/dev/shm/
9
+ export PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin:$PYTHONPATH
10
+ export WANDB_MODE=offline
11
+ export WANDB_ANONYMOUS=must
12
+ NUM_NODES=1
13
+ NODE_RANK=0
14
+ MASTER_ADDR=localhost
15
+ MASTER_PORT=29500
16
+ NPROC_PER_NODE=8
17
+ MODEL_PATH=/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8
18
+
19
+ # replace the variables with your own
20
+ torchrun \
21
+ --nnodes=$NUM_NODES \
22
+ --node_rank=$NODE_RANK \
23
+ --nproc_per_node=$NPROC_PER_NODE \
24
+ --master_addr=$MASTER_ADDR \
25
+ --master_port=$MASTER_PORT \
26
+ train/pretrain_unified_navit.py \
27
+ --dataset_config_file ./data/configs/example_smm_semantic.yaml \
28
+ --model_path $MODEL_PATH \
29
+ --layer_module Qwen2MoTDecoderLayer \
30
+ --max_latent_size 64 \
31
+ --resume-from $MODEL_PATH \
32
+ --finetune_from_hf True \
33
+ --auto_resume True \
34
+ --resume-model-only True \
35
+ --finetune-from-ema False \
36
+ --log_every 1 \
37
+ --lr 2e-5 \
38
+ --lr_scheduler cosine \
39
+ --min_lr 1e-6 \
40
+ --num_worker 1 \
41
+ --expected_num_tokens 40000 \
42
+ --max_num_tokens 40000 \
43
+ --max_num_tokens_per_sample 40000 \
44
+ --prefer_buffer_before 10000 \
45
+ --num_shard=$NPROC_PER_NODE \
46
+ --sharding_strategy="HYBRID_SHARD" \
47
+ --wandb_project "zebra-cot" \
48
+ --wandb_name "h200-zebra-cot-$(date +%Y%m%d_%H%M%S)" \
49
+ --save_every 100 \
50
+ --warmup_steps 50 \
51
+ --total_steps 5000 \
52
+ --results_dir results/ \
53
+ --checkpoint_dir results/checkpoints_smm_semantic_part1_v1_origin/ > run.out 2> run.err \
54
+ --cpu_offload True \
55
+
56
+
57
+ # bash scripts/train_smm.sh
scripts/train_smm_sbatch.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=bagel-zebra-cot-smm
3
+ #SBATCH --partition=h200_tandon
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks-per-node=1
6
+ #SBATCH --cpus-per-task=32
7
+ #SBATCH --gres=gpu:h200:8
8
+ #SBATCH --mem=1600G
9
+ #SBATCH --time=48:00:00
10
+ #SBATCH --output=slurm_logs/train_smm_%j.out
11
+ #SBATCH --error=slurm_logs/train_smm_%j.err
12
+
13
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
14
+ # SPDX-License-Identifier: Apache-2.0
15
+
16
+ # Load any necessary modules (adjust as needed for your cluster)
17
+ # module load cuda/12.1
18
+ # module load conda
19
+
20
+ # Activate conda environment
21
+ source /scratch/by2593/miniconda3/etc/profile.d/conda.sh
22
+ conda activate bagel
23
+
24
+ # Change to the project directory
25
+ cd /scratch/by2593/Bagel-Zebra-CoT-origin
26
+
27
+ # Set environment variables
28
+ export HF_HOME=/dev/shm/
29
+ export PYTHONPATH=/scratch/by2593/Bagel-Zebra-CoT-origin:$PYTHONPATH
30
+ export WANDB_MODE=offline
31
+ export WANDB_ANONYMOUS=must
32
+
33
+ # SLURM variables
34
+ NUM_NODES=1
35
+ NODE_RANK=0
36
+ MASTER_ADDR=$(hostname)
37
+ MASTER_PORT=29500
38
+ NPROC_PER_NODE=8
39
+ MODEL_PATH=/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8
40
+
41
+ echo "Starting SMM training on node: $SLURM_JOB_NODELIST"
42
+ echo "Job ID: $SLURM_JOB_ID"
43
+ echo "Number of GPUs: $NPROC_PER_NODE"
44
+
45
+ # Run training
46
+ torchrun \
47
+ --nnodes=$NUM_NODES \
48
+ --node_rank=$NODE_RANK \
49
+ --nproc_per_node=$NPROC_PER_NODE \
50
+ --master_addr=$MASTER_ADDR \
51
+ --master_port=$MASTER_PORT \
52
+ train/pretrain_unified_navit.py \
53
+ --dataset_config_file ./data/configs/example_smm_random.yaml \
54
+ --model_path $MODEL_PATH \
55
+ --layer_module Qwen2MoTDecoderLayer \
56
+ --max_latent_size 64 \
57
+ --visual_und True \
58
+ --finetune_from_hf True \
59
+ --auto_resume True \
60
+ --resume-model-only False \
61
+ --finetune-from-ema False \
62
+ --log_every 1 \
63
+ --lr 2e-5 \
64
+ --lr_scheduler cosine \
65
+ --min_lr 1e-6 \
66
+ --num_worker 1 \
67
+ --expected_num_tokens 50000 \
68
+ --max_num_tokens 50000 \
69
+ --max_num_tokens_per_sample 50000 \
70
+ --prefer_buffer_before 10000 \
71
+ --num_shard=$NPROC_PER_NODE \
72
+ --sharding_strategy="HYBRID_SHARD" \
73
+ --wandb_project "smm" \
74
+ --wandb_name "h200-zebra-cot-smm-sbatch-$(date +%Y%m%d_%H%M%S)" \
75
+ --save_every 100 \
76
+ --warmup_steps 50 \
77
+ --total_steps 5000 \
78
+ --results_dir results/ \
79
+ --checkpoint_dir /scratch/by2593/Bagel-Zebra-CoT-origin/results/checkpoints_smm_random_20251026_033448/ \
80
+ --cpu_offload True \
81
+ --max_checkpoints 2
82
+
83
+ echo "SMM training completed on $(date)"
84
+
85
+ # sbatch scripts/train_smm_sbatch.sh
test_images/image.png ADDED

Git LFS Details

  • SHA256: 8e402e7927312911bc35200f70ef5ce98d8efb4f715b10c768a5018b330d12d4
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
test_images/meme.jpg ADDED
test_images/octupusy.jpg ADDED
test_images/women.jpg ADDED