zhibinlan commited on
Commit
9250899
·
1 Parent(s): c8305ac
Files changed (4) hide show
  1. .gitattributes +2 -0
  2. README.md +320 -0
  3. figures/main_result.png +3 -0
  4. figures/scaling.png +3 -0
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,323 @@
1
  ---
 
 
 
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
  license: apache-2.0
6
+ pipeline_tag: image-text-to-text
7
+ tags:
8
+ - Sentence Similarity
9
+ - Embedding
10
+ - zero-shot-image-classification
11
+ - video-text-to-text
12
  ---
13
+
14
+ # UME-R1-7B
15
+
16
+ ## Model Summary
17
+
18
+ The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling.
19
+
20
+ - **Repository:** [UME-R1](https://github.com/DeepLearnXMU/UME-R1)
21
+ - **Paper:** [UME-R1]()
22
+
23
+ ## Train/Eval Data
24
+ - Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train
25
+ - Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2
26
+
27
+
28
+
29
+
30
+ ## Model Performance
31
+ UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone.
32
+
33
+
34
+ <img src="./figures/main_result.png" alt="MMEB-V2" width="1200" height="auto">
35
+ <!-- ![MMEB-V2](./figures/main_result.png) -->
36
+
37
+ In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling.
38
+
39
+ <img src="./figures/scaling.png" alt="pass@k" width="1200" height="auto">
40
+
41
+ ### Quick Start
42
+
43
+ First clone our github
44
+ ```bash
45
+ git clone https://github.com/DeepLearnXMU/UME-R1
46
+ cd UME-R1
47
+ bash setup.sh
48
+ ```
49
+
50
+ Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers.
51
+
52
+ Example of obtaining generative embeddings:
53
+
54
+ ```python
55
+ from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
56
+ from qwen_vl_utils import process_vision_info
57
+ import torch
58
+
59
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
60
+ "zhibinlan/UME-R1-2B",
61
+ torch_dtype=torch.bfloat16,
62
+ attn_implementation="flash_attention_2",
63
+ device_map="cuda:0",
64
+ )
65
+
66
+ processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-2B")
67
+
68
+ prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings.
69
+ First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence.
70
+ Finally, use the <gen_emb> tag to represent the entire input.'''
71
+
72
+
73
+
74
+ messages = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {
79
+ "type": "image",
80
+ "image": "assets/example.jpg",
81
+ },
82
+ {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n" + prompt},
83
+ ],
84
+ }
85
+ ]
86
+
87
+ # Preparation for inference
88
+ text = processor.apply_chat_template(
89
+ messages, tokenize=False, add_generation_prompt=True
90
+ )
91
+
92
+ image_inputs, video_inputs = process_vision_info(messages)
93
+ inputs = processor(
94
+ text=[text],
95
+ images=image_inputs,
96
+ videos=video_inputs,
97
+ padding=True,
98
+ return_tensors="pt",
99
+ )
100
+ inputs = inputs.to(model.device)
101
+
102
+ # Inference: Generation of the output
103
+ generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True)
104
+ # Post-process the output
105
+ generated_ids = generated_output.sequences
106
+ hidden_states = generated_output.hidden_states
107
+
108
+ generated_ids_trimmed = [
109
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
+ ]
111
+
112
+ def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
113
+
114
+ embedding_idx = []
115
+ for i, out_ids in enumerate(generated_ids_trimmed):
116
+ embed_exist = False
117
+ for j in range(len(out_ids) - 1, -1, -1):
118
+ if out_ids[j] == EMBEDDING_TOKEN_ID:
119
+ embedding_idx.append(j + 1)
120
+ embed_exist = True
121
+ break
122
+ if not embed_exist:
123
+ embedding_idx.append(-1)
124
+
125
+ return embedding_idx
126
+
127
+ def normalize_reps(reps):
128
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
129
+ return reps
130
+
131
+ # Get the last hidden state of the <gen_emb> token
132
+ embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()["<gen_emb>"])
133
+ embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1)
134
+
135
+ # Normalize the representations
136
+ embedding_reps = normalize_reps(embedding_reps)
137
+
138
+ output_text = processor.batch_decode(
139
+ generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
140
+ )
141
+ ```
142
+
143
+ <details>
144
+ <summary>Example of obtaining discriminative embeddings</summary>
145
+
146
+ ```python
147
+ from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
148
+ from qwen_vl_utils import process_vision_info
149
+ import torch
150
+
151
+ pretrained_path = "release/UME-R1-2B"
152
+
153
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
154
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
155
+ pretrained_path,
156
+ torch_dtype=torch.bfloat16,
157
+ attn_implementation="flash_attention_2",
158
+ device_map="cuda:0",
159
+ )
160
+
161
+ # default processor
162
+ processor = AutoProcessor.from_pretrained(pretrained_path)
163
+
164
+ messages = [
165
+ {
166
+ "role": "user",
167
+ "content": [
168
+ {
169
+ "type": "image",
170
+ "image": "UME-R1/assets/example.jpg",
171
+ },
172
+ {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"},
173
+ ],
174
+ }
175
+ ]
176
+
177
+ # Preparation for inference
178
+ text = processor.apply_chat_template(
179
+ messages, tokenize=False, add_generation_prompt=True
180
+ )
181
+
182
+ image_inputs, video_inputs = process_vision_info(messages)
183
+ inputs = processor(
184
+ text=[text],
185
+ images=image_inputs,
186
+ videos=video_inputs,
187
+ padding=True,
188
+ return_tensors="pt",
189
+ )
190
+ inputs = inputs.to(model.device)
191
+
192
+ def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
193
+
194
+ embedding_idx = []
195
+ # Search from the last token forward
196
+ for i, out_ids in enumerate(generated_ids_trimmed):
197
+ embed_exist = False
198
+ for j in range(len(out_ids) - 1, -1, -1):
199
+ if out_ids[j] == EMBEDDING_TOKEN_ID:
200
+ embedding_idx.append(j)
201
+ embed_exist = True
202
+ break
203
+ if not embed_exist:
204
+ embedding_idx.append(-1)
205
+
206
+ return embedding_idx
207
+
208
+ def normalize_reps(reps):
209
+ # Normalize the representations
210
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
211
+ return reps
212
+
213
+ output = model(**inputs, output_hidden_states=True, return_dict=True)
214
+ hidden_states = output.hidden_states[-1][0]
215
+ # print("output.hidden_states shape: ", hidden_states.shape)
216
+ embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()["<disc_emb>"])
217
+
218
+ # Get the last hidden state of the <gen_emb> token
219
+ embedding_reps = hidden_states[embedding_idx[0]]
220
+
221
+ # Normalize the representations
222
+ embedding_reps = normalize_reps(embedding_reps)
223
+ ```
224
+
225
+ </details>
226
+
227
+ <details>
228
+ <summary>Multi image inference</summary>
229
+
230
+ ```python
231
+ # Messages containing multiple images and a text query
232
+ messages = [
233
+ {
234
+ "role": "user",
235
+ "content": [
236
+ {"type": "image", "image": "file:///path/to/image1.jpg"},
237
+ {"type": "image", "image": "file:///path/to/image2.jpg"},
238
+ {"type": "text", "text": "Represent the given images."},
239
+ ],
240
+ }
241
+ ]
242
+ ```
243
+
244
+ </details>
245
+
246
+ <details>
247
+ <summary>Video inference</summary>
248
+
249
+ ```python
250
+ # Messages containing a images list as a video and a text query
251
+ messages = [
252
+ {
253
+ "role": "user",
254
+ "content": [
255
+ {
256
+ "type": "video",
257
+ "video": [
258
+ "file:///path/to/frame1.jpg",
259
+ "file:///path/to/frame2.jpg",
260
+ "file:///path/to/frame3.jpg",
261
+ "file:///path/to/frame4.jpg",
262
+ ],
263
+ },
264
+ {"type": "text", "text": "Represent this video."},
265
+ ],
266
+ }
267
+ ]
268
+
269
+ # Messages containing a local video path and a text query
270
+ messages = [
271
+ {
272
+ "role": "user",
273
+ "content": [
274
+ {
275
+ "type": "video",
276
+ "video": "file:///path/to/video1.mp4",
277
+ "max_pixels": 360 * 420,
278
+ "fps": 1.0,
279
+ },
280
+ {"type": "text", "text": "Represent this video."},
281
+ ],
282
+ }
283
+ ]
284
+
285
+ # Messages containing a video url and a text query
286
+ messages = [
287
+ {
288
+ "role": "user",
289
+ "content": [
290
+ {
291
+ "type": "video",
292
+ "video": "https://path/to/video.mp4",
293
+ "min_pixels": 4 * 28 * 28,
294
+ "max_pixels": 256 * 28 * 28,
295
+ "total_pixels": 20480 * 28 * 28,
296
+ },
297
+ {"type": "text", "text": "Represent this video."},
298
+ ],
299
+ }
300
+ ]
301
+ image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
302
+ inputs = processor(
303
+ text=[text],
304
+ images=image_inputs,
305
+ videos=video_inputs,
306
+ fps=fps,
307
+ padding=True,
308
+ return_tensors="pt",
309
+ **video_kwargs,
310
+ )
311
+ ```
312
+
313
+ </details>
314
+
315
+
316
+ For more usage tips, please refer to our [Github page](https://github.com/DeepLearnXMU/UME-R1).
317
+
318
+
319
+ ## Citation
320
+ If you find our work helpful, feel free to give us a cite.
321
+
322
+ ```
323
+ ```
figures/main_result.png ADDED

Git LFS Details

  • SHA256: 6e1f7225c42487f587fec7fd03db9e1818a6193db83ab196bf49fa8ab8da17af
  • Pointer size: 131 Bytes
  • Size of remote file: 271 kB
figures/scaling.png ADDED

Git LFS Details

  • SHA256: a1ebb391fd4aa2e9c44718b464c2a80051de8fdbd34c0fdab159bc390add8133
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB