import copy import os from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from torch.nn import CrossEntropyLoss from torch.nn.utils.rnn import pad_sequence import warnings from torch import Tensor, nn from transformers import ( PreTrainedModel, PreTrainedTokenizer, Blip2VisionModel, Blip2QFormerModel, Blip2Model, Blip2PreTrainedModel, Blip2ForConditionalGeneration, GenerationConfig, ) from transformers.models.blip_2.modeling_blip_2 import ( Blip2ForConditionalGenerationModelOutput, ) from transformers.utils import logging from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from .modeling_chatglm import ( ChatGLMForConditionalGeneration, InvalidScoreLogitsProcessor, ) from .configuration_blip2chatglm import Blip2ChatGLMConfig logger = logging.get_logger(__name__) class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration): config_class = Blip2ChatGLMConfig def __init__(self, config: Blip2ChatGLMConfig): Blip2PreTrainedModel.__init__(self, config) # NOTE: we only initialize Blip2PreTrainedModel # directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel self.vision_model = Blip2VisionModel(config.vision_config) self.query_tokens = nn.Parameter( torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size) ) self.qformer = Blip2QFormerModel(config.qformer_config) self.language_projection = nn.Linear( config.qformer_config.hidden_size, config.text_config.hidden_size ) self.language_model = ChatGLMForConditionalGeneration(config.text_config) # Initialize weights and apply final processing # self.post_init() def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"): if vision_encoder_dtype == "fp32": self.vision_model = self.vision_model.float() elif vision_encoder_dtype == "fp16": self.vision_model = self.vision_model.half() else: raise NotImplementedError( f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}" ) if lm_dtype == "fp32": self.language_model = self.language_model.float() elif lm_dtype == "fp16": self.language_model = self.language_model.half() elif lm_dtype == "int4": self.language_model = self.language_model.half().quantize(4) elif lm_dtype == "int8": self.language_model = self.language_model.half().quantize(8) else: raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}") def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.FloatTensor, image_slot_offset: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: """_summary_ Args: pixel_values (torch.FloatTensor): _description_ input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None. attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None. output_attentions (Optional[bool], optional): _description_. Defaults to None. output_hidden_states (Optional[bool], optional): _description_. Defaults to None. labels (Optional[torch.LongTensor], optional): _description_. Defaults to None. return_dict (Optional[bool], optional): _description_. Defaults to None. Returns: Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_ """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # step 1: forward the images through the vision encoder, # to get image embeddings of shape (batch_size, seq_len, hidden_size) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[0] # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones( image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device ) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) query_output = query_outputs[0] # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_slot_offset is None: # image as prefix # update data to avoid inplace operation of leaf Variable inputs_embeds.data[ :, : self.config.num_query_tokens, : ] = language_model_inputs else: for i, offset in enumerate(image_slot_offset): inputs_embeds.data[ i, offset : offset + self.config.num_query_tokens, : ] = language_model_inputs[i] outputs = self.language_model( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits if return_dict else outputs[0] loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="mean") loss = loss_fct( shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1), ) if not return_dict: output = (logits, vision_outputs, query_outputs, outputs) return ((loss,) + output) if loss is not None else output return Blip2ForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, qformer_outputs=query_outputs, language_model_outputs=outputs, ) def prepare_inputs_for_chat( self, tokenizer: PreTrainedTokenizer, batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]], max_length: int, user_role: str = "问", bot_role: str = "答", ): device = self.device nvtokens = self.config.num_query_tokens # 1. Prepare token ids all_images = [] all_image_slots = [] all_input_ids = [] for messages in batch_messages: images = [] image_slots = [] input_ids = [] round_roles = [set()] for role, qtext, qimgs in messages: if role in round_roles[-1]: # a new round (not the first round) input_ids += tokenizer( f"\n[Round {len(round_roles)}]\n{role}:", add_special_tokens=False, ).input_ids round_roles.append({role}) else: round_roles[-1].add(role) input_ids += tokenizer( # For first role, no new line f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False ).input_ids cur_index = 0 for qimg, img_idx in qimgs: if img_idx > cur_index: input_ids += tokenizer( qtext[cur_index:img_idx], add_special_tokens=False ).input_ids cur_index = img_idx # image slot, embedding will be replaced by image embeddings image_slots.append(len(input_ids)) input_ids += [tokenizer.unk_token_id] * nvtokens images.append(qimg) input_ids += tokenizer( qtext[cur_index:], add_special_tokens=False ).input_ids if len(round_roles) == 1: # only 1 round if len(round_roles[0]) == 1 and user_role in round_roles[0]: # only user role input_ids += tokenizer("").input_ids else: input_ids += tokenizer(f"\n{bot_role}:").input_ids else: # add tag for round 0 input_ids = ( tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids + input_ids ) input_ids += tokenizer(f"\n{bot_role}:").input_ids if len(input_ids) >= max_length: image_slots_after_truncate = [] images_after_truncate = [] truncate_index = len(input_ids) - max_length for image_slot, image in zip(image_slots, images): # truncate from left if len(input_ids) - image_slot < max_length: image_slots_after_truncate.append(image_slot) images_after_truncate.append(image) elif len(input_ids) - (image_slot + nvtokens) < max_length: # in-contact image slot is not allowed truncate_index = max(truncate_index, image_slot + nvtokens) for i, image_slot in enumerate(image_slots_after_truncate): image_slots_after_truncate[i] = image_slot - truncate_index input_ids = input_ids[truncate_index:] image_slots = image_slots_after_truncate images = images_after_truncate # print(tokenizer.convert_ids_to_tokens(input_ids)) all_images.extend(images) all_image_slots.append(image_slots) all_input_ids.append(input_ids) # 2. Prepare image embeddings if len(all_images) != 0: vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0)) all_image_embeds = vision_outputs[0] indices_or_sections = [len(chunk) for chunk in all_image_slots] indices_or_sections = np.cumsum(indices_or_sections) all_vtokens = [] # TODO: qformer not batched for image_embeds in torch.tensor_split( all_image_embeds, tuple(indices_or_sections) ): image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( device ) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer.forward( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, ) query_output = query_outputs[0] all_vtokens.append(self.language_projection(query_output)) else: all_vtokens = None # 3. Place image embeddings into slots input_ids = ( torch.ones( (len(all_input_ids), max(len(ids) for ids in all_input_ids)), dtype=torch.long, ) * tokenizer.pad_token_id ) for i, ids in enumerate(all_input_ids): # pad left input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long) input_ids = input_ids.to(device) inputs_embeds = self.language_model.transformer.word_embeddings(input_ids) if all_vtokens is not None: for i, (image_slots, vtokens) in enumerate( zip(all_image_slots, all_vtokens) ): for slot, vimg in zip(image_slots, vtokens): inputs_embeds[i][slot : slot + nvtokens, :] = vimg return input_ids, inputs_embeds @torch.no_grad() def batch_chat( self, tokenizer: PreTrainedTokenizer, batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]], max_length: int = 2048, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, user_role: str = "问", bot_role: str = "答", **kwargs, ): input_ids, inputs_embeds = self.prepare_inputs_for_chat( tokenizer=tokenizer, batch_messages=batch_messages, max_length=max_length, user_role=user_role, bot_role=bot_role, ) logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = { "max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, **kwargs, } outputs = self.language_model.generate( input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs ) responses = [] for i, output in enumerate(outputs.tolist()): output = output[len(input_ids[i]) :] response = tokenizer.decode(output) responses.append(self.language_model.process_response(response)) return responses @torch.no_grad() def stream_chat( self, tokenizer: PreTrainedTokenizer, messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]], num_beams=5, max_length=512, top_p=0.9, do_sample=True, temperature=1, user_role: str = "问", bot_role: str = "答", **kwargs, ): input_ids, inputs_embeds = self.prepare_inputs_for_chat( tokenizer=tokenizer, batch_messages=[messages], max_length=max_length, user_role=user_role, bot_role=bot_role, ) logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) gen_kwargs = { "max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, **kwargs, } for outputs in self.language_model.stream_generate( input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs ): outputs = outputs.tolist()[0][len(input_ids[0]) :] response = tokenizer.decode(outputs) response = self.language_model.process_response(response) yield response