hezhihui commited on
Commit
6d7ce17
1 Parent(s): b352d20

multi-images

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +26 -5
modeling_minicpmv.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import torch
4
  from threading import Thread
5
  from copy import deepcopy
 
6
  from torchvision import transforms
7
  from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
8
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
@@ -291,17 +292,37 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
291
  msgs = json.loads(msgs)
292
  copy_msgs = deepcopy(msgs)
293
 
294
- assert len(msgs) > 0, 'msgs is empty'
295
- assert sampling or not stream, 'if use stream mode, make sure sampling=True'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- if image is not None and isinstance(msgs[0]['content'], str):
298
- copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
299
  if system_prompt:
300
  sys_msg = {'role': 'system', 'content': system_prompt}
301
  copy_msgs = [sys_msg] + copy_msgs
302
 
303
  prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
304
- inputs = processor(prompt, [image], return_tensors="pt", max_length=max_inp_length).to(self.device)
305
 
306
  if sampling:
307
  generation_config = {
 
3
  import torch
4
  from threading import Thread
5
  from copy import deepcopy
6
+ from PIL import Image
7
  from torchvision import transforms
8
  from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
9
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
 
292
  msgs = json.loads(msgs)
293
  copy_msgs = deepcopy(msgs)
294
 
295
+ assert len(msgs) > 0, "msgs is empty"
296
+ assert sampling or not stream, "if use stream mode, make sure sampling=True"
297
+
298
+ if image is not None and isinstance(copy_msgs[0]["content"], str):
299
+ # copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
300
+ copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
301
+
302
+ images = []
303
+ for i, msg in enumerate(copy_msgs):
304
+ role = msg["role"]
305
+ content = msg["content"]
306
+ assert role in ["user", "assistant"]
307
+ if i == 0:
308
+ assert role == "user", "The role of first msg should be user"
309
+ if isinstance(content, str):
310
+ content = [content]
311
+ cur_msgs = []
312
+ for c in content:
313
+ if isinstance(c, Image.Image):
314
+ images.append(c)
315
+ cur_msgs.append("(<image>./</image>)")
316
+ elif isinstance(c, str):
317
+ cur_msgs.append(c)
318
+ msg["content"] = "\n".join(cur_msgs)
319
 
 
 
320
  if system_prompt:
321
  sys_msg = {'role': 'system', 'content': system_prompt}
322
  copy_msgs = [sys_msg] + copy_msgs
323
 
324
  prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
325
+ inputs = processor(prompt, images, return_tensors="pt", max_length=max_inp_length).to(self.device)
326
 
327
  if sampling:
328
  generation_config = {