shwu commited on
Commit
51e5ad2
1 Parent(s): 35e0244

feat: better modeling_chatglm

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -34
  2. README.md +61 -61
  3. modeling_blip2chatglm.py +181 -345
  4. modeling_chatglm.py +22 -20
.gitattributes CHANGED
@@ -1,34 +1,34 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,61 +1,61 @@
1
- ---
2
- language:
3
- - zh
4
- - en
5
- tags:
6
- - chatglm
7
- - blip2
8
- ---
9
-
10
- # Model Card for blip2zh-chatglm-6b
11
-
12
- ## Model Details
13
-
14
- ### Model Description
15
-
16
- blip2zh-chatglm-6b是基于blip2训练的中文多模态聊天模型。具有基本的图像理解能力。
17
- 由于blip2的训练方式不会对语言模型进行微调,因此在纯文本对话中的行为可以保持和原始chatglm一致。
18
-
19
- 注意:由于目前模型仅经过blip2两阶段图文对齐预训练,没有包括vqa或者指令微调等具体下游任务的训练,因此依然容易生成不符合预期的内容。
20
-
21
- - **blip2 base model**: [bert-base-chinese](https://huggingface.co/bert-base-chinese)
22
- - **Vision encoder**: eva-clip-vit-g
23
- - **Language model**: [chatglm-6b](https://github.com/THUDM/ChatGLM-6B) at [commit](https://huggingface.co/THUDM/chatglm-6b/commit/9324de70a93207c9a310cf99d5d6261791489691)
24
-
25
- ### Model Sources
26
-
27
- - [**Training Code**](https://github.com/XiPotatonium/LAVIS): blip2训练代码,基于[LAVIS](https://github.com/salesforce/LAVIS)
28
- - [**webui**](https://github.com/XiPotatonium/chatbot-webui): 一个由gradio实现的webui
29
- - [**api**](https://github.com/XiPotatonium/chatbot-api): 一个由fastapi实现的api服务,可以部署在本地,同时也支持一些其他类型的本地可部署语言模型。
30
-
31
- ## Uses
32
-
33
- 模型参数包含了图像编码器,blip2和chatglm-6b。
34
-
35
- 加载模型及推理可以参考[api](https://github.com/XiPotatonium/chatbot-api/blob/main/src/model/blip2chatglm/__init__.py)的实现
36
-
37
- 一些[example](https://github.com/XiPotatonium/chatbot-api/blob/main/examples.ipynb)
38
-
39
- ## Limitations
40
-
41
- 受限于中文数据集,目前图像理解能力依然有限,会产生无关或者错误的内容。
42
- 目前没有引入多轮对话训练以及指令微调。多轮对话可能会受到上下文的干扰。
43
- 并且同样受限于chatglm-6b本身的对话效果。
44
-
45
- ## Training Details
46
-
47
- ### Training Data
48
-
49
- * [laion-2b-chinese](https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset): 我们仅选取了其中clip分数较高的670k图文对并采样了部分数据进行训练。
50
- * [coco-zh](https://github.com/li-xirong/coco-cn)
51
- * [flickr8k-zh](http://lixirong.net/datasets/flickr8kcn)
52
-
53
- ### Training Procedure
54
-
55
- 基于blip2的两阶段训练方法
56
-
57
- ## Demos
58
-
59
- ![](imgs/demo1.png)
60
- ![](imgs/demo2.png)
61
- ![](imgs/demo3.png)
 
1
+ ---
2
+ language:
3
+ - zh
4
+ - en
5
+ tags:
6
+ - chatglm
7
+ - blip2
8
+ ---
9
+
10
+ # Model Card for blip2zh-chatglm-6b
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ blip2zh-chatglm-6b是基于blip2训练的中文多模态聊天模型。具有基本的图像理解能力。
17
+ 由于blip2的训练方式不会对语言模型进行微调,因此在纯文本对话中的行为可以保持和原始chatglm一致。
18
+
19
+ 注意:由于目前模型仅经过blip2两阶段图文对齐预训练,没有包括vqa或者指令微调等具体下游任务的训练,因此依然容易生成不符合预期的内容。
20
+
21
+ - **blip2 base model**: [bert-base-chinese](https://huggingface.co/bert-base-chinese)
22
+ - **Vision encoder**: eva-clip-vit-g
23
+ - **Language model**: [chatglm-6b](https://github.com/THUDM/ChatGLM-6B) at [commit](https://huggingface.co/THUDM/chatglm-6b/commit/9324de70a93207c9a310cf99d5d6261791489691)
24
+
25
+ ### Model Sources
26
+
27
+ - [**Training Code**](https://github.com/XiPotatonium/LAVIS): blip2训练代码,基于[LAVIS](https://github.com/salesforce/LAVIS)
28
+ - [**webui**](https://github.com/XiPotatonium/chatbot-webui): 一个由gradio实现的webui
29
+ - [**api**](https://github.com/XiPotatonium/chatbot-api): 一个由fastapi实现的api服务,可以部署在本地,同时也支持一些其他类型的本地可部署语言模型。
30
+
31
+ ## Uses
32
+
33
+ 模型参数包含了图像编码器,blip2和chatglm-6b。
34
+
35
+ 加载模型及推理可以参考[api](https://github.com/XiPotatonium/chatbot-api/blob/main/src/model/blip2chatglm/__init__.py)的实现
36
+
37
+ 一些[example](https://github.com/XiPotatonium/chatbot-api/blob/main/examples.ipynb)
38
+
39
+ ## Limitations
40
+
41
+ 受限于中文数据集,目前图像理解能力依然有限,会产生无关或者错误的内容。
42
+ 目前没有引入多轮对话训练以及指令微调。多轮对话可能会受到上下文的干扰。
43
+ 并且同样受限于chatglm-6b本身的对话效果。
44
+
45
+ ## Training Details
46
+
47
+ ### Training Data
48
+
49
+ * [laion-2b-chinese](https://huggingface.co/datasets/IDEA-CCNL/laion2B-multi-chinese-subset): 我们仅选取了其中clip分数较高的670k图文对并采样了部分数据进行训练。
50
+ * [coco-zh](https://github.com/li-xirong/coco-cn)
51
+ * [flickr8k-zh](http://lixirong.net/datasets/flickr8kcn)
52
+
53
+ ### Training Procedure
54
+
55
+ 基于blip2的两阶段训练方法
56
+
57
+ ## Demos
58
+
59
+ ![](imgs/demo1.png)
60
+ ![](imgs/demo2.png)
61
+ ![](imgs/demo3.png)
modeling_blip2chatglm.py CHANGED
@@ -1,13 +1,16 @@
1
  import copy
2
  import os
3
  from typing import Callable, List, Optional, Tuple, Union
 
4
  import torch
5
  from torch.nn import CrossEntropyLoss
 
6
  import warnings
7
  from torch import Tensor, nn
8
 
9
  from transformers import (
10
  PreTrainedModel,
 
11
  Blip2VisionModel,
12
  Blip2QFormerModel,
13
  Blip2Model,
@@ -137,12 +140,14 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
137
  if image_slot_offset is None:
138
  # image as prefix
139
  # update data to avoid inplace operation of leaf Variable
140
- inputs_embeds.data[:, : self.config.num_query_tokens, :] = language_model_inputs
 
 
141
  else:
142
  for i, offset in enumerate(image_slot_offset):
143
- inputs_embeds.data[i, offset : offset + self.config.num_query_tokens, :] = (
144
- language_model_inputs[i]
145
- )
146
 
147
  outputs = self.language_model(
148
  input_ids=input_ids,
@@ -181,118 +186,162 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
181
  language_model_outputs=outputs,
182
  )
183
 
184
- @torch.no_grad()
185
- def stream_chat(
186
  self,
187
- tokenizer,
188
- query: Union[str, Tuple[str, torch.Tensor]],
189
- history: List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]] = [],
190
- num_beams=5,
191
- max_length=128,
192
- top_p=0.9,
193
- do_sample=True,
194
- temperature=1,
195
  ):
196
  device = self.device
197
- # 1. Prepare token ids
198
- images = []
199
- image_slots = []
200
-
201
  nvtokens = self.config.num_query_tokens
202
- if history:
203
- input_ids = tokenizer(
204
- f"[Round {len(history)}]\n问:", add_special_tokens=False
205
- ).input_ids
206
- slot_offset = len(input_ids)
207
- if isinstance(query, tuple):
208
- qtext, qimg = query
209
- # image slot, embedding will be replaced by image embeddings
210
- input_ids.extend([tokenizer.unk_token_id] * nvtokens)
211
- else:
212
- qtext = query
213
- qimg = None
214
- input_ids += tokenizer(qtext + f"\n答:").input_ids
215
- if qimg is not None:
216
- images.append(qimg)
217
- image_slots.append(len(input_ids) - slot_offset) # count from backward
218
-
219
- for ri, (q, r) in enumerate(reversed(history)):
220
- if len(input_ids) >= max_length:
221
- break
222
- i = len(history) - ri - 1
223
- cur_input_ids: List[int] = tokenizer(
224
- f"[Round {i}]\n问:", add_special_tokens=False
225
  ).input_ids
226
- slot_offset = len(cur_input_ids)
227
- if isinstance(q, tuple):
228
- qtext, qimg = q
229
  # image slot, embedding will be replaced by image embeddings
230
- cur_input_ids.extend([tokenizer.unk_token_id] * nvtokens)
231
  else:
232
- qtext = q
233
  qimg = None
234
- cur_input_ids += tokenizer(
235
- qtext + f"\n答:{r}\n", add_special_tokens=False
236
- ).input_ids
237
- input_ids = cur_input_ids + input_ids
238
  if qimg is not None:
239
- images.append(qimg)
240
  image_slots.append(
241
  len(input_ids) - slot_offset
242
  ) # count from backward
243
- else:
244
- input_ids = []
245
- if isinstance(query, tuple):
246
- qtext, qimg = query
247
- # image slot, embedding will be replaced by image embeddings
248
- input_ids.extend([tokenizer.unk_token_id] * nvtokens)
249
- else:
250
- qtext = query
251
- qimg = None
252
- input_ids += tokenizer(qtext).input_ids
253
- if qimg is not None:
254
- images.append(qimg)
255
- image_slots.append(len(input_ids)) # count from backward
256
-
257
- if len(input_ids) >= max_length:
258
- # truncate
259
- if image_slots[-1] > max_length and image_slots[-1] - nvtokens < max_length:
260
- # A non-intact image slot is not allowed
261
- input_ids = input_ids[-(image_slots[-1] - nvtokens) :]
 
 
 
 
 
 
262
  else:
263
- input_ids = input_ids[-max_length:]
264
- if image_slots[-1] > max_length:
265
- image_slots.pop()
266
- images.pop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  # 2. Prepare image embeddings
269
- if len(images) != 0:
270
- image = torch.cat(list(images), dim=0)
271
- vision_outputs = self.vision_model.forward(image)
272
- image_embeds = vision_outputs[0]
273
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
274
- device
275
- )
 
 
 
 
 
 
276
 
277
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
278
- query_outputs = self.qformer.forward(
279
- query_embeds=query_tokens,
280
- encoder_hidden_states=image_embeds,
281
- encoder_attention_mask=image_atts,
282
- )
283
- query_output = query_outputs[0]
284
 
285
- vtokens = self.language_projection(query_output)
286
  else:
287
- vtokens = []
288
 
289
  # 3. Place image embeddings into slots
290
- input_ids = torch.as_tensor(input_ids, dtype=torch.long).to(device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
291
  inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
292
- for slot, vimg in zip(image_slots, vtokens):
293
- inputs_embeds[0][-slot : -slot + nvtokens, :] = vimg
 
294
 
295
- logits_processor = LogitsProcessorList()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  logits_processor.append(InvalidScoreLogitsProcessor())
297
  gen_kwargs = {
298
  "max_length": max_length,
@@ -301,265 +350,52 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
301
  "top_p": top_p,
302
  "temperature": temperature,
303
  "logits_processor": logits_processor,
 
304
  }
305
 
306
- for outputs in self.stream_generate(
307
  input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
308
- ):
309
- outputs = outputs.tolist()[0][len(input_ids[0]) :]
310
- response = tokenizer.decode(outputs)
311
- response = self.language_model.process_response(response)
312
- yield response
 
 
313
 
314
  @torch.no_grad()
315
- def stream_generate(
316
  self,
317
- input_ids,
318
- inputs_embeds,
319
- generation_config: Optional[GenerationConfig] = None,
320
- logits_processor: Optional[LogitsProcessorList] = None,
321
- stopping_criteria: Optional[StoppingCriteriaList] = None,
322
- prefix_allowed_tokens_fn: Optional[
323
- Callable[[int, torch.Tensor], List[int]]
324
- ] = None,
325
  **kwargs,
326
  ):
327
- """slightly modified from chatglm implementation to support inputs_embeds
328
-
329
- Args:
330
- input_ids (_type_): _description_
331
- inputs_embeds (_type_): _description_
332
- generation_config (Optional[GenerationConfig], optional): _description_. Defaults to None.
333
- logits_processor (Optional[LogitsProcessorList], optional): _description_. Defaults to None.
334
- stopping_criteria (Optional[StoppingCriteriaList], optional): _description_. Defaults to None.
335
- prefix_allowed_tokens_fn (Optional[ Callable[[int, torch.Tensor], List[int]] ], optional): _description_. Defaults to None.
336
-
337
- Yields:
338
- _type_: _description_
339
- """
340
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
341
-
342
- if generation_config is None:
343
- generation_config = self.language_model.generation_config
344
- generation_config = copy.deepcopy(generation_config)
345
- model_kwargs = generation_config.update(**kwargs)
346
- bos_token_id, eos_token_id = (
347
- generation_config.bos_token_id,
348
- generation_config.eos_token_id,
349
  )
350
 
351
- if isinstance(eos_token_id, int):
352
- eos_token_id = [eos_token_id]
353
-
354
- has_default_max_length = (
355
- kwargs.get("max_length") is None
356
- and generation_config.max_length is not None
357
- )
358
- if has_default_max_length and generation_config.max_new_tokens is None:
359
- warnings.warn(
360
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
361
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
362
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
363
- UserWarning,
364
- )
365
- elif generation_config.max_new_tokens is not None:
366
- generation_config.max_length = (
367
- generation_config.max_new_tokens + input_ids_seq_length
368
- )
369
- if not has_default_max_length:
370
- logger.warn(
371
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
372
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
373
- "Please refer to the documentation for more information. "
374
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
375
- UserWarning,
376
- )
377
-
378
- if input_ids_seq_length >= generation_config.max_length:
379
- input_ids_string = (
380
- "decoder_input_ids"
381
- if self.language_model.config.is_encoder_decoder
382
- else "input_ids"
383
- )
384
- logger.warning(
385
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
386
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
387
- " increasing `max_new_tokens`."
388
- )
389
-
390
- # 2. Set generation parameters if not already defined
391
- logits_processor = (
392
- logits_processor if logits_processor is not None else LogitsProcessorList()
393
- )
394
- stopping_criteria = (
395
- stopping_criteria
396
- if stopping_criteria is not None
397
- else StoppingCriteriaList()
398
- )
399
-
400
- logits_processor = self.language_model._get_logits_processor(
401
- generation_config=generation_config,
402
- input_ids_seq_length=input_ids_seq_length,
403
- encoder_input_ids=input_ids,
404
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
405
- logits_processor=logits_processor,
406
- )
407
-
408
- stopping_criteria = self.language_model._get_stopping_criteria(
409
- generation_config=generation_config, stopping_criteria=stopping_criteria
410
- )
411
- logits_warper = self.language_model._get_logits_warper(generation_config)
412
-
413
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
414
- scores = None
415
- while True:
416
- model_inputs = self.prepare_inputs_for_generation(
417
- input_ids, inputs_embeds=inputs_embeds, **model_kwargs
418
- )
419
- # forward pass to get next token
420
- outputs = self.language_model(
421
- **model_inputs,
422
- return_dict=True,
423
- output_attentions=False,
424
- output_hidden_states=False,
425
- )
426
-
427
- next_token_logits = outputs.logits[:, -1, :]
428
-
429
- # pre-process distribution
430
- next_token_scores = logits_processor(input_ids, next_token_logits)
431
- next_token_scores = logits_warper(input_ids, next_token_scores)
432
-
433
- # sample
434
- probs = nn.functional.softmax(next_token_scores, dim=-1)
435
- if generation_config.do_sample:
436
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
437
- else:
438
- next_tokens = torch.argmax(probs, dim=-1)
439
-
440
- # update generated ids, model inputs, and length for next step
441
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
442
- inputs_embeds = torch.cat(
443
- [
444
- inputs_embeds,
445
- self.language_model.get_input_embeddings()(next_tokens)[:, None, :],
446
- ],
447
- dim=1,
448
- )
449
- model_kwargs = self.language_model._update_model_kwargs_for_generation(
450
- outputs,
451
- model_kwargs,
452
- is_encoder_decoder=self.language_model.config.is_encoder_decoder,
453
- )
454
- unfinished_sequences = unfinished_sequences.mul(
455
- (sum(next_tokens != i for i in eos_token_id)).long()
456
- )
457
-
458
- # stop when each sentence is finished, or if we exceed the maximum length
459
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
460
- break
461
- yield input_ids
462
-
463
- def prepare_inputs_for_generation(
464
- self,
465
- input_ids: torch.LongTensor,
466
- inputs_embeds: Optional[torch.Tensor] = None,
467
- past: Optional[torch.Tensor] = None,
468
- past_key_values: Optional[torch.Tensor] = None,
469
- attention_mask: Optional[torch.Tensor] = None,
470
- position_ids: Optional[torch.Tensor] = None,
471
- **kwargs,
472
- ) -> dict:
473
- """slightly modified from chatglm implementation to support inputs_embeds
474
-
475
- Args:
476
- input_ids (torch.LongTensor): _description_
477
- inputs_embeds (Optional[torch.Tensor], optional): _description_. Defaults to None.
478
- past (Optional[torch.Tensor], optional): _description_. Defaults to None.
479
- past_key_values (Optional[torch.Tensor], optional): _description_. Defaults to None.
480
- attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None.
481
- position_ids (Optional[torch.Tensor], optional): _description_. Defaults to None.
482
-
483
- Returns:
484
- dict: _description_
485
- """
486
- batch_size, seq_length = input_ids.shape
487
- MASK, gMASK = self.language_model.config.mask_token_id, self.language_model.config.gmask_token_id
488
- seqs = input_ids.tolist()
489
- mask_positions, use_gmasks = [], []
490
- for seq in seqs:
491
- mask_token = gMASK if gMASK in seq else MASK
492
- use_gmask = mask_token == gMASK
493
- mask_positions.append(seq.index(mask_token))
494
- use_gmasks.append(use_gmask)
495
-
496
- # only last token for input_ids if past is not None
497
- if past is not None or past_key_values is not None:
498
- last_token = input_ids[:, -1].unsqueeze(-1)
499
- if attention_mask is not None and attention_mask.dtype == torch.bool:
500
- attention_mask = attention_mask[:, :, -1:]
501
- else:
502
- attention_mask = None
503
- if position_ids is not None:
504
- position_ids = position_ids[..., -1:]
505
- else:
506
- context_lengths = [seq.index(self.language_model.config.bos_token_id) for seq in seqs]
507
- if self.language_model.position_encoding_2d:
508
- position_ids = torch.tensor(
509
- [
510
- [mask_position, seq_length - context_length]
511
- for mask_position, context_length in zip(
512
- mask_positions, context_lengths
513
- )
514
- ],
515
- dtype=torch.long,
516
- device=input_ids.device,
517
- ).unsqueeze(-1)
518
- else:
519
- position_ids = torch.tensor(
520
- [mask_position for mask_position in mask_positions],
521
- dtype=torch.long,
522
- device=input_ids.device,
523
- ).unsqueeze(-1)
524
-
525
- if past is None:
526
- past = past_key_values
527
- return {
528
- "input_ids": last_token,
529
- "past_key_values": past,
530
- "position_ids": position_ids,
531
- "attention_mask": attention_mask,
532
- }
533
- else:
534
- if attention_mask is not None and attention_mask.dtype != torch.bool:
535
- logger.warning_once(
536
- f"The dtype of attention mask ({attention_mask.dtype}) is not bool"
537
- )
538
- attention_mask = None
539
- if attention_mask is None:
540
- attention_mask = self.language_model.get_masks(input_ids, device=input_ids.device)
541
- if position_ids is None:
542
- position_ids = self.language_model.get_position_ids(
543
- input_ids,
544
- device=input_ids.device,
545
- mask_positions=mask_positions,
546
- use_gmasks=use_gmasks,
547
- )
548
 
549
- if inputs_embeds is not None:
550
- assert input_ids.size(1) == inputs_embeds.size(
551
- 1
552
- ), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
553
- return {
554
- "inputs_embeds": inputs_embeds,
555
- "past_key_values": past,
556
- "position_ids": position_ids,
557
- "attention_mask": attention_mask,
558
- }
559
- else:
560
- return {
561
- "input_ids": input_ids,
562
- "past_key_values": past,
563
- "position_ids": position_ids,
564
- "attention_mask": attention_mask,
565
- }
 
1
  import copy
2
  import os
3
  from typing import Callable, List, Optional, Tuple, Union
4
+ import numpy as np
5
  import torch
6
  from torch.nn import CrossEntropyLoss
7
+ from torch.nn.utils.rnn import pad_sequence
8
  import warnings
9
  from torch import Tensor, nn
10
 
11
  from transformers import (
12
  PreTrainedModel,
13
+ PreTrainedTokenizer,
14
  Blip2VisionModel,
15
  Blip2QFormerModel,
16
  Blip2Model,
 
140
  if image_slot_offset is None:
141
  # image as prefix
142
  # update data to avoid inplace operation of leaf Variable
143
+ inputs_embeds.data[
144
+ :, : self.config.num_query_tokens, :
145
+ ] = language_model_inputs
146
  else:
147
  for i, offset in enumerate(image_slot_offset):
148
+ inputs_embeds.data[
149
+ i, offset : offset + self.config.num_query_tokens, :
150
+ ] = language_model_inputs[i]
151
 
152
  outputs = self.language_model(
153
  input_ids=input_ids,
 
186
  language_model_outputs=outputs,
187
  )
188
 
189
+ def prepare_inputs_for_chat(
 
190
  self,
191
+ tokenizer: PreTrainedTokenizer,
192
+ queries: List[Union[str, Tuple[str, torch.Tensor]]],
193
+ histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
194
+ max_length: int,
 
 
 
 
195
  ):
196
  device = self.device
 
 
 
 
197
  nvtokens = self.config.num_query_tokens
198
+ # 1. Prepare token ids
199
+ all_images = []
200
+ all_image_slots = []
201
+ all_input_ids = []
202
+ for query, history in zip(queries, histories):
203
+ image_slots = []
204
+
205
+ if history:
206
+ input_ids = tokenizer(
207
+ f"[Round {len(history)}]\n问:", add_special_tokens=False
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  ).input_ids
209
+ slot_offset = len(input_ids)
210
+ if isinstance(query, tuple):
211
+ qtext, qimg = query
212
  # image slot, embedding will be replaced by image embeddings
213
+ input_ids.extend([tokenizer.unk_token_id] * nvtokens)
214
  else:
215
+ qtext = query
216
  qimg = None
217
+ input_ids += tokenizer(qtext + f"\n答:").input_ids
 
 
 
218
  if qimg is not None:
219
+ all_images.append(qimg)
220
  image_slots.append(
221
  len(input_ids) - slot_offset
222
  ) # count from backward
223
+
224
+ for ri, (q, r) in enumerate(reversed(history)):
225
+ if len(input_ids) >= max_length:
226
+ break
227
+ i = len(history) - ri - 1
228
+ cur_input_ids: List[int] = tokenizer(
229
+ f"[Round {i}]\n问:", add_special_tokens=False
230
+ ).input_ids
231
+ slot_offset = len(cur_input_ids)
232
+ if isinstance(q, tuple):
233
+ qtext, qimg = q
234
+ # image slot, embedding will be replaced by image embeddings
235
+ cur_input_ids.extend([tokenizer.unk_token_id] * nvtokens)
236
+ else:
237
+ qtext = q
238
+ qimg = None
239
+ cur_input_ids += tokenizer(
240
+ qtext + f"\n答:{r}\n", add_special_tokens=False
241
+ ).input_ids
242
+ input_ids = cur_input_ids + input_ids
243
+ if qimg is not None:
244
+ all_images.append(qimg)
245
+ image_slots.append(
246
+ len(input_ids) - slot_offset
247
+ ) # count from backward
248
  else:
249
+ input_ids = []
250
+ if isinstance(query, tuple):
251
+ qtext, qimg = query
252
+ # image slot, embedding will be replaced by image embeddings
253
+ input_ids.extend([tokenizer.unk_token_id] * nvtokens)
254
+ else:
255
+ qtext = query
256
+ qimg = None
257
+ input_ids += tokenizer(qtext).input_ids
258
+ if qimg is not None:
259
+ all_images.append(qimg)
260
+ image_slots.append(len(input_ids)) # count from backward
261
+
262
+ if len(input_ids) >= max_length:
263
+ # truncate
264
+ if (
265
+ image_slots[-1] > max_length
266
+ and image_slots[-1] - nvtokens < max_length
267
+ ):
268
+ # A non-intact image slot is not allowed
269
+ input_ids = input_ids[-(image_slots[-1] - nvtokens) :]
270
+ else:
271
+ input_ids = input_ids[-max_length:]
272
+ if image_slots[-1] > max_length:
273
+ image_slots.pop()
274
+ all_images.pop()
275
+
276
+ all_image_slots.append(image_slots)
277
+ all_input_ids.append(input_ids)
278
 
279
  # 2. Prepare image embeddings
280
+ if len(all_images) != 0:
281
+ vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0))
282
+ all_image_embeds = vision_outputs[0]
283
+ indices_or_sections = [len(chunk) for chunk in all_image_slots]
284
+ indices_or_sections = np.cumsum(indices_or_sections)
285
+ all_vtokens = []
286
+ # TODO: qformer not batched
287
+ for image_embeds in torch.tensor_split(
288
+ all_image_embeds, tuple(indices_or_sections)
289
+ ):
290
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
291
+ device
292
+ )
293
 
294
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
295
+ query_outputs = self.qformer.forward(
296
+ query_embeds=query_tokens,
297
+ encoder_hidden_states=image_embeds,
298
+ encoder_attention_mask=image_atts,
299
+ )
300
+ query_output = query_outputs[0]
301
 
302
+ all_vtokens.append(self.language_projection(query_output))
303
  else:
304
+ all_vtokens = None
305
 
306
  # 3. Place image embeddings into slots
307
+ input_ids = (
308
+ torch.ones(
309
+ (len(all_input_ids), max(len(ids) for ids in all_input_ids)),
310
+ dtype=torch.long,
311
+ )
312
+ * tokenizer.pad_token_id
313
+ )
314
+ for i, ids in enumerate(all_input_ids):
315
+ # pad left
316
+ input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
317
+ input_ids = input_ids.to(device)
318
  inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
319
+ for i, (image_slots, vtokens) in enumerate(zip(all_image_slots, all_vtokens)):
320
+ for slot, vimg in zip(image_slots, vtokens):
321
+ inputs_embeds[i][-slot : -slot + nvtokens, :] = vimg
322
 
323
+ return input_ids, inputs_embeds
324
+
325
+ @torch.no_grad()
326
+ def batch_chat(
327
+ self,
328
+ tokenizer: PreTrainedTokenizer,
329
+ queries: List[Union[str, Tuple[str, torch.Tensor]]],
330
+ histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
331
+ max_length: int = 2048,
332
+ num_beams=1,
333
+ do_sample=True,
334
+ top_p=0.7,
335
+ temperature=0.95,
336
+ logits_processor=None,
337
+ **kwargs,
338
+ ):
339
+ input_ids, inputs_embeds = self.prepare_inputs_for_chat(
340
+ tokenizer, queries, histories, max_length
341
+ )
342
+
343
+ if logits_processor is None:
344
+ logits_processor = LogitsProcessorList()
345
  logits_processor.append(InvalidScoreLogitsProcessor())
346
  gen_kwargs = {
347
  "max_length": max_length,
 
350
  "top_p": top_p,
351
  "temperature": temperature,
352
  "logits_processor": logits_processor,
353
+ **kwargs,
354
  }
355
 
356
+ outputs = self.language_model.generate(
357
  input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
358
+ )
359
+ responses = []
360
+ for i, output in enumerate(outputs.tolist()):
361
+ output = output[len(input_ids[i]) :]
362
+ response = tokenizer.decode(output)
363
+ responses.append(self.language_model.process_response(response))
364
+ return responses
365
 
366
  @torch.no_grad()
367
+ def stream_chat(
368
  self,
369
+ tokenizer: PreTrainedTokenizer,
370
+ query: Union[str, Tuple[str, torch.Tensor]],
371
+ history: List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]],
372
+ num_beams=5,
373
+ max_length=128,
374
+ top_p=0.9,
375
+ do_sample=True,
376
+ temperature=1,
377
  **kwargs,
378
  ):
379
+ input_ids, inputs_embeds = self.prepare_inputs_for_chat(
380
+ tokenizer, [query], [history], max_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  )
382
 
383
+ logits_processor = LogitsProcessorList()
384
+ logits_processor.append(InvalidScoreLogitsProcessor())
385
+ gen_kwargs = {
386
+ "max_length": max_length,
387
+ "num_beams": num_beams,
388
+ "do_sample": do_sample,
389
+ "top_p": top_p,
390
+ "temperature": temperature,
391
+ "logits_processor": logits_processor,
392
+ **kwargs,
393
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
+ for outputs in self.language_model.stream_generate(
396
+ input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
397
+ ):
398
+ outputs = outputs.tolist()[0][len(input_ids[0]) :]
399
+ response = tokenizer.decode(outputs)
400
+ response = self.language_model.process_response(response)
401
+ yield response
 
 
 
 
 
 
 
 
 
 
modeling_chatglm.py CHANGED
@@ -913,15 +913,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
913
  )
914
  use_cache = False
915
 
916
- # if input_ids is not None and inputs_embeds is not None:
917
- # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
918
- # elif input_ids is not None:
919
- # batch_size, seq_length = input_ids.shape[:2]
920
- # elif inputs_embeds is not None:
921
- # batch_size, seq_length. _ = inputs_embeds.shape[:2]
922
- # else:
923
- # raise ValueError("You have to specify either input_ids or inputs_embeds")
924
-
925
  if input_ids is not None:
926
  batch_size, seq_length = input_ids.shape[:2]
927
  elif inputs_embeds is not None:
@@ -980,11 +971,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
980
  if attention_mask is None:
981
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
982
 
983
- # NOTE: this is a hack to make the code work with the LAVIS training
984
- # else:
985
- # pass
986
- # attention_mask = attention_mask.to(input_ids.device)
987
-
988
  for i, layer in enumerate(self.layers):
989
 
990
  if output_hidden_states:
@@ -1109,11 +1095,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1109
  [position_ids, new_position_id], dim=-1
1110
  )
1111
 
 
 
 
 
1112
  return model_kwargs
1113
 
1114
  def prepare_inputs_for_generation(
1115
  self,
1116
  input_ids: torch.LongTensor,
 
1117
  past: Optional[torch.Tensor] = None,
1118
  past_key_values: Optional[torch.Tensor] = None,
1119
  attention_mask: Optional[torch.Tensor] = None,
@@ -1174,12 +1165,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1174
  use_gmasks=use_gmasks
1175
  )
1176
 
1177
- return {
1178
- "input_ids": input_ids,
1179
- "past_key_values": past,
1180
- "position_ids": position_ids,
1181
- "attention_mask": attention_mask
1182
- }
 
 
 
 
 
 
 
 
 
 
 
1183
 
1184
  def forward(
1185
  self,
 
913
  )
914
  use_cache = False
915
 
 
 
 
 
 
 
 
 
 
916
  if input_ids is not None:
917
  batch_size, seq_length = input_ids.shape[:2]
918
  elif inputs_embeds is not None:
 
971
  if attention_mask is None:
972
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
973
 
 
 
 
 
 
974
  for i, layer in enumerate(self.layers):
975
 
976
  if output_hidden_states:
 
1095
  [position_ids, new_position_id], dim=-1
1096
  )
1097
 
1098
+ # set to None as prepare_inputs_for_generation use past for input embeds
1099
+ if "inputs_embeds" in model_kwargs:
1100
+ model_kwargs["inputs_embeds"] = None
1101
+
1102
  return model_kwargs
1103
 
1104
  def prepare_inputs_for_generation(
1105
  self,
1106
  input_ids: torch.LongTensor,
1107
+ inputs_embeds: Optional[torch.Tensor] = None,
1108
  past: Optional[torch.Tensor] = None,
1109
  past_key_values: Optional[torch.Tensor] = None,
1110
  attention_mask: Optional[torch.Tensor] = None,
 
1165
  use_gmasks=use_gmasks
1166
  )
1167
 
1168
+ if inputs_embeds is not None:
1169
+ assert input_ids.size(1) == inputs_embeds.size(
1170
+ 1
1171
+ ), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
1172
+ return {
1173
+ "inputs_embeds": inputs_embeds,
1174
+ "past_key_values": past,
1175
+ "position_ids": position_ids,
1176
+ "attention_mask": attention_mask,
1177
+ }
1178
+ else:
1179
+ return {
1180
+ "input_ids": input_ids,
1181
+ "past_key_values": past,
1182
+ "position_ids": position_ids,
1183
+ "attention_mask": attention_mask,
1184
+ }
1185
 
1186
  def forward(
1187
  self,