shwu commited on
Commit
d6e13c5
1 Parent(s): 51e5ad2

feat: new generation code

Browse files
Files changed (2) hide show
  1. modeling_blip2chatglm.py +90 -82
  2. modeling_chatglm.py +2 -4
modeling_blip2chatglm.py CHANGED
@@ -189,9 +189,10 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
189
  def prepare_inputs_for_chat(
190
  self,
191
  tokenizer: PreTrainedTokenizer,
192
- queries: List[Union[str, Tuple[str, torch.Tensor]]],
193
- histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
194
  max_length: int,
 
 
195
  ):
196
  device = self.device
197
  nvtokens = self.config.num_query_tokens
@@ -199,80 +200,76 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
199
  all_images = []
200
  all_image_slots = []
201
  all_input_ids = []
202
- for query, history in zip(queries, histories):
 
203
  image_slots = []
204
-
205
- if history:
206
- input_ids = tokenizer(
207
- f"[Round {len(history)}]\n问:", add_special_tokens=False
208
- ).input_ids
209
- slot_offset = len(input_ids)
210
- if isinstance(query, tuple):
211
- qtext, qimg = query
212
- # image slot, embedding will be replaced by image embeddings
213
- input_ids.extend([tokenizer.unk_token_id] * nvtokens)
214
- else:
215
- qtext = query
216
- qimg = None
217
- input_ids += tokenizer(qtext + f"\n答:").input_ids
218
- if qimg is not None:
219
- all_images.append(qimg)
220
- image_slots.append(
221
- len(input_ids) - slot_offset
222
- ) # count from backward
223
-
224
- for ri, (q, r) in enumerate(reversed(history)):
225
- if len(input_ids) >= max_length:
226
- break
227
- i = len(history) - ri - 1
228
- cur_input_ids: List[int] = tokenizer(
229
- f"[Round {i}]\n问:", add_special_tokens=False
230
  ).input_ids
231
- slot_offset = len(cur_input_ids)
232
- if isinstance(q, tuple):
233
- qtext, qimg = q
234
- # image slot, embedding will be replaced by image embeddings
235
- cur_input_ids.extend([tokenizer.unk_token_id] * nvtokens)
236
- else:
237
- qtext = q
238
- qimg = None
239
- cur_input_ids += tokenizer(
240
- qtext + f"\n答:{r}\n", add_special_tokens=False
241
  ).input_ids
242
- input_ids = cur_input_ids + input_ids
243
- if qimg is not None:
244
- all_images.append(qimg)
245
- image_slots.append(
246
- len(input_ids) - slot_offset
247
- ) # count from backward
248
- else:
249
- input_ids = []
250
- if isinstance(query, tuple):
251
- qtext, qimg = query
252
  # image slot, embedding will be replaced by image embeddings
253
- input_ids.extend([tokenizer.unk_token_id] * nvtokens)
 
 
 
 
 
 
 
 
 
 
254
  else:
255
- qtext = query
256
- qimg = None
257
- input_ids += tokenizer(qtext).input_ids
258
- if qimg is not None:
259
- all_images.append(qimg)
260
- image_slots.append(len(input_ids)) # count from backward
 
 
261
 
262
  if len(input_ids) >= max_length:
263
- # truncate
264
- if (
265
- image_slots[-1] > max_length
266
- and image_slots[-1] - nvtokens < max_length
267
- ):
268
- # A non-intact image slot is not allowed
269
- input_ids = input_ids[-(image_slots[-1] - nvtokens) :]
270
- else:
271
- input_ids = input_ids[-max_length:]
272
- if image_slots[-1] > max_length:
273
- image_slots.pop()
274
- all_images.pop()
275
-
 
 
 
 
 
 
 
276
  all_image_slots.append(image_slots)
277
  all_input_ids.append(input_ids)
278
 
@@ -316,9 +313,12 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
316
  input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
317
  input_ids = input_ids.to(device)
318
  inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
319
- for i, (image_slots, vtokens) in enumerate(zip(all_image_slots, all_vtokens)):
320
- for slot, vimg in zip(image_slots, vtokens):
321
- inputs_embeds[i][-slot : -slot + nvtokens, :] = vimg
 
 
 
