#coding=utf-8 import os import math import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_ from contextlib import suppress import logging from einops import rearrange from peft import LoraConfig, get_peft_model from bigmodelvis import Visualization from .clip_encoder_hd import CLIPVisionTowerHD from .conversation import get_conv_template from .processors_conv import preprocess_qwen from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.generation import GenerationConfig from transformers import Qwen2Config, Qwen2ForCausalLM def get_autocast(precision, cache_enabled=True): if precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16': # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled) elif precision == 'fp16': return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled) elif precision == 'fp32': return suppress else: raise ValueError('not supported precision: {}'.format(precision)) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class InfMLLM_Unified_HD_Chat(PreTrainedModel): def __init__(self, config, debug=False): super().__init__(config) ## Initialize LM model self.lm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=False, trust_remote_code=True) self.media_token_img = "<|image|>" self.media_token_id_img = self.lm_tokenizer(self.media_token_img, return_tensors="pt",add_special_tokens=False).input_ids.item() self.lm_model = Qwen2ForCausalLM(config.lm_config) self.lm_tokenizer.model_max_length = config.max_txt_len self.template_name = config.conv_style self.preprocess_function = preprocess_qwen self.separate = nn.Parameter(torch.zeros([1, 1, 4096])) self.newline = nn.Parameter(torch.zeros([1, 1, 1, 4096])) ## Initialize image encoder self.encoder_img = CLIPVisionTowerHD(config.vision_config, vision_select_layer=-2) self.encoder_img_ln = lambda x: x self.adapter_img = nn.Sequential( nn.Linear(self.encoder_img.num_features*4, self.lm_model.config.hidden_size), nn.GELU(), nn.Linear(self.lm_model.config.hidden_size, self.lm_model.config.hidden_size) ) ## Others self.config = config self.precision = config.precision self._apply_lemmatizer = getattr(config, 'apply_lemmatizer', False) self._lemmatizer = None def forward_encoder_img(self, image): autocast = get_autocast(self.precision, cache_enabled=True) with autocast(): assert isinstance(image, list) image_embeds, image_split = self.encoder_img(image, self.separate, self.newline) image_embeds = self.encoder_img_ln(image_embeds) # [bsz, L, D] image_embeds = self.adapter_img(image_embeds) return image_embeds, image_split def _concat_embeds(self, prompt_embeds, prompt_ids, prompt_masks, labels=None, padding='left'): emb_lens = [len(emb) for emb in prompt_embeds] if len(set(emb_lens)) == 1: if labels is not None: return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0), torch.stack(labels, dim=0) return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0) pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=prompt_embeds[0].device)) prompt_embeds_new = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() prompt_ids_new = torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) * self.lm_tokenizer.pad_token_id prompt_masks_new = torch.zeros([len(emb_lens), max(emb_lens)]).to(prompt_masks[0]) if labels is not None: labels_new = -100 * torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) for i, L in enumerate(emb_lens): if padding == 'left': prompt_embeds_new[i, -L:] = prompt_embeds[i] prompt_ids_new[i, -L:] = prompt_ids[i] prompt_masks_new[i, -L:] = prompt_masks[i] if labels is not None: labels_new[i, -L:] = labels[i] elif padding == 'right': prompt_embeds_new[i, :L] = prompt_embeds[i] prompt_ids_new[i, :L] = prompt_ids[i] prompt_masks_new[i, :L] = prompt_masks[i] if labels is not None: labels_new[i, :L] = labels[i] else: raise ValueError() if labels is not None: return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new return prompt_embeds_new, prompt_ids_new, prompt_masks_new def _insert_media_feat(self, prompt_embeds, prompt_ids, prompt_masks, is_languages, embeds_media, media_token_id, index_list=None, labels=None, len_media=None): ## insert embeds_media into prompt prompt_embeds_new = [] prompt_masks_new = [] prompt_ids_new = [] labels_new = [] device = embeds_media[0].device if index_list is not None: assert len(index_list) == len(embeds_media) assert len(embeds_media) <= len(prompt_embeds) for b in range(len(prompt_embeds)): if (index_list is not None) and (b not in index_list): prompt_embeds_new.append(prompt_embeds[b]) prompt_ids_new.append(prompt_ids[b]) prompt_masks_new.append(prompt_masks[b]) if labels is not None: labels_new.append(labels[b]) else: _idx = prompt_ids[b].tolist().index(media_token_id) if index_list is not None: b_media = index_list.index(b) else: b_media = b if len_media is not None: cur_embeds_media = embeds_media[b_media, :len_media[b_media]] else: cur_embeds_media = embeds_media[b_media] prompt_embeds_new.append(torch.cat([prompt_embeds[b][:_idx+1], cur_embeds_media, prompt_embeds[b][_idx+1:] ], dim=0)) prompt_ids_new.append(torch.cat([prompt_ids[b][:_idx+1], torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), prompt_ids[b][_idx+1:] ], dim=0)) if labels is not None: labels_new.append(torch.cat([labels[b][:_idx+1], torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), labels[b][_idx+1:] ], dim=0)) # if is pure-language sample, mask out image-embeddings prompt_masks_new.append(torch.cat([prompt_masks[b][:_idx+1], torch.zeros(len(cur_embeds_media), dtype=torch.long).to(device) if is_languages[b] else torch.ones(len(cur_embeds_media), dtype=torch.long).to(device), prompt_masks[b][_idx+1:]], dim=0)) if labels is not None: return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new return prompt_embeds_new, prompt_ids_new, prompt_masks_new @torch.no_grad() def generate( self, samples, num_beams=5, max_length=128, min_length=1, top_p=0.9, temperature=0., return_prompts=False ): autocast = get_autocast(self.precision, cache_enabled=True) with autocast(): conversations = samples['conversations'] is_languages = [False] * len(conversations) image_img = samples.get('images', None) index_img = list(range(len(image_img))) device = None special_prefix = ["" for _ in range(len(conversations))] if (self.config.encoder_img is not None) and (image_img is not None) and len(index_img) > 0: for i in index_img: special_prefix[i] = self.media_token_img + special_prefix[i] new_image_img = [] for index in index_img: new_image_img.append(image_img[index]) embeds_img, len_img = self.forward_encoder_img(new_image_img) device = embeds_img.device conv = get_conv_template(self.template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} prompts = [] for i, source in enumerate(conversations): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] per_prefix = special_prefix[i] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' sentence['value'] = sentence['value'].replace("", "").strip() # llava-1.5 add to the begin of the question, remove here if j == 0: sentence['value'] = per_prefix + sentence['value'] conv.append_message(role, sentence['value']) prompts.append(conv.get_prompt()) self.lm_tokenizer.padding_side = "left" if self.lm_tokenizer.bos_token is not None: prompt_text = [self.lm_tokenizer.bos_token + t for t in prompts] else: prompt_text = prompts prompt_tokens = self.lm_tokenizer( prompt_text, return_tensors="pt", padding="longest", truncation=False, add_special_tokens=False ).to(device) prompt_embeds = self.lm_model.get_input_embeddings()(prompt_tokens.input_ids) prompt_masks = prompt_tokens.attention_mask # [bsz, n2] prompt_ids = prompt_tokens.input_ids assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" if embeds_img is not None: prompt_embeds, prompt_ids, prompt_masks = self._insert_media_feat(prompt_embeds=prompt_embeds, prompt_ids=prompt_ids, prompt_masks=prompt_masks, is_languages=is_languages, embeds_media=embeds_img, media_token_id=self.media_token_id_img, index_list=index_img, len_media=len_img) # pad and concat embeds prompt_embeds, prompt_ids, prompt_masks = self._concat_embeds(prompt_embeds, prompt_ids, prompt_masks, padding="left") assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" kwargs = {} kwargs['max_new_tokens'] = max_length outputs = self.lm_model.generate( #input_ids=input_ids, inputs_embeds=prompt_embeds, attention_mask=prompt_masks, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, num_beams=num_beams, eos_token_id=self.lm_tokenizer.eos_token_id, #max_length=max_length, min_length=min_length, **kwargs ) output_text = self.lm_tokenizer.batch_decode( outputs, skip_special_tokens=True ) output_text = [text.strip() for text in output_text] if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): output_text = self._lemmatize(output_text) if return_prompts: return output_text, prompts return output_text def _lemmatize(self, answers): def apply(answer): doc = self.lemmatizer(answer) words = [] for token in doc: if token.pos_ in ["NOUN", "VERB"]: words.append(token.lemma_) else: words.append(token.text) answer = " ".join(words) return answer return [apply(answer) for answer in answers] @property def lemmatizer(self): if self._lemmatizer is None: try: import spacy self._lemmatizer = spacy.load("en_core_web_sm") except ImportError: logging.error( """ Please install spacy and en_core_web_sm model to apply lemmatization. python -m spacy download en_core_web_sm OR import spacy.cli spacy.cli.download("en_core_web_sm") """ ) exit(1) return self._lemmatizer