zhibinlan commited on
Commit
4b44664
·
verified ·
1 Parent(s): 391930a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +322 -322
README.md CHANGED
@@ -1,323 +1,323 @@
1
- ---
2
- language:
3
- - en
4
- library_name: transformers
5
- license: apache-2.0
6
- pipeline_tag: image-text-to-text
7
- tags:
8
- - Sentence Similarity
9
- - Embedding
10
- - zero-shot-image-classification
11
- - video-text-to-text
12
- ---
13
-
14
- # UME-R1-7B
15
-
16
- ## Model Summary
17
-
18
- The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling.
19
-
20
- - **Repository:** [UME-R1](https://github.com/DeepLearnXMU/UME-R1)
21
- - **Paper:** [UME-R1]()
22
-
23
- ## Train/Eval Data
24
- - Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train
25
- - Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2
26
-
27
-
28
-
29
-
30
- ## Model Performance
31
- UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone.
32
-
33
-
34
- <img src="./figures/main_result.png" alt="MMEB-V2" width="1200" height="auto">
35
- <!-- ![MMEB-V2](./figures/main_result.png) -->
36
-
37
- In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling.
38
-
39
- <img src="./figures/scaling.png" alt="pass@k" width="1200" height="auto">
40
-
41
- ### Quick Start
42
-
43
- First clone our github
44
- ```bash
45
- git clone https://github.com/DeepLearnXMU/UME-R1
46
- cd UME-R1
47
- bash setup.sh
48
- ```
49
-
50
- Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers.
51
-
52
- Example of obtaining generative embeddings:
53
-
54
- ```python
55
- from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
56
- from qwen_vl_utils import process_vision_info
57
- import torch
58
-
59
- model = Qwen2VLForConditionalGeneration.from_pretrained(
60
- "zhibinlan/UME-R1-7B",
61
- torch_dtype=torch.bfloat16,
62
- attn_implementation="flash_attention_2",
63
- device_map="cuda:0",
64
- )
65
-
66
- processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-7B")
67
-
68
- prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings.
69
- First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence.
70
- Finally, use the <gen_emb> tag to represent the entire input.'''
71
-
72
-
73
-
74
- messages = [
75
- {
76
- "role": "user",
77
- "content": [
78
- {
79
- "type": "image",
80
- "image": "assets/example.jpg",
81
- },
82
- {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n" + prompt},
83
- ],
84
- }
85
- ]
86
-
87
- # Preparation for inference
88
- text = processor.apply_chat_template(
89
- messages, tokenize=False, add_generation_prompt=True
90
- )
91
-
92
- image_inputs, video_inputs = process_vision_info(messages)
93
- inputs = processor(
94
- text=[text],
95
- images=image_inputs,
96
- videos=video_inputs,
97
- padding=True,
98
- return_tensors="pt",
99
- )
100
- inputs = inputs.to(model.device)
101
-
102
- # Inference: Generation of the output
103
- generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True)
104
- # Post-process the output
105
- generated_ids = generated_output.sequences
106
- hidden_states = generated_output.hidden_states
107
-
108
- generated_ids_trimmed = [
109
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
- ]
111
-
112
- def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
113
-
114
- embedding_idx = []
115
- for i, out_ids in enumerate(generated_ids_trimmed):
116
- embed_exist = False
117
- for j in range(len(out_ids) - 1, -1, -1):
118
- if out_ids[j] == EMBEDDING_TOKEN_ID:
119
- embedding_idx.append(j + 1)
120
- embed_exist = True
121
- break
122
- if not embed_exist:
123
- embedding_idx.append(-1)
124
-
125
- return embedding_idx
126
-
127
- def normalize_reps(reps):
128
- reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
129
- return reps
130
-
131
- # Get the last hidden state of the <gen_emb> token
132
- embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()["<gen_emb>"])
133
- embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1)
134
-
135
- # Normalize the representations
136
- embedding_reps = normalize_reps(embedding_reps)
137
-
138
- output_text = processor.batch_decode(
139
- generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
140
- )
141
- ```
142
-
143
- <details>
144
- <summary>Example of obtaining discriminative embeddings</summary>
145
-
146
- ```python
147
- from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
148
- from qwen_vl_utils import process_vision_info
149
- import torch
150
-
151
- pretrained_path = "zhibinlan/UME-R1-7B"
152
-
153
- # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
154
- model = Qwen2VLForConditionalGeneration.from_pretrained(
155
- pretrained_path,
156
- torch_dtype=torch.bfloat16,
157
- attn_implementation="flash_attention_2",
158
- device_map="cuda:0",
159
- )
160
-
161
- # default processor
162
- processor = AutoProcessor.from_pretrained(pretrained_path)
163
-
164
- messages = [
165
- {
166
- "role": "user",
167
- "content": [
168
- {
169
- "type": "image",
170
- "image": "UME-R1/assets/example.jpg",
171
- },
172
- {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"},
173
- ],
174
- }
175
- ]
176
-
177
- # Preparation for inference
178
- text = processor.apply_chat_template(
179
- messages, tokenize=False, add_generation_prompt=True
180
- )
181
-
182
- image_inputs, video_inputs = process_vision_info(messages)
183
- inputs = processor(
184
- text=[text],
185
- images=image_inputs,
186
- videos=video_inputs,
187
- padding=True,
188
- return_tensors="pt",
189
- )
190
- inputs = inputs.to(model.device)
191
-
192
- def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
193
-
194
- embedding_idx = []
195
- # Search from the last token forward
196
- for i, out_ids in enumerate(generated_ids_trimmed):
197
- embed_exist = False
198
- for j in range(len(out_ids) - 1, -1, -1):
199
- if out_ids[j] == EMBEDDING_TOKEN_ID:
200
- embedding_idx.append(j)
201
- embed_exist = True
202
- break
203
- if not embed_exist:
204
- embedding_idx.append(-1)
205
-
206
- return embedding_idx
207
-
208
- def normalize_reps(reps):
209
- # Normalize the representations
210
- reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
211
- return reps
212
-
213
- output = model(**inputs, output_hidden_states=True, return_dict=True)
214
- hidden_states = output.hidden_states[-1][0]
215
- # print("output.hidden_states shape: ", hidden_states.shape)
216
- embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()["<disc_emb>"])
217
-
218
- # Get the last hidden state of the <gen_emb> token
219
- embedding_reps = hidden_states[embedding_idx[0]]
220
-
221
- # Normalize the representations
222
- embedding_reps = normalize_reps(embedding_reps)
223
- ```
224
-
225
- </details>
226
-
227
- <details>
228
- <summary>Multi image inference</summary>
229
-
230
- ```python
231
- # Messages containing multiple images and a text query
232
- messages = [
233
- {
234
- "role": "user",
235
- "content": [
236
- {"type": "image", "image": "file:///path/to/image1.jpg"},
237
- {"type": "image", "image": "file:///path/to/image2.jpg"},
238
- {"type": "text", "text": "Represent the given images."},
239
- ],
240
- }
241
- ]
242
- ```
243
-
244
- </details>
245
-
246
- <details>
247
- <summary>Video inference</summary>
248
-
249
- ```python
250
- # Messages containing a images list as a video and a text query
251
- messages = [
252
- {
253
- "role": "user",
254
- "content": [
255
- {
256
- "type": "video",
257
- "video": [
258
- "file:///path/to/frame1.jpg",
259
- "file:///path/to/frame2.jpg",
260
- "file:///path/to/frame3.jpg",
261
- "file:///path/to/frame4.jpg",
262
- ],
263
- },
264
- {"type": "text", "text": "Represent this video."},
265
- ],
266
- }
267
- ]
268
-
269
- # Messages containing a local video path and a text query
270
- messages = [
271
- {
272
- "role": "user",
273
- "content": [
274
- {
275
- "type": "video",
276
- "video": "file:///path/to/video1.mp4",
277
- "max_pixels": 360 * 420,
278
- "fps": 1.0,
279
- },
280
- {"type": "text", "text": "Represent this video."},
281
- ],
282
- }
283
- ]
284
-
285
- # Messages containing a video url and a text query
286
- messages = [
287
- {
288
- "role": "user",
289
- "content": [
290
- {
291
- "type": "video",
292
- "video": "https://path/to/video.mp4",
293
- "min_pixels": 4 * 28 * 28,
294
- "max_pixels": 256 * 28 * 28,
295
- "total_pixels": 20480 * 28 * 28,
296
- },
297
- {"type": "text", "text": "Represent this video."},
298
- ],
299
- }
300
- ]
301
- image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
302
- inputs = processor(
303
- text=[text],
304
- images=image_inputs,
305
- videos=video_inputs,
306
- fps=fps,
307
- padding=True,
308
- return_tensors="pt",
309
- **video_kwargs,
310
- )
311
- ```
312
-
313
- </details>
314
-
315
-
316
- For more usage tips, please refer to our [Github page](https://github.com/DeepLearnXMU/UME-R1).
317
-
318
-
319
- ## Citation
320
- If you find our work helpful, feel free to give us a cite.
321
-
322
- ```
323
  ```
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ license: apache-2.0
6
+ pipeline_tag: image-text-to-text
7
+ tags:
8
+ - Sentence Similarity
9
+ - Embedding
10
+ - zero-shot-image-classification
11
+ - video-text-to-text
12
+ ---
13
+
14
+ # UME-R1-7B
15
+
16
+ ## Model Summary
17
+
18
+ The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling.
19
+
20
+ - **Repository:** [UME-R1](https://github.com/DeepLearnXMU/UME-R1)
21
+ - **Paper:** [UME-R1](https://arxiv.org/abs/2511.00405)
22
+
23
+ ## Train/Eval Data
24
+ - Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train
25
+ - Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2
26
+
27
+
28
+
29
+
30
+ ## Model Performance
31
+ UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone.
32
+
33
+
34
+ <img src="./figures/main_result.png" alt="MMEB-V2" width="1200" height="auto">
35
+ <!-- ![MMEB-V2](./figures/main_result.png) -->
36
+
37
+ In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling.
38
+
39
+ <img src="./figures/scaling.png" alt="pass@k" width="1200" height="auto">
40
+
41
+ ### Quick Start
42
+
43
+ First clone our github
44
+ ```bash
45
+ git clone https://github.com/DeepLearnXMU/UME-R1
46
+ cd UME-R1
47
+ bash setup.sh
48
+ ```
49
+
50
+ Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers.
51
+
52
+ Example of obtaining generative embeddings:
53
+
54
+ ```python
55
+ from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
56
+ from qwen_vl_utils import process_vision_info
57
+ import torch
58
+
59
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
60
+ "zhibinlan/UME-R1-7B",
61
+ torch_dtype=torch.bfloat16,
62
+ attn_implementation="flash_attention_2",
63
+ device_map="cuda:0",
64
+ )
65
+
66
+ processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-7B")
67
+
68
+ prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings.
69
+ First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence.
70
+ Finally, use the <gen_emb> tag to represent the entire input.'''
71
+
72
+
73
+
74
+ messages = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {
79
+ "type": "image",
80
+ "image": "assets/example.jpg",
81
+ },
82
+ {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n" + prompt},
83
+ ],
84
+ }
85
+ ]
86
+
87
+ # Preparation for inference
88
+ text = processor.apply_chat_template(
89
+ messages, tokenize=False, add_generation_prompt=True
90
+ )
91
+
92
+ image_inputs, video_inputs = process_vision_info(messages)
93
+ inputs = processor(
94
+ text=[text],
95
+ images=image_inputs,
96
+ videos=video_inputs,
97
+ padding=True,
98
+ return_tensors="pt",
99
+ )
100
+ inputs = inputs.to(model.device)
101
+
102
+ # Inference: Generation of the output
103
+ generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True)
104
+ # Post-process the output
105
+ generated_ids = generated_output.sequences
106
+ hidden_states = generated_output.hidden_states
107
+
108
+ generated_ids_trimmed = [
109
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
+ ]
111
+
112
+ def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
113
+
114
+ embedding_idx = []
115
+ for i, out_ids in enumerate(generated_ids_trimmed):
116
+ embed_exist = False
117
+ for j in range(len(out_ids) - 1, -1, -1):
118
+ if out_ids[j] == EMBEDDING_TOKEN_ID:
119
+ embedding_idx.append(j + 1)
120
+ embed_exist = True
121
+ break
122
+ if not embed_exist:
123
+ embedding_idx.append(-1)
124
+
125
+ return embedding_idx
126
+
127
+ def normalize_reps(reps):
128
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
129
+ return reps
130
+
131
+ # Get the last hidden state of the <gen_emb> token
132
+ embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()["<gen_emb>"])
133
+ embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1)
134
+
135
+ # Normalize the representations
136
+ embedding_reps = normalize_reps(embedding_reps)
137
+
138
+ output_text = processor.batch_decode(
139
+ generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
140
+ )
141
+ ```
142
+
143
+ <details>
144
+ <summary>Example of obtaining discriminative embeddings</summary>
145
+
146
+ ```python
147
+ from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
148
+ from qwen_vl_utils import process_vision_info
149
+ import torch
150
+
151
+ pretrained_path = "zhibinlan/UME-R1-7B"
152
+
153
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
154
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
155
+ pretrained_path,
156
+ torch_dtype=torch.bfloat16,
157
+ attn_implementation="flash_attention_2",
158
+ device_map="cuda:0",
159
+ )
160
+
161
+ # default processor
162
+ processor = AutoProcessor.from_pretrained(pretrained_path)
163
+
164
+ messages = [
165
+ {
166
+ "role": "user",
167
+ "content": [
168
+ {
169
+ "type": "image",
170
+ "image": "UME-R1/assets/example.jpg",
171
+ },
172
+ {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"},
173
+ ],
174
+ }
175
+ ]
176
+
177
+ # Preparation for inference
178
+ text = processor.apply_chat_template(
179
+ messages, tokenize=False, add_generation_prompt=True
180
+ )
181
+
182
+ image_inputs, video_inputs = process_vision_info(messages)
183
+ inputs = processor(
184
+ text=[text],
185
+ images=image_inputs,
186
+ videos=video_inputs,
187
+ padding=True,
188
+ return_tensors="pt",
189
+ )
190
+ inputs = inputs.to(model.device)
191
+
192
+ def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
193
+
194
+ embedding_idx = []
195
+ # Search from the last token forward
196
+ for i, out_ids in enumerate(generated_ids_trimmed):
197
+ embed_exist = False
198
+ for j in range(len(out_ids) - 1, -1, -1):
199
+ if out_ids[j] == EMBEDDING_TOKEN_ID:
200
+ embedding_idx.append(j)
201
+ embed_exist = True
202
+ break
203
+ if not embed_exist:
204
+ embedding_idx.append(-1)
205
+
206
+ return embedding_idx
207
+
208
+ def normalize_reps(reps):
209
+ # Normalize the representations
210
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
211
+ return reps
212
+
213
+ output = model(**inputs, output_hidden_states=True, return_dict=True)
214
+ hidden_states = output.hidden_states[-1][0]
215
+ # print("output.hidden_states shape: ", hidden_states.shape)
216
+ embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()["<disc_emb>"])
217
+
218
+ # Get the last hidden state of the <gen_emb> token
219
+ embedding_reps = hidden_states[embedding_idx[0]]
220
+
221
+ # Normalize the representations
222
+ embedding_reps = normalize_reps(embedding_reps)
223
+ ```
224
+
225
+ </details>
226
+
227
+ <details>
228
+ <summary>Multi image inference</summary>
229
+
230
+ ```python
231
+ # Messages containing multiple images and a text query
232
+ messages = [
233
+ {
234
+ "role": "user",
235
+ "content": [
236
+ {"type": "image", "image": "file:///path/to/image1.jpg"},
237
+ {"type": "image", "image": "file:///path/to/image2.jpg"},
238
+ {"type": "text", "text": "Represent the given images."},
239
+ ],
240
+ }
241
+ ]
242
+ ```
243
+
244
+ </details>
245
+
246
+ <details>
247
+ <summary>Video inference</summary>
248
+
249
+ ```python
250
+ # Messages containing a images list as a video and a text query
251
+ messages = [
252
+ {
253
+ "role": "user",
254
+ "content": [
255
+ {
256
+ "type": "video",
257
+ "video": [
258
+ "file:///path/to/frame1.jpg",
259
+ "file:///path/to/frame2.jpg",
260
+ "file:///path/to/frame3.jpg",
261
+ "file:///path/to/frame4.jpg",
262
+ ],
263
+ },
264
+ {"type": "text", "text": "Represent this video."},
265
+ ],
266
+ }
267
+ ]
268
+
269
+ # Messages containing a local video path and a text query
270
+ messages = [
271
+ {
272
+ "role": "user",
273
+ "content": [
274
+ {
275
+ "type": "video",
276
+ "video": "file:///path/to/video1.mp4",
277
+ "max_pixels": 360 * 420,
278
+ "fps": 1.0,
279
+ },
280
+ {"type": "text", "text": "Represent this video."},
281
+ ],
282
+ }
283
+ ]
284
+
285
+ # Messages containing a video url and a text query
286
+ messages = [
287
+ {
288
+ "role": "user",
289
+ "content": [
290
+ {
291
+ "type": "video",
292
+ "video": "https://path/to/video.mp4",
293
+ "min_pixels": 4 * 28 * 28,
294
+ "max_pixels": 256 * 28 * 28,
295
+ "total_pixels": 20480 * 28 * 28,
296
+ },
297
+ {"type": "text", "text": "Represent this video."},
298
+ ],
299
+ }
300
+ ]
301
+ image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
302
+ inputs = processor(
303
+ text=[text],
304
+ images=image_inputs,
305
+ videos=video_inputs,
306
+ fps=fps,
307
+ padding=True,
308
+ return_tensors="pt",
309
+ **video_kwargs,
310
+ )
311
+ ```
312
+
313
+ </details>
314
+
315
+
316
+ For more usage tips, please refer to our [Github page](https://github.com/DeepLearnXMU/UME-R1).
317
+
318
+
319
+ <!-- ## Citation
320
+ If you find our work helpful, feel free to give us a cite. -->
321
+
322
+ ```
323
  ```