322
 
323
  return input_ids, inputs_embeds
324
 
@@ -326,22 +326,25 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
326
  def batch_chat(
327
  self,
328
  tokenizer: PreTrainedTokenizer,
329
- queries: List[Union[str, Tuple[str, torch.Tensor]]],
330
- histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
331
  max_length: int = 2048,
332
  num_beams=1,
333
  do_sample=True,
334
  top_p=0.7,
335
  temperature=0.95,
336
- logits_processor=None,
 
337
  **kwargs,
338
  ):
339
  input_ids, inputs_embeds = self.prepare_inputs_for_chat(
340
- tokenizer, queries, histories, max_length
 
 
 
 
341
  )
342
 
343
- if logits_processor is None:
344
- logits_processor = LogitsProcessorList()
345
  logits_processor.append(InvalidScoreLogitsProcessor())
346
  gen_kwargs = {
347
  "max_length": max_length,
@@ -367,17 +370,22 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
367
  def stream_chat(
368
  self,
369
  tokenizer: PreTrainedTokenizer,
370
- query: Union[str, Tuple[str, torch.Tensor]],
371
- history: List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]],
372
  num_beams=5,
373
- max_length=128,
374
  top_p=0.9,
375
  do_sample=True,
376
  temperature=1,
 
 
377
  **kwargs,
378
  ):
379
  input_ids, inputs_embeds = self.prepare_inputs_for_chat(
380
- tokenizer, [query], [history], max_length
 
 
 
 
381
  )
382
 
383
  logits_processor = LogitsProcessorList()
 
189
  def prepare_inputs_for_chat(
190
  self,
191
  tokenizer: PreTrainedTokenizer,
192
+ batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
 
193
  max_length: int,
194
+ user_role: str = "问",
195
+ bot_role: str = "答",
196
  ):
197
  device = self.device
198
  nvtokens = self.config.num_query_tokens
 
200
  all_images = []
201
  all_image_slots = []
202
  all_input_ids = []
203
+ for messages in batch_messages:
204
+ images = []
205
  image_slots = []
206
+ input_ids = []
207
+
208
+ round_roles = [set()]
209
+ for role, qtext, qimgs in messages:
210
+ if role in round_roles[-1]:
211
+ # a new round (not the first round)
212
+ input_ids += tokenizer(
213
+ f"\n[Round {len(round_roles)}]\n{role}:",
214
+ add_special_tokens=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  ).input_ids
216
+ round_roles.append({role})
217
+ else:
218
+ round_roles[-1].add(role)
219
+ input_ids += tokenizer(
220
+ # For first role, no new line
221
+ f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False
 
 
 
 
222
  ).input_ids
223
+ cur_index = 0
224
+ for qimg, img_idx in qimgs:
225
+ if img_idx > cur_index:
226
+ input_ids += tokenizer(
227
+ qtext[cur_index:img_idx], add_special_tokens=False
228
+ ).input_ids
229
+ cur_index = img_idx
 
 
 
230
  # image slot, embedding will be replaced by image embeddings
231
+ image_slots.append(len(input_ids))
232
+ input_ids += [tokenizer.unk_token_id] * nvtokens
233
+ images.append(qimg)
234
+ input_ids += tokenizer(
235
+ qtext[cur_index:], add_special_tokens=False
236
+ ).input_ids
237
+ if len(round_roles) == 1:
238
+ # only 1 round
239
+ if len(round_roles[0]) == 1 and user_role in round_roles[0]:
240
+ # only user role
241
+ input_ids += tokenizer("").input_ids
242
  else:
243
+ input_ids += tokenizer(f"\n{bot_role}:").input_ids
244
+ else:
245
+ # add tag for round 0
246
+ input_ids = (
247
+ tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids
248
+ + input_ids
249
+ )
250
+ input_ids += tokenizer(f"\n{bot_role}:").input_ids
251
 
252
  if len(input_ids) >= max_length:
