hezhihui
commited on
Commit
·
6d7ce17
1
Parent(s):
b352d20
multi-images
Browse files- modeling_minicpmv.py +26 -5
modeling_minicpmv.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
| 3 |
import torch
|
| 4 |
from threading import Thread
|
| 5 |
from copy import deepcopy
|
|
|
|
| 6 |
from torchvision import transforms
|
| 7 |
from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
|
| 8 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
|
@@ -291,17 +292,37 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 291 |
msgs = json.loads(msgs)
|
| 292 |
copy_msgs = deepcopy(msgs)
|
| 293 |
|
| 294 |
-
assert len(msgs) > 0,
|
| 295 |
-
assert sampling or not stream,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
if image is not None and isinstance(msgs[0]['content'], str):
|
| 298 |
-
copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
|
| 299 |
if system_prompt:
|
| 300 |
sys_msg = {'role': 'system', 'content': system_prompt}
|
| 301 |
copy_msgs = [sys_msg] + copy_msgs
|
| 302 |
|
| 303 |
prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
|
| 304 |
-
inputs = processor(prompt,
|
| 305 |
|
| 306 |
if sampling:
|
| 307 |
generation_config = {
|
|
|
|
| 3 |
import torch
|
| 4 |
from threading import Thread
|
| 5 |
from copy import deepcopy
|
| 6 |
+
from PIL import Image
|
| 7 |
from torchvision import transforms
|
| 8 |
from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
|
| 9 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
|
|
|
| 292 |
msgs = json.loads(msgs)
|
| 293 |
copy_msgs = deepcopy(msgs)
|
| 294 |
|
| 295 |
+
assert len(msgs) > 0, "msgs is empty"
|
| 296 |
+
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
| 297 |
+
|
| 298 |
+
if image is not None and isinstance(copy_msgs[0]["content"], str):
|
| 299 |
+
# copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
|
| 300 |
+
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
|
| 301 |
+
|
| 302 |
+
images = []
|
| 303 |
+
for i, msg in enumerate(copy_msgs):
|
| 304 |
+
role = msg["role"]
|
| 305 |
+
content = msg["content"]
|
| 306 |
+
assert role in ["user", "assistant"]
|
| 307 |
+
if i == 0:
|
| 308 |
+
assert role == "user", "The role of first msg should be user"
|
| 309 |
+
if isinstance(content, str):
|
| 310 |
+
content = [content]
|
| 311 |
+
cur_msgs = []
|
| 312 |
+
for c in content:
|
| 313 |
+
if isinstance(c, Image.Image):
|
| 314 |
+
images.append(c)
|
| 315 |
+
cur_msgs.append("(<image>./</image>)")
|
| 316 |
+
elif isinstance(c, str):
|
| 317 |
+
cur_msgs.append(c)
|
| 318 |
+
msg["content"] = "\n".join(cur_msgs)
|
| 319 |
|
|
|
|
|
|
|
| 320 |
if system_prompt:
|
| 321 |
sys_msg = {'role': 'system', 'content': system_prompt}
|
| 322 |
copy_msgs = [sys_msg] + copy_msgs
|
| 323 |
|
| 324 |
prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
|
| 325 |
+
inputs = processor(prompt, images, return_tensors="pt", max_length=max_inp_length).to(self.device)
|
| 326 |
|
| 327 |
if sampling:
|
| 328 |
generation_config = {
|