253
+ image_slots_after_truncate = []
254
+ images_after_truncate = []
255
+ truncate_index = len(input_ids) - max_length
256
+ for image_slot, image in zip(image_slots, images):
257
+ # truncate from left
258
+ if len(input_ids) - image_slot < max_length:
259
+ image_slots_after_truncate.append(image_slot)
260
+ images_after_truncate.append(image)
261
+ elif len(input_ids) - (image_slot + nvtokens) < max_length:
262
+ # in-contact image slot is not allowed
263
+ truncate_index = max(truncate_index, image_slot + nvtokens)
264
+ for i, image_slot in enumerate(image_slots_after_truncate):
265
+ image_slots_after_truncate[i] = image_slot - truncate_index
266
+ input_ids = input_ids[truncate_index:]
267
+ image_slots = image_slots_after_truncate
268
+ images = images_after_truncate
269
+
270
+ # print(tokenizer.convert_ids_to_tokens(input_ids))
271
+
272
+ all_images.extend(images)
273
  all_image_slots.append(image_slots)
274
  all_input_ids.append(input_ids)
275
 
 
313
  input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
314
  input_ids = input_ids.to(device)
315
  inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
316
+ if all_vtokens is not None:
317
+ for i, (image_slots, vtokens) in enumerate(
318
+ zip(all_image_slots, all_vtokens)
319
+ ):
320
+ for slot, vimg in zip(image_slots, vtokens):
321
+ inputs_embeds[i][slot : slot + nvtokens, :] = vimg
322
 
323
  return input_ids, inputs_embeds
324
 
 
326
  def batch_chat(
327
  self,
328
  tokenizer: PreTrainedTokenizer,
329
+ batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
 
330
  max_length: int = 2048,
331
  num_beams=1,
332
  do_sample=True,
333
  top_p=0.7,
334
  temperature=0.95,
335
+ user_role: str = "问",
336
+ bot_role: str = "答",
337
  **kwargs,
338
  ):
339
  input_ids, inputs_embeds = self.prepare_inputs_for_chat(
340
+ tokenizer=tokenizer,
341
+ batch_messages=batch_messages,
342
+ max_length=max_length,
343
+ user_role=user_role,
344
+ bot_role=bot_role,
345
  )
346
 
347
+ logits_processor = LogitsProcessorList()
 
348
  logits_processor.append(InvalidScoreLogitsProcessor())
349
  gen_kwargs = {
350
  "max_length": max_length,
 
370
  def stream_chat(
371
  self,
372
  tokenizer: PreTrainedTokenizer,
373
+ messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]],
 
374
  num_beams=5,
375
+ max_length=512,
376
  top_p=0.9,
377
  do_sample=True,
378
  temperature=1,
379
+ user_role: str = "问",
380
+ bot_role: str = "答",
381
  **kwargs,
382
  ):
383
  input_ids, inputs_embeds = self.prepare_inputs_for_chat(
384
+ tokenizer=tokenizer,
385
+ batch_messages=[messages],
386
+ max_length=max_length,
387
+ user_role=user_role,
388
+ bot_role=bot_role,
389
  )
390
 
391
  logits_processor = LogitsProcessorList()
modeling_chatglm.py CHANGED
@@ -970,6 +970,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
970
 
971
  if attention_mask is None:
972
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
 
 
973
 
974
  for i, layer in enumerate(self.layers):
975
 
@@ -1095,10 +1097,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1095
  [position_ids, new_position_id], dim=-1
1096
  )
1097
 
1098
- # set to None as prepare_inputs_for_generation use past for input embeds
1099
- if "inputs_embeds" in model_kwargs:
1100
- model_kwargs["inputs_embeds"] = None
1101
-
1102
  return model_kwargs
1103
 
1104
  def prepare_inputs_for_generation(
 
970
 
971
  if attention_mask is None:
972
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
973
+ else:
974
+ attention_mask = attention_mask.to(hidden_states.device)
975
 
976
  for i, layer in enumerate(self.layers):
977
 
 
1097
  [position_ids, new_position_id], dim=-1
1098
  )
1099
 
 
 
 
 
1100
  return model_kwargs
1101
 
1102
  def prepare_inputs_for_generation(