Manli commited on
Commit
68f33a2
1 Parent(s): 7325a90

Merge modeling files into a single one to avoid relative import

Browse files
Files changed (3) hide show
  1. modeling_xgenmm.py +2030 -40
  2. utils.py +0 -383
  3. vlm.py +0 -1381
modeling_xgenmm.py CHANGED
@@ -1,29 +1,2010 @@
1
- from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
 
 
 
 
 
2
  import torch
3
- import open_clip
 
 
4
  from typing import List, Optional, Tuple, Union
5
- from utils import check_embedding_fns
6
- from vlm import PerceiverResampler, XGenMMPerceiver
7
- from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class XGenMMVisionEncoder(PreTrainedModel):
10
  main_input_name = "pixel_values"
11
  config_class = XGenMMVisionEncoderConfig
12
-
13
  def __init__(self, config: XGenMMVisionEncoderConfig):
14
  super().__init__(config)
15
- if config.model_name != 'google/siglip-so400m-patch14-384':
16
- raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
 
 
17
  self.model = AutoModel.from_pretrained(config.model_name)
18
-
19
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
21
  return self.model.encode_image(pixel_values)
22
-
23
 
24
- # vision tokenizer
 
25
  class XGenMMVisionTokenizer(PreTrainedModel):
26
  config_class = XGenMMVisionTokenizerConfig
 
27
  def __init__(self, config: XGenMMVisionTokenizerConfig):
28
  super().__init__(config)
29
  self.model = PerceiverResampler(
@@ -31,50 +2012,58 @@ class XGenMMVisionTokenizer(PreTrainedModel):
31
  dim_inner=config.lang_embedding_dim,
32
  num_latents=config.num_vis_tokens,
33
  )
34
-
35
- def forward(self,
36
- vision_features: torch.Tensor,
37
- vision_attn_masks: torch.Tensor):
38
  return self.model(vision_features, vision_attn_masks)
39
-
 
40
  # XGenMM model
41
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
42
  config_class = XGenMMConfig
43
-
44
  def __init__(self, config: XGenMMConfig):
45
  super().__init__(config)
46
-
47
  # vision encoder initialization
48
- vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
49
-
50
- # language model initialization
 
 
51
  language_model = AutoModelForCausalLM.from_config(config.text_config)
52
  check_embedding_fns(language_model)
53
  # Update _tied_weights_keys using the base model used.
54
  if language_model._tied_weights_keys is not None:
55
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
56
-
 
 
57
  # vision tokenizer initialization
58
- if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
 
 
 
59
  overwrite = language_model.get_input_embeddings().weight.shape[1]
60
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
61
- print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
62
-
 
 
63
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
64
 
65
  self.vlm = XGenMMPerceiver(
66
  vision_encoder=vision_encoder,
67
  vision_tokenizer=vision_tokenizer,
68
  lang_model=language_model,
69
- initial_tokenizer_len = config.text_config.initial_tokenizer_len,
70
- pad_token_id = config.text_config.pad_token_id,
71
- image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
72
- anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling,
73
- anyres_grids = config.vision_encoder_config.anyres_grids
74
  )
75
  # Initialize weights and apply final processing
76
  self.post_init()
77
-
78
  @torch.no_grad()
79
  def generate(
80
  self,
@@ -82,14 +2071,15 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
82
  input_ids: Optional[torch.LongTensor] = None,
83
  attention_mask: Optional[torch.LongTensor] = None,
84
  **generate_kwargs,
85
- ) -> torch.LongTensor:
86
  self.vlm = self.vlm.eval()
87
  return self.vlm.generate(
88
- vision_x = pixel_values,
89
- lang_x = input_ids,
90
- attention_mask = attention_mask,
91
- **generate_kwargs)
92
-
 
93
  def update_special_tokens(self, tokenizer):
94
  tokenizer.add_special_tokens(
95
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
@@ -97,8 +2087,8 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
97
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
98
  self.vlm.set_special_token_ids(
99
  {
100
- v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
 
101
  }
102
  )
103
  return tokenizer
104
-
 
1
+ import ast
2
+ import math
3
+ from einops import rearrange, repeat
4
+ from einops_exts import rearrange_many
5
+ from einops import rearrange
6
+ from PIL import Image
7
  import torch
8
+ from torch import einsum, nn
9
+
10
+
11
  from typing import List, Optional, Tuple, Union
12
+ import torch.nn.functional as F
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from dataclasses import dataclass
15
+ from transformers import CLIPVisionModel
16
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
17
+ from transformers import PretrainedConfig, logging, CONFIG_MAPPING
18
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
19
+
20
+
21
+ def hasattr_recursive(obj, att):
22
+ """
23
+ Check if obj has nested attribute
24
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
+ """
26
+ if att == "":
27
+ return True
28
+ i = att.find(".")
29
+ if i < 0:
30
+ return hasattr(obj, att)
31
+ else:
32
+ try:
33
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
+ except:
35
+ return False
36
+
37
+
38
+ def getattr_recursive(obj, att):
39
+ """
40
+ Return nested attribute of obj
41
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
42
+ """
43
+ if att == "":
44
+ return obj
45
+ i = att.find(".")
46
+ if i < 0:
47
+ return getattr(obj, att)
48
+ else:
49
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
50
+
51
+
52
+ def setattr_recursive(obj, att, val):
53
+ """
54
+ Set nested attribute of obj
55
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
56
+ """
57
+ if "." in att:
58
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
59
+ setattr(obj, att.split(".")[-1], val)
60
+
61
+
62
+ def check_embedding_fns(lang_model):
63
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
64
+ if not has_fn(lang_model, "get_input_embeddings"):
65
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
66
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
67
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
68
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
69
+ else:
70
+ raise ValueError(
71
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
72
+ )
73
+
74
+ if not has_fn(lang_model, "set_input_embeddings"):
75
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
76
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
77
+ lang_model, "transformer.wte", x
78
+ )
79
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
80
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
81
+ lang_model, "model.decoder.embed_tokens", x
82
+ )
83
+ else:
84
+ raise ValueError(
85
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
86
+ )
87
+
88
+ if not has_fn(lang_model, "get_output_embeddings"):
89
+ if hasattr_recursive(lang_model, "lm_head"):
90
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
91
+ else:
92
+ raise ValueError(
93
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
94
+ )
95
+
96
+ if not has_fn(lang_model, "set_output_embeddings"):
97
+ if hasattr_recursive(lang_model, "lm_head"):
98
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
99
+ lang_model, "lm_head", x
100
+ )
101
+ else:
102
+ raise ValueError(
103
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
104
+ )
105
+
106
+
107
+ def has_fn(model, fn_name):
108
+ """Check if model has a function fn_name"""
109
+ return callable(getattr(model, fn_name, None))
110
+
111
+
112
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
113
+ """
114
+ Stack a list of tensors with padding on one side
115
+ Args:
116
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
117
+ padding_value (int, optional): Value to pad with. Defaults to 0.
118
+ padding_side (str, optional): Side to pad on. Defaults to "right".
119
+ Returns:
120
+ torch.Tensor: Stacked tensors
121
+ """
122
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
123
+ padded_tensors = []
124
+ for tensor in list_of_tensors:
125
+ num_tokens = tensor.size(0)
126
+ if len(tensor.size()) == 1:
127
+ padding = torch.full(
128
+ (max_tokens - num_tokens,),
129
+ padding_value,
130
+ dtype=tensor.dtype,
131
+ device=tensor.device,
132
+ )
133
+ else:
134
+ padding = torch.full(
135
+ (max_tokens - num_tokens, tensor.size(1)),
136
+ padding_value,
137
+ dtype=tensor.dtype,
138
+ device=tensor.device,
139
+ )
140
+ padded_tensor = (
141
+ torch.cat((tensor, padding), dim=0)
142
+ if padding_side == "right"
143
+ else torch.cat((padding, tensor), dim=0)
144
+ )
145
+ padded_tensors.append(padded_tensor)
146
+ return torch.stack(padded_tensors)
147
+
148
+
149
+ def unpad_image(tensor, original_size, keep_original_shape=False):
150
+ """
151
+ Unpads a PyTorch tensor of a padded and resized image.
152
+
153
+ Args:
154
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
155
+ original_size (tuple): The original size of the image (height, width).
156
+
157
+ Returns:
158
+ torch.Tensor: The unpadded image tensor.
159
+ """
160
+ original_width, original_height = original_size
161
+ current_height, current_width = tensor.shape[1:]
162
+
163
+ original_aspect_ratio = original_width / original_height
164
+ current_aspect_ratio = current_width / current_height
165
+
166
+ if original_aspect_ratio > current_aspect_ratio:
167
+ scale_factor = current_width / original_width
168
+ new_height = int(original_height * scale_factor)
169
+ padding = (current_height - new_height) // 2
170
+ if keep_original_shape:
171
+ attention_mask = torch.ones(
172
+ (current_height, current_width), device=tensor.device
173
+ )
174
+ attention_mask[:padding, :] = 0
175
+ attention_mask[current_height - padding :, :] = 0
176
+ return tensor, attention_mask
177
+ else:
178
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
179
+ return unpadded_tensor, None
180
+ else:
181
+ scale_factor = current_height / original_height
182
+ new_width = int(original_width * scale_factor)
183
+ padding = (current_width - new_width) // 2
184
+ if keep_original_shape:
185
+ attention_mask = torch.ones(
186
+ (current_height, current_width), device=tensor.device
187
+ )
188
+ attention_mask[:, :padding] = 0
189
+ attention_mask[:, current_width - padding :] = 0
190
+ return tensor, attention_mask
191
+ else:
192
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
193
+ return unpadded_tensor, None
194
+
195
+
196
+ def select_best_resolution(original_size, possible_resolutions):
197
+ """
198
+ Selects the best resolution from a list of possible resolutions based on the original size.
199
+
200
+ Args:
201
+ original_size (tuple): The original size of the image in the format (width, height).
202
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
203
+
204
+ Returns:
205
+ tuple: The best fit resolution in the format (width, height).
206
+ """
207
+ original_width, original_height = original_size
208
+ best_fit = None
209
+ max_effective_resolution = 0
210
+ min_wasted_resolution = float("inf")
211
+
212
+ for width, height in possible_resolutions:
213
+ scale = min(width / original_width, height / original_height)
214
+ downscaled_width, downscaled_height = int(original_width * scale), int(
215
+ original_height * scale
216
+ )
217
+ effective_resolution = min(
218
+ downscaled_width * downscaled_height, original_width * original_height
219
+ )
220
+ wasted_resolution = (width * height) - effective_resolution
221
+
222
+ if effective_resolution > max_effective_resolution or (
223
+ effective_resolution == max_effective_resolution
224
+ and wasted_resolution < min_wasted_resolution
225
+ ):
226
+ max_effective_resolution = effective_resolution
227
+ min_wasted_resolution = wasted_resolution
228
+ best_fit = (width, height)
229
+
230
+ return best_fit
231
+
232
+
233
+ def resize_and_pad_image(image, target_resolution):
234
+ """
235
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
236
+
237
+ Args:
238
+ image (PIL.Image.Image): The input image.
239
+ target_resolution (tuple): The target resolution (width, height) of the image.
240
+
241
+ Returns:
242
+ PIL.Image.Image: The resized and padded image.
243
+ """
244
+ original_width, original_height = image.size
245
+ target_width, target_height = target_resolution
246
+
247
+ scale_w = target_width / original_width
248
+ scale_h = target_height / original_height
249
+
250
+ if scale_w < scale_h:
251
+ new_width = target_width
252
+ new_height = min(math.ceil(original_height * scale_w), target_height)
253
+ else:
254
+ new_height = target_height
255
+ new_width = min(math.ceil(original_width * scale_h), target_width)
256
+
257
+ # Resize the image
258
+ resized_image = image.resize((new_width, new_height))
259
+
260
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
261
+ paste_x = (target_width - new_width) // 2
262
+ paste_y = (target_height - new_height) // 2
263
+ new_image.paste(resized_image, (paste_x, paste_y))
264
+
265
+ return new_image
266
+
267
+
268
+ def divide_to_patches(image, patch_size):
269
+ """
270
+ Divides an image into patches of a specified size.
271
+
272
+ Args:
273
+ image (PIL.Image.Image): The input image.
274
+ patch_size (int): The size of each patch.
275
+
276
+ Returns:
277
+ list: A list of PIL.Image.Image objects representing the patches.
278
+ """
279
+ patches = []
280
+ width, height = image.size
281
+ for i in range(0, height, patch_size):
282
+ for j in range(0, width, patch_size):
283
+ box = (j, i, j + patch_size, i + patch_size)
284
+ patch = image.crop(box)
285
+ patches.append(patch)
286
+
287
+ return patches
288
+
289
+
290
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
291
+ """
292
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
293
+
294
+ Args:
295
+ image_size (tuple): The size of the input image in the format (width, height).
296
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
297
+ patch_size (int): The size of each image patch.
298
+
299
+ Returns:
300
+ tuple: The shape of the image patch grid in the format (width, height).
301
+ """
302
+ if type(grid_pinpoints) is list:
303
+ possible_resolutions = grid_pinpoints
304
+ else:
305
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
306
+ width, height = select_best_resolution(image_size, possible_resolutions)
307
+ return width // patch_size, height // patch_size
308
+
309
+
310
+ def process_anyres_image(image, processor, grid_pinpoints):
311
+ """
312
+ Process an image with variable resolutions.
313
+
314
+ Args:
315
+ image (PIL.Image.Image): The input image to be processed.
316
+ processor: The image processor object.
317
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
318
+
319
+ Returns:
320
+ torch.Tensor: A tensor containing the processed image patches.
321
+ """
322
+ # FIXME: determine grid_pinpoints from image sizes.
323
+ if type(grid_pinpoints) is list:
324
+ possible_resolutions = grid_pinpoints
325
+ else:
326
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
327
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
328
+ image_padded = resize_and_pad_image(image, best_resolution)
329
+
330
+ processor_size = processor.transforms[0].size
331
+ patches = divide_to_patches(image_padded, processor_size[0])
332
+
333
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
334
+
335
+ image_patches = [image_original_resize] + patches
336
+ image_patches = [processor(image_patch) for image_patch in image_patches]
337
+ return torch.stack(image_patches, dim=0)
338
+
339
+
340
+ def expand2square(pil_img, background_color):
341
+ width, height = pil_img.size
342
+ if width == height:
343
+ return pil_img
344
+ elif width > height:
345
+ result = Image.new(pil_img.mode, (width, width), background_color)
346
+ result.paste(pil_img, (0, (width - height) // 2))
347
+ return result
348
+ else:
349
+ result = Image.new(pil_img.mode, (height, height), background_color)
350
+ result.paste(pil_img, ((height - width) // 2, 0))
351
+ return result
352
+
353
+
354
+ class VisionTokenizer(nn.Module):
355
+ def __init__(self, dim_media, num_tokens_per_media):
356
+ super().__init__()
357
+ self.dim_media = dim_media
358
+ self.num_tokens_per_media = num_tokens_per_media
359
+
360
+
361
+ class PerceiverAttention(nn.Module):
362
+ def __init__(self, *, dim, dim_head=64, heads=8):
363
+ super().__init__()
364
+ self.scale = dim_head**-0.5
365
+ self.heads = heads
366
+ inner_dim = dim_head * heads
367
+
368
+ self.norm_media = nn.LayerNorm(dim)
369
+ self.norm_latents = nn.LayerNorm(dim)
370
+
371
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
372
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
373
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
374
+
375
+ def forward(self, x, latents, vision_attn_masks=None):
376
+ """
377
+ Args:
378
+ x (torch.Tensor): image features
379
+ shape (b, T, n1, D)
380
+ latent (torch.Tensor): latent features
381
+ shape (b, T, n2, D)
382
+ """
383
+ x = self.norm_media(x)
384
+ latents = self.norm_latents(latents)
385
+
386
+ h = self.heads
387
+
388
+ q = self.to_q(latents)
389
+ kv_input = torch.cat(
390
+ (x, latents), dim=-2
391
+ ) # TODO: Change the shape of vision attention mask according to this.
392
+ if vision_attn_masks is not None:
393
+ vision_attn_masks = torch.cat(
394
+ (
395
+ vision_attn_masks,
396
+ torch.ones(
397
+ (latents.shape[0], latents.shape[-2]),
398
+ dtype=latents.dtype,
399
+ device=latents.device,
400
+ ),
401
+ ),
402
+ dim=-1,
403
+ )
404
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
405
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
406
+ q = q * self.scale
407
+
408
+ # attention
409
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
410
+ # Apply vision attention mask here.
411
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
412
+ if vision_attn_masks is not None:
413
+ attn_bias = torch.zeros(
414
+ (q.size(0), 1, 1, q.size(-2), k.size(-2)),
415
+ dtype=q.dtype,
416
+ device=q.device,
417
+ )
418
+ vision_attn_masks = repeat(
419
+ vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2)
420
+ )
421
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
422
+ sim += attn_bias
423
+
424
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
425
+ attn = sim.softmax(dim=-1)
426
+
427
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
428
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
429
+ return self.to_out(out)
430
+
431
+
432
+ def FeedForward(dim, mult=4):
433
+ inner_dim = int(dim * mult)
434
+ return nn.Sequential(
435
+ nn.LayerNorm(dim),
436
+ nn.Linear(dim, inner_dim, bias=False),
437
+ nn.GELU(),
438
+ nn.Linear(inner_dim, dim, bias=False),
439
+ )
440
+
441
+
442
+ def num_params(module, filter_to_trainable=False):
443
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
444
+ if filter_to_trainable:
445
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
446
+ else:
447
+ return sum(p.numel() for p in module.parameters())
448
+
449
+
450
+ class PerceiverResampler(VisionTokenizer):
451
+ def __init__(
452
+ self,
453
+ *,
454
+ dim,
455
+ dim_inner=None,
456
+ depth=6,
457
+ dim_head=96,
458
+ heads=16,
459
+ num_latents=128,
460
+ max_num_media=None,
461
+ max_num_frames=None,
462
+ ff_mult=4,
463
+ ):
464
+ """
465
+ Perceiver module which takes in image features and outputs image tokens.
466
+ Args:
467
+ dim (int): dimension of the incoming image features
468
+ dim_inner (int, optional): final dimension to project the incoming image features to;
469
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
470
+ depth (int, optional): number of layers. Defaults to 6.
471
+ dim_head (int, optional): dimension of each head. Defaults to 64.
472
+ heads (int, optional): number of heads. Defaults to 8.
473
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
474
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
475
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
476
+ and keep positional embeddings for. If None, no positional embeddings are used.
477
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
478
+ and keep positional embeddings for. If None, no positional embeddings are used.
479
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
480
+ """
481
+ if dim_inner is not None:
482
+ projection = nn.Linear(dim, dim_inner)
483
+ else:
484
+ projection = None
485
+ dim_inner = dim
486
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
487
+ self.projection = projection
488
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
489
+
490
+ # positional embeddings
491
+ self.frame_embs = (
492
+ nn.Parameter(torch.randn(max_num_frames, dim))
493
+ if exists(max_num_frames)
494
+ else None
495
+ )
496
+ self.media_time_embs = (
497
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
498
+ if exists(max_num_media)
499
+ else None
500
+ )
501
+
502
+ self.layers = nn.ModuleList([])
503
+ for _ in range(depth):
504
+ self.layers.append(
505
+ nn.ModuleList(
506
+ [
507
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
508
+ FeedForward(dim=dim, mult=ff_mult),
509
+ ]
510
+ )
511
+ )
512
+
513
+ self.norm = nn.LayerNorm(dim)
514
+
515
+ def forward(self, x, vision_attn_masks):
516
+ """
517
+ Args:
518
+ x (torch.Tensor): image features
519
+ shape (b, T, F, v, D)
520
+ vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
521
+ shape (b, v)
522
+ Returns:
523
+ shape (b, T, n, D) where n is self.num_latents
524
+ """
525
+ b, T, F, v = x.shape[:4]
526
+
527
+ # frame and media time embeddings
528
+ if exists(self.frame_embs):
529
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
530
+ x = x + frame_embs
531
+ x = rearrange(
532
+ x, "b T F v d -> b T (F v) d"
533
+ ) # flatten the frame and spatial dimensions
534
+ if exists(self.media_time_embs):
535
+ x = x + self.media_time_embs[:T]
536
+
537
+ # blocks
538
+ latents = self.latents
539
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
540
+ for attn, ff in self.layers:
541
+ latents = attn(x, latents, vision_attn_masks) + latents
542
+ latents = ff(latents) + latents
543
+
544
+ if exists(self.projection):
545
+ return self.projection(self.norm(latents))
546
+ else:
547
+ return self.norm(latents)
548
+
549
+
550
+ class DecoupledEmbedding(nn.Embedding):
551
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
552
+ """
553
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
554
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
555
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
556
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ max_original_id: int,
562
+ num_additional_embeddings: int = 0,
563
+ _weight: torch.Tensor = None,
564
+ num_original_embeddings: int = None,
565
+ embedding_dim: int = None,
566
+ partially_freeze=True,
567
+ device=None,
568
+ dtype=None,
569
+ pad_token_id=None,
570
+ ) -> None:
571
+ """
572
+ Args:
573
+ max_original_id (`int`):
574
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
575
+ This is usually len(tokenizer) - 1 before additional tokens are added.
576
+ Note that this may not equal self.weight.shape[0]
577
+ num_additional_embeddings (`int`):
578
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
579
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
580
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
581
+ num_original_embeddings (`int`):
582
+ self.weight.shape[0]
583
+ embedding_dim (`int`):
584
+ The size of each embedding vector
585
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
586
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
587
+ padding_idx (`int`, *optional*):
588
+ The padding index (needs to be less than num_embeddings)
589
+
590
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
591
+ `max_norm` or `norm_type`. We are not supporting these.
592
+ """
593
+ # validate args
594
+ if pad_token_id is not None and pad_token_id > max_original_id:
595
+ raise ValueError(
596
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
597
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
598
+ )
599
+ if _weight is not None:
600
+ assert (num_original_embeddings is None) or (
601
+ _weight.shape[0] == num_original_embeddings
602
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
603
+ assert (embedding_dim is None) or (
604
+ _weight.shape[1] == embedding_dim
605
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
606
+ num_original_embeddings = _weight.shape[0]
607
+ embedding_dim = _weight.shape[1]
608
+ else:
609
+ assert (
610
+ num_original_embeddings is not None
611
+ ), "num_original_embeddings must be provided if _weight is not provided"
612
+ assert (
613
+ embedding_dim is not None
614
+ ), "embedding_dim must be provided if _weight is not provided"
615
+
616
+ super().__init__(
617
+ num_embeddings=num_original_embeddings,
618
+ embedding_dim=embedding_dim,
619
+ device=device,
620
+ dtype=dtype,
621
+ padding_idx=pad_token_id,
622
+ _weight=_weight,
623
+ )
624
+ self.max_original_id = max_original_id
625
+ self.padding_idx = pad_token_id
626
+ self.num_additional_embeddings = num_additional_embeddings
627
+ if self.num_additional_embeddings > 0:
628
+ self.additional_embedding = nn.Embedding(
629
+ num_embeddings=self.num_additional_embeddings,
630
+ embedding_dim=embedding_dim,
631
+ device=device,
632
+ dtype=dtype,
633
+ )
634
+ self.set_requires_grad(
635
+ require_regular_grad=not partially_freeze, require_additional_grad=True
636
+ )
637
+
638
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
639
+ """
640
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
641
+ """
642
+ self.weight.requires_grad_(require_regular_grad)
643
+ self.additional_embedding.requires_grad_(require_additional_grad)
644
+
645
+ def forward(self, input_ids):
646
+ """
647
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
648
+ self.additional_embedding.weight that is being trained.
649
+
650
+ in order to make a lookup of the input ids, we:
651
+ 1. find out the indices of the entries belonging to the 2nd embedding
652
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
653
+ embedding starts from 0 and not num_embeddings
654
+ 3. perform the 2nd embedding lookup
655
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
656
+ 5. perform the 1st embedding lookup
657
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
658
+
659
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
660
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
661
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
662
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
663
+ measure.
664
+
665
+ """
666
+ if self.num_additional_embeddings == 0:
667
+ return F.embedding(input_ids, self.weight)
668
+
669
+ # Clone so that we don't modify the original input_ids later on
670
+ input_ids = input_ids.clone()
671
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
672
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
673
+ additional_embeddings = self.additional_embedding(
674
+ input_ids_additional_vocab - self.max_original_id - 1
675
+ )
676
+
677
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
678
+ input_ids[additional_vocab_indices] = 0
679
+ full_vector = F.embedding(input_ids, self.weight)
680
+
681
+ # overwrite the records with high indices
682
+ full_vector[additional_vocab_indices] = additional_embeddings
683
+
684
+ return full_vector
685
+
686
+ def extra_repr(self) -> str:
687
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
688
+ self.max_original_id + 1,
689
+ self.num_additional_embeddings,
690
+ self.embedding_dim,
691
+ (not self.weight.requires_grad),
692
+ )
693
+
694
+
695
+ class DecoupledLinear(nn.Linear):
696
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
697
+ """
698
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
699
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
700
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
701
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ max_original_id: int,
707
+ additional_out_features: int = 0,
708
+ _weight: torch.Tensor = None,
709
+ _bias: torch.Tensor = None,
710
+ in_features: int = None,
711
+ original_out_features: int = None,
712
+ bias: bool = True,
713
+ partially_freeze: bool = True,
714
+ device=None,
715
+ dtype=None,
716
+ ) -> None:
717
+ """
718
+ Args:
719
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
720
+ This is usually len(tokenizer) - 1 before additional tokens are added.
721
+ Note that this may not equal original_out_features - 1
722
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
723
+ If provided, this sets the `in_features` and `original_out_features` parameters.
724
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
725
+ in_features: int. Input hidden size.
726
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
727
+ additional_out_features: int. Number of additional trainable dimensions.
728
+ bias: bool. Whether to include a bias term.
729
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
730
+ """
731
+ # argument validation
732
+ if _weight is not None:
733
+ assert (_weight.shape[0] == original_out_features) or (
734
+ original_out_features is None
735
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
736
+ assert (_weight.shape[1] == in_features) or (
737
+ in_features is None
738
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
739
+ in_features = _weight.shape[1]
740
+ original_out_features = _weight.shape[0]
741
+ else:
742
+ assert (
743
+ in_features is not None
744
+ ), "in_features must be provided if _weight is not provided"
745
+ assert (
746
+ original_out_features is not None
747
+ ), "original_out_features must be provided if _weight is not provided"
748
+
749
+ if _bias is not None:
750
+ assert bias is True, "bias must be True if _bias is provided"
751
+
752
+ # initialize original linear
753
+ super().__init__(in_features, original_out_features, bias, device, dtype)
754
+
755
+ # set weight and bias manually
756
+ if _weight is not None:
757
+ self.weight = nn.Parameter(_weight)
758
+ if _bias is not None:
759
+ self.bias = nn.Parameter(_bias)
760
+
761
+ self.in_features = in_features
762
+ self.original_out_features = original_out_features
763
+ self.max_original_id = max_original_id
764
+
765
+ # initialize additional linear
766
+ self.additional_out_features = additional_out_features
767
+ self.has_bias = bias
768
+ if additional_out_features > 0:
769
+ self.additional_fc = nn.Linear(
770
+ in_features=in_features,
771
+ out_features=additional_out_features,
772
+ bias=self.has_bias,
773
+ device=device,
774
+ dtype=dtype,
775
+ )
776
+ self.set_requires_grad(
777
+ require_regular_grad=not partially_freeze, require_additional_grad=True
778
+ )
779
+
780
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
781
+ """
782
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
783
+ """
784
+ self.weight.requires_grad_(require_regular_grad)
785
+ if self.has_bias:
786
+ self.bias.requires_grad_(require_regular_grad)
787
+ self.additional_fc.requires_grad_(require_additional_grad)
788
+
789
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
790
+ output = F.linear(input, self.weight, self.bias)
791
+ output = output[..., : self.max_original_id + 1]
792
+
793
+ if self.additional_out_features > 0:
794
+ additional_features = F.linear(
795
+ input, self.additional_fc.weight, self.additional_fc.bias
796
+ )
797
+ output = torch.cat((output, additional_features), -1)
798
+ return output
799
+
800
+ def extra_repr(self) -> str:
801
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
802
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
803
+ self.in_features,
804
+ self.max_original_id + 1,
805
+ self.additional_out_features,
806
+ self.bias is not None,
807
+ (not self.weight.requires_grad or not self.bias.requires_grad),
808
+ )
809
+
810
+
811
+ class VLM(nn.Module):
812
+ """
813
+ Generic vision-language model (VLM) class.
814
+ A VLM consists of four components:
815
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
816
+ input: (B, T_img, F, C, H, W)
817
+ output: (B, T_img, F, v, d)
818
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
819
+ input: (B, T_img, F, v, d)
820
+ output: (B, T_img, n, d)
821
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
822
+ 4. A language model
823
+ """
824
+
825
+ def __init__(
826
+ self,
827
+ vision_encoder: nn.Module,
828
+ vision_tokenizer: nn.Module,
829
+ lang_model: nn.Module,
830
+ initial_tokenizer_len: int,
831
+ pad_token_id: int,
832
+ gradient_checkpointing: bool = False,
833
+ ):
834
+ """
835
+ Args:
836
+ vision_encoder (nn.Module): e.g. CLIP
837
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
838
+ lang_model (nn.Module): e.g. MPT
839
+ initial_tokenizer_len (int): size of the original tokenizer vocab
840
+ pad_token_id (int): id of the pad token
841
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
842
+ """
843
+ super().__init__()
844
+
845
+ # save dimension information
846
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
847
+ if hasattr(lang_model.config, "d_model"):
848
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
849
+ else:
850
+ self.lang_hidden_dim = lang_model.config.hidden_size
851
+ self.vis_embedding_dim = vision_tokenizer.dim_media
852
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
853
+
854
+ # core components
855
+ self.vision_encoder = vision_encoder
856
+ self.vision_tokenizer = vision_tokenizer
857
+ self.lang_model = lang_model
858
+
859
+ # lm embeddings
860
+ self.pad_token_id = pad_token_id
861
+ self.initial_tokenizer_len = initial_tokenizer_len
862
+ input_embeds = DecoupledEmbedding(
863
+ max_original_id=initial_tokenizer_len - 1,
864
+ num_additional_embeddings=len(self.special_tokens),
865
+ _weight=self.lang_model.get_input_embeddings().weight,
866
+ pad_token_id=self.pad_token_id,
867
+ )
868
+ if hasattr(input_embeds, "additional_embedding"):
869
+ input_embeds.additional_embedding.weight.data.normal_(
870
+ mean=0.0,
871
+ std=(
872
+ self.lang_model.config.initializer_range
873
+ if hasattr(self.lang_model.config, "initializer_range")
874
+ else 0.02
875
+ ),
876
+ )
877
+ self.lang_model.set_input_embeddings(input_embeds)
878
+
879
+ out_embeds = DecoupledLinear(
880
+ max_original_id=initial_tokenizer_len - 1,
881
+ additional_out_features=len(self.special_tokens),
882
+ _weight=self.lang_model.get_output_embeddings().weight,
883
+ _bias=(
884
+ self.lang_model.get_output_embeddings().bias
885
+ if hasattr(self.lang_model.get_output_embeddings(), "bias")
886
+ else None
887
+ ),
888
+ )
889
+ if hasattr(out_embeds, "additional_fc"):
890
+ out_embeds.additional_fc.weight.data.normal_(
891
+ mean=0.0,
892
+ std=(
893
+ self.lang_model.config.initializer_range
894
+ if hasattr(self.lang_model.config, "initializer_range")
895
+ else 0.02
896
+ ),
897
+ )
898
+ self.lang_model.set_output_embeddings(out_embeds)
899
+
900
+ # gradient checkpointing
901
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
902
+
903
+ def forward(
904
+ self,
905
+ vision_x: Optional[torch.Tensor],
906
+ lang_x: torch.Tensor,
907
+ attention_mask: Optional[torch.Tensor] = None,
908
+ labels: Optional[torch.Tensor] = None,
909
+ past_key_values: Optional[
910
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
911
+ ] = None,
912
+ past_media_locations: Optional[torch.Tensor] = None,
913
+ past_vision_tokens: Optional[torch.Tensor] = None,
914
+ use_cache: Optional[bool] = False,
915
+ **kwargs,
916
+ ):
917
+ """
918
+ Args:
919
+ vision_x: Vision input
920
+ shape (B, T_img, F, C, H, W) with F=1
921
+ only F = 1 is supported (single-frame videos)
922
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
923
+ only the first number of media tokens in lang_x are used
924
+ lang_x: Language input ids, with media tokens denoting where
925
+ visual media should be inserted.
926
+ shape (B, T_txt)
927
+ attention_mask: Attention mask. Defaults to None.
928
+ labels: Labels. Defaults to None.
929
+ shape (B, T_txt)
930
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
931
+ list of length = number of decoder layers in the LM
932
+ exact implementation depends on LM, see Hugging Face docs
933
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
934
+ shape (B, T_txt)
935
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
936
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
937
+ If True, includes key_values, media_locations, and vision_tokens in the output.
938
+ """
939
+ assert not (past_vision_tokens is None) ^ (
940
+ past_media_locations is None
941
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
942
+
943
+ # convert pixels to vision tokens
944
+ if vision_x is not None:
945
+ vision_features = self._encode_vision_x(vision_x=vision_x)
946
+ vision_tokens = self.vision_tokenizer(vision_features)
947
+ else:
948
+ vision_tokens = None
949
+
950
+ # fuse the vision and language tokens
951
+ new_inputs = self._prepare_inputs_for_forward(
952
+ vision_tokens=vision_tokens,
953
+ lang_x=lang_x,
954
+ attention_mask=attention_mask,
955
+ labels=labels,
956
+ past_key_values=past_key_values,
957
+ past_media_locations=past_media_locations,
958
+ padding_side="right",
959
+ past_vision_tokens=past_vision_tokens,
960
+ )
961
+ output = self.lang_model(
962
+ **new_inputs,
963
+ use_cache=use_cache,
964
+ past_key_values=past_key_values,
965
+ **kwargs,
966
+ )
967
+
968
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
969
+ # or to add the past_vision_tokens and past_media_locations to the output
970
+ output = self._postprocess_outputs_from_forward(
971
+ output=output,
972
+ lang_x=lang_x,
973
+ vision_tokens=vision_tokens,
974
+ use_cache=use_cache,
975
+ past_vision_tokens=past_vision_tokens,
976
+ past_media_locations=past_media_locations,
977
+ )
978
+
979
+ # postforward hooks
980
+ self._post_forward_hook()
981
+ return output
982
+
983
+ def _encode_vision_x_anyres(self, samples, device):
984
+ assert self.anyres_grids is not None
985
+ image_raw = samples[
986
+ "image"
987
+ ] # list of patch list in of shape [1, N_patch, C, H, W]
988
+ image_sizes = samples["image_size"]
989
+
990
+ # Image_raw can be a list of list of patches, when a `samples` has multiple images.
991
+ if isinstance(image_raw[0], list):
992
+ images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
993
+ image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
994
+ else:
995
+ # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
996
+ # concate list of patches into one big patch for any res encoding.
997
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
998
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
999
+ image = image.to(device)
1000
+
1001
+ with torch.no_grad():
1002
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1003
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
1004
+ elif self.vision_encoder.__class__.__name__ in [
1005
+ "CLIPVisionModel",
1006
+ "SiglipVisionTransformer",
1007
+ ]:
1008
+ image_embeds = self.vision_encoder(image).last_hidden_state
1009
+ else:
1010
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
1011
+
1012
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(
1013
+ self.vision_encoder, SiglipVisionTransformer
1014
+ ):
1015
+ base_img_size = self.vision_encoder.config.image_size
1016
+ else:
1017
+ base_img_size = self.vision_encoder.image_size[0]
1018
+
1019
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1020
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
1021
+ elif self.vision_encoder.__class__.__name__ in [
1022
+ "CLIPVisionModel",
1023
+ "SiglipVisionTransformer",
1024
+ ]:
1025
+ grid_size_base = (
1026
+ self.vision_encoder.config.image_size
1027
+ // self.vision_encoder.config.patch_size
1028
+ )
1029
+ grid_size = (grid_size_base, grid_size_base)
1030
+ else:
1031
+ grid_size = self.vision_encoder.grid_size
1032
+ height, width = grid_size
1033
+
1034
+ if not image_embeds.shape[1] == height * width:
1035
+ assert (
1036
+ image_embeds.shape[1] == height * width + 1
1037
+ ) # For vision encoders that has [CLS] token.
1038
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
1039
+ n_vis_token_per_patch = image_embeds.shape[1]
1040
+
1041
+ # Split encoded patches and merge patch features
1042
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
1043
+ split_sizes = [image.shape[0] for image in images]
1044
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
1045
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
1046
+ new_image_embeds = []
1047
+ patch_attn_masks = []
1048
+ max_n_img_token = -1
1049
+ for idx, patch_embeds in enumerate(image_embeds):
1050
+ if patch_embeds.shape[0] > 1:
1051
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
1052
+ base_patch_embeds = patch_embeds[
1053
+ 0
1054
+ ] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
1055
+ patch_embeds = patch_embeds[1:]
1056
+
1057
+ assert height * width == base_patch_embeds.shape[0]
1058
+
1059
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
1060
+ image_sizes[idx], self.anyres_grids, base_img_size
1061
+ ) # Hardcoded grid_pinpoints.
1062
+ patch_embeds = patch_embeds.view(
1063
+ num_patch_height, num_patch_width, height, width, -1
1064
+ )
1065
+
1066
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
1067
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
1068
+ patch_embeds, patch_attn_mask = unpad_image(
1069
+ patch_embeds, image_sizes[idx], self.anyres_patch_sampling
1070
+ )
1071
+ if hasattr(self, "image_newline"):
1072
+ patch_embeds = torch.cat(
1073
+ (
1074
+ patch_embeds,
1075
+ self.image_newline[:, None, None].expand(
1076
+ *patch_embeds.shape[:-1], 1
1077
+ ),
1078
+ ),
1079
+ dim=-1,
1080
+ )
1081
+ if self.anyres_patch_sampling:
1082
+ patch_embeds = patch_embeds.view(
1083
+ -1, num_patch_height, num_patch_width, height * width
1084
+ )
1085
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
1086
+ assert patch_attn_mask is not None
1087
+ patch_attn_mask = patch_attn_mask.view(
1088
+ num_patch_height, num_patch_width, height * width
1089
+ )
1090
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
1091
+ patch_embeds = torch.cat(
1092
+ (base_patch_embeds.unsqueeze(0), patch_embeds), dim=0
1093
+ )
1094
+ patch_attn_mask = torch.cat(
1095
+ (
1096
+ torch.ones(
1097
+ n_vis_token_per_patch, device=patch_embeds.device
1098
+ ).unsqueeze(0),
1099
+ patch_attn_mask,
1100
+ ),
1101
+ dim=0,
1102
+ )
1103
+ else:
1104
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
1105
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
1106
+ else:
1107
+ patch_embeds = (
1108
+ patch_embeds[0].unsqueeze(0)
1109
+ if self.anyres_patch_sampling
1110
+ else patch_embeds[0]
1111
+ )
1112
+ patch_attn_mask = (
1113
+ torch.ones(
1114
+ n_vis_token_per_patch, device=patch_embeds.device
1115
+ ).unsqueeze(0)
1116
+ if self.anyres_patch_sampling
1117
+ else None
1118
+ )
1119
+ if hasattr(self, "image_newline"):
1120
+ patch_embeds = torch.cat(
1121
+ (patch_embeds, self.image_newline[None]), dim=0
1122
+ )
1123
+ if not self.anyres_patch_sampling:
1124
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
1125
+
1126
+ new_image_embeds.append(patch_embeds)
1127
+ patch_attn_masks.append(patch_attn_mask)
1128
+
1129
+ if self.anyres_patch_sampling:
1130
+ # Return individual patches for independent token downsampling.
1131
+ return new_image_embeds, patch_attn_masks
1132
+
1133
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
1134
+ image_embeds = []
1135
+ image_atts = []
1136
+ for image_embed in new_image_embeds:
1137
+ n_img_token = image_embed.shape[0]
1138
+ img_attn = torch.ones(
1139
+ (max_n_img_token), dtype=torch.long, device=image_embed.device
1140
+ )
1141
+ if n_img_token < max_n_img_token:
1142
+ padded_embed = torch.zeros(
1143
+ (max_n_img_token, image_embed.shape[-1]),
1144
+ dtype=image_embed.dtype,
1145
+ device=image_embed.device,
1146
+ )
1147
+ padded_embed[:n_img_token, :] = image_embed
1148
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
1149
+ else:
1150
+ padded_embed = image_embed
1151
+ image_embeds.append(padded_embed)
1152
+ image_atts.append(img_attn)
1153
+ image_embeds = torch.stack(
1154
+ image_embeds, dim=0
1155
+ ) # Shape [B, N_tok_longest, C_dim]
1156
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
1157
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
1158
+ image_embeds = image_embeds[:, None, None, :, :]
1159
+ # image_atts = image_atts[:, None, None, :, :]
1160
+
1161
+ return image_embeds, image_atts
1162
+
1163
+ def _encode_vision_x(self, vision_x: torch.Tensor):
1164
+ """
1165
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
1166
+ Args:
1167
+ vision_x: Vision input
1168
+ shape (B, T_img, F, C, H, W)
1169
+ Images in the same chunk are collated along T_img, and frames are collated along F
1170
+ Currently only F=1 is supported (single-frame videos)
1171
+
1172
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
1173
+ """
1174
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
1175
+ b, T, F = vision_x.shape[:3]
1176
+
1177
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
1178
+ with torch.no_grad():
1179
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1180
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
1181
+ elif self.vision_encoder.__class__.__name__ in [
1182
+ "CLIPVisionModel",
1183
+ "SiglipVisionTransformer",
1184
+ ]:
1185
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
1186
+ else:
1187
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
1188
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
1189
+ return vision_x
1190
+
1191
+ def _concat_vision_cache(
1192
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
1193
+ ):
1194
+ """
1195
+ Helper function to include the past vision tokens and past media locations in the output.
1196
+ """
1197
+ if use_cache:
1198
+ if past_media_locations is not None and past_vision_tokens is not None:
1199
+ if vision_tokens is not None:
1200
+ updated_vision_tokens = torch.cat(
1201
+ [
1202
+ past_vision_tokens,
1203
+ vision_tokens,
1204
+ ],
1205
+ dim=1,
1206
+ )
1207
+ else:
1208
+ updated_vision_tokens = past_vision_tokens
1209
+ updated_media_locations = torch.cat(
1210
+ [
1211
+ past_media_locations,
1212
+ lang_x == self.media_token_id,
1213
+ ],
1214
+ dim=1,
1215
+ )
1216
+ else:
1217
+ updated_vision_tokens = vision_tokens
1218
+ updated_media_locations = lang_x == self.media_token_id
1219
+
1220
+ else:
1221
+ updated_vision_tokens = None
1222
+ updated_media_locations = None
1223
+
1224
+ return updated_vision_tokens, updated_media_locations
1225
+
1226
+ def generate(
1227
+ self,
1228
+ vision_x: torch.Tensor,
1229
+ lang_x: torch.Tensor,
1230
+ attention_mask: torch.Tensor = None,
1231
+ past_key_values: Optional[
1232
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1233
+ ] = None,
1234
+ past_media_locations: Optional[torch.Tensor] = None,
1235
+ past_vision_tokens: Optional[torch.Tensor] = None,
1236
+ **kwargs,
1237
+ ):
1238
+ """
1239
+ Generate text conditioned on vision and language inputs.
1240
+ Args:
1241
+ vision_x (torch.Tensor): Vision input
1242
+ shape (B, T_img, F, C, H, W)
1243
+ see documentation for forward
1244
+ lang_x (torch.Tensor): Language input
1245
+ shape (B, T_txt)
1246
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1247
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1248
+ Returns:
1249
+ torch.Tensor: lang_x with generated tokens appended to it
1250
+ """
1251
+ num_beams = kwargs.pop("num_beams", 1)
1252
+
1253
+ # convert pixels to vision tokens
1254
+ if vision_x is not None:
1255
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1256
+ vision_tokens = self.vision_tokenizer(vision_features)
1257
+ else:
1258
+ vision_tokens = None
1259
+
1260
+ # fuse the vision and language tokens
1261
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1262
+ # the total batch size is B * num_beams
1263
+ new_inputs = self._prepare_inputs_for_forward(
1264
+ vision_tokens=vision_tokens,
1265
+ lang_x=lang_x,
1266
+ attention_mask=attention_mask,
1267
+ past_key_values=past_key_values,
1268
+ past_media_locations=past_media_locations,
1269
+ past_vision_tokens=past_vision_tokens,
1270
+ padding_side="left",
1271
+ num_beams=num_beams,
1272
+ )
1273
+ output = self.lang_model.generate(
1274
+ **new_inputs,
1275
+ past_key_values=past_key_values,
1276
+ num_beams=num_beams,
1277
+ use_cache=True,
1278
+ **kwargs,
1279
+ )
1280
+ self._post_forward_hook()
1281
+ return output
1282
+
1283
+ @property
1284
+ def num_trainable_params(self):
1285
+ """Print the number of trainable parameters"""
1286
+ return num_params(self, filter_to_trainable=True)
1287
+
1288
+ def set_trainable(self):
1289
+ """
1290
+ Freeze appropriate parameters in the model.
1291
+ """
1292
+ raise NotImplementedError
1293
+
1294
+ def group_params_by_weight_decay(self):
1295
+ """
1296
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
1297
+ """
1298
+ params_with_wd, params_without_wd = [], []
1299
+ for n, p in self.named_parameters():
1300
+ if p.requires_grad:
1301
+ if self._should_apply_weight_decay(n):
1302
+ params_with_wd.append(p)
1303
+ else:
1304
+ params_without_wd.append(p)
1305
+ return params_with_wd, params_without_wd
1306
+
1307
+ def _should_apply_weight_decay(self, parameter_name):
1308
+ """
1309
+ Return whether weight decay should be applied to a parameter.
1310
+ """
1311
+ raise NotImplementedError
1312
+
1313
+ @property
1314
+ def special_tokens(self):
1315
+ """
1316
+ Returns a dict mapping from the attribute name of a special token to its string format,
1317
+ e.g. "media_token": "<image>"
1318
+ """
1319
+ assert (
1320
+ "media_token" in self._special_tokens
1321
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
1322
+ return self._special_tokens
1323
+
1324
+ @property
1325
+ def special_token_ids(self):
1326
+ """
1327
+ Returns a list of the special token ids
1328
+ """
1329
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
1330
+
1331
+ def set_special_token_ids(self, string_to_ids):
1332
+ """
1333
+ Args:
1334
+ string_to_ids (dict): mapping from token string to id
1335
+ """
1336
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
1337
+ for att_name, token_str in self.special_tokens.items():
1338
+ token_id = string_to_ids[token_str]
1339
+ setattr(self, f"{att_name}_id", token_id)
1340
+ setattr(self.lang_model, f"{att_name}_id", token_id)
1341
+
1342
+ def init_gradient_checkpointing(self):
1343
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1344
+ checkpoint_wrapper,
1345
+ CheckpointWrapper,
1346
+ CheckpointImpl,
1347
+ apply_activation_checkpointing,
1348
+ )
1349
+ from functools import partial
1350
+
1351
+ non_reentrant_wrapper = partial(
1352
+ checkpoint_wrapper,
1353
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
1354
+ )
1355
+ apply_activation_checkpointing(
1356
+ self,
1357
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
1358
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
1359
+ and not isinstance(m, CheckpointWrapper),
1360
+ )
1361
+
1362
+
1363
+ @dataclass
1364
+ class VLMOutputWithPast(CausalLMOutputWithPast):
1365
+ """
1366
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
1367
+ past_media_locations: Optional[torch.Tensor] = None,
1368
+ past_vision_tokens: Optional[torch.Tensor] = None,
1369
+ """
1370
+
1371
+ past_media_locations: Optional[torch.Tensor] = None
1372
+ past_vision_tokens: Optional[torch.Tensor] = None
1373
+
1374
+
1375
+ def exists(val):
1376
+ return val is not None
1377
+
1378
+
1379
+ def FeedForward(dim, mult=4):
1380
+ inner_dim = int(dim * mult)
1381
+ return nn.Sequential(
1382
+ nn.LayerNorm(dim),
1383
+ nn.Linear(dim, inner_dim, bias=False),
1384
+ nn.GELU(),
1385
+ nn.Linear(inner_dim, dim, bias=False),
1386
+ )
1387
+
1388
+
1389
+ class VLMWithLanguageStream(VLM):
1390
+ """
1391
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
1392
+ """
1393
+
1394
+ def __init__(
1395
+ self,
1396
+ vision_encoder: nn.Module,
1397
+ vision_tokenizer: nn.Module,
1398
+ lang_model: nn.Module,
1399
+ initial_tokenizer_len: int,
1400
+ pad_token_id: int,
1401
+ decoder_layers_attr_name: str = None,
1402
+ gradient_checkpointing: bool = False,
1403
+ ):
1404
+ super().__init__(
1405
+ vision_encoder=vision_encoder,
1406
+ vision_tokenizer=vision_tokenizer,
1407
+ lang_model=lang_model,
1408
+ initial_tokenizer_len=initial_tokenizer_len,
1409
+ pad_token_id=pad_token_id,
1410
+ gradient_checkpointing=gradient_checkpointing,
1411
+ )
1412
+ self.decoder_layers_attr_name = decoder_layers_attr_name
1413
+ if decoder_layers_attr_name is not None:
1414
+ for block in getattr_recursive(
1415
+ self.lang_model, self.decoder_layers_attr_name
1416
+ ):
1417
+ block._use_gradient_checkpointing = gradient_checkpointing
1418
+
1419
+ def _prepare_inputs_for_forward(
1420
+ self,
1421
+ vision_tokens: torch.Tensor,
1422
+ lang_x: torch.Tensor,
1423
+ attention_mask: torch.Tensor,
1424
+ labels: torch.Tensor = None,
1425
+ past_key_values=None,
1426
+ vision_attention_mask: Optional[torch.Tensor] = None,
1427
+ past_media_locations: torch.Tensor = None,
1428
+ past_vision_tokens: torch.Tensor = None,
1429
+ padding_side: str = "left",
1430
+ num_beams: int = 1,
1431
+ ):
1432
+ """
1433
+ Insert the vision tokens directly into the language stream/
1434
+ This requires us to modify the input_ids, attention_mask, and labels.
1435
+ """
1436
+ if past_key_values is not None:
1437
+ past_len = past_key_values[0][0].shape[2]
1438
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1439
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1440
+ + "Check that you've expanded the attention mask to account for past image tokens."
1441
+ )
1442
+
1443
+ if vision_tokens is None:
1444
+ return {
1445
+ "input_ids": lang_x,
1446
+ "attention_mask": attention_mask,
1447
+ "labels": labels,
1448
+ }
1449
+
1450
+ # get the language embeddings
1451
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1452
+
1453
+ # build up the multimodal embeddings
1454
+ B = lang_x.shape[0]
1455
+ has_labels = labels is not None
1456
+ multimodal_embeds = []
1457
+ multimodal_attention_mask = []
1458
+ multimodal_labels = [] if has_labels else None
1459
+ for i in range(B):
1460
+ # get index of <image> tokens in lang_x[i]
1461
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1462
+
1463
+ if len(image_token_idxs) == 0:
1464
+ multimodal_embeds.append(lang_embeds[i].clone())
1465
+ multimodal_attention_mask.append(attention_mask[i].clone())
1466
+ if has_labels:
1467
+ multimodal_labels.append(labels[i].clone())
1468
+ continue
1469
+
1470
+ # loop through the image_token_idxs and insert the vision tokens
1471
+ new_embed = lang_embeds[i].clone()
1472
+ new_attention_mask = (
1473
+ attention_mask[i].clone() if attention_mask is not None else None
1474
+ )
1475
+ if has_labels:
1476
+ new_label = labels[i].clone()
1477
+
1478
+ for img_num, img_idx in enumerate(image_token_idxs):
1479
+ # Get vision token attention mask for padded llava-style any resolution image tokens.
1480
+ if self.image_aspect_ratio == "anyres":
1481
+ num_vis_tokens = vision_tokens[i][img_num].shape[0]
1482
+ if vision_attention_mask is not None:
1483
+ vis_attention_mask = vision_attention_mask[i]
1484
+ else:
1485
+ vis_attention_mask = torch.ones(
1486
+ num_vis_tokens, dtype=torch.long
1487
+ ).to(attention_mask.device)
1488
+ else:
1489
+ assert (
1490
+ vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1491
+ ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1492
+ vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1493
+ # By default, vision tokens are not padded.
1494
+ num_vis_tokens = self.num_tokens_per_vis
1495
+ vis_attention_mask = torch.ones(
1496
+ num_vis_tokens, dtype=torch.long
1497
+ ).to(attention_mask.device)
1498
+
1499
+ new_embed = torch.cat(
1500
+ (
1501
+ new_embed[:img_idx],
1502
+ vision_tokens[i][img_num],
1503
+ new_embed[img_idx + 1 :],
1504
+ ),
1505
+ dim=0,
1506
+ )
1507
+ new_attention_mask = torch.cat(
1508
+ (
1509
+ new_attention_mask[:img_idx],
1510
+ vis_attention_mask,
1511
+ new_attention_mask[img_idx + 1 :],
1512
+ ),
1513
+ dim=0,
1514
+ )
1515
+ if has_labels:
1516
+ new_label = torch.cat(
1517
+ (
1518
+ new_label[:img_idx],
1519
+ torch.ones(num_vis_tokens, dtype=torch.long).to(
1520
+ labels.device
1521
+ )
1522
+ * -100,
1523
+ new_label[img_idx + 1 :],
1524
+ ),
1525
+ dim=0,
1526
+ )
1527
+ multimodal_embeds.append(new_embed)
1528
+ multimodal_attention_mask.append(new_attention_mask)
1529
+ if has_labels:
1530
+ multimodal_labels.append(new_label)
1531
+
1532
+ # stack
1533
+ multimodal_embeds = stack_with_padding(
1534
+ multimodal_embeds,
1535
+ padding_value=self.pad_token_id,
1536
+ padding_side=padding_side,
1537
+ )
1538
+ multimodal_attention_mask = stack_with_padding(
1539
+ multimodal_attention_mask,
1540
+ padding_value=0,
1541
+ padding_side=padding_side,
1542
+ )
1543
+ if has_labels:
1544
+ multimodal_labels = stack_with_padding(
1545
+ multimodal_labels,
1546
+ padding_value=-100,
1547
+ padding_side=padding_side,
1548
+ )
1549
+
1550
+ return {
1551
+ "inputs_embeds": multimodal_embeds,
1552
+ "attention_mask": multimodal_attention_mask,
1553
+ "labels": multimodal_labels,
1554
+ }
1555
+
1556
+ def _postprocess_outputs_from_forward(
1557
+ self,
1558
+ output: CausalLMOutputWithPast,
1559
+ lang_x: torch.Tensor,
1560
+ vision_tokens: torch.Tensor,
1561
+ past_vision_tokens: torch.Tensor,
1562
+ past_media_locations: torch.Tensor,
1563
+ use_cache: bool = False,
1564
+ ):
1565
+ # Include the past vision tokens and past media locations in the output
1566
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1567
+ lang_x=lang_x,
1568
+ vision_tokens=vision_tokens,
1569
+ past_vision_tokens=past_vision_tokens,
1570
+ past_media_locations=past_media_locations,
1571
+ use_cache=use_cache,
1572
+ )
1573
+
1574
+ # return logits that are the same shape as the original input_ids
1575
+ logits = output.logits
1576
+ batch_logits = []
1577
+ B, T_txt = lang_x.shape
1578
+ for i in range(B):
1579
+ sequence_logits = []
1580
+ logits_j = 0
1581
+ for j in range(T_txt):
1582
+ if lang_x[i, j] != self.media_token_id:
1583
+ sequence_logits.append(logits[i, logits_j])
1584
+ logits_j += 1
1585
+ else:
1586
+ # append the logit for the first image token, then skip over the rest
1587
+ # note: the model actually learns to predict <im_patch>, not <image>
1588
+ sequence_logits.append(logits[i, logits_j])
1589
+ logits_j += self.num_tokens_per_vis
1590
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1591
+ batch_logits.append(sequence_logits)
1592
+
1593
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1594
+ # The final logits shape should be the same as the original input_ids shape
1595
+ assert batch_logits.shape[:2] == (B, T_txt)
1596
+
1597
+ # assemble the output
1598
+ output = VLMOutputWithPast(
1599
+ loss=output.loss,
1600
+ logits=batch_logits,
1601
+ past_key_values=output.past_key_values,
1602
+ hidden_states=output.hidden_states,
1603
+ attentions=output.attentions,
1604
+ past_media_locations=updated_media_locations,
1605
+ past_vision_tokens=updated_vision_tokens,
1606
+ )
1607
+
1608
+ return output
1609
+
1610
+ def _post_forward_hook(self):
1611
+ pass
1612
+
1613
+ @property
1614
+ def num_params_per_module(self):
1615
+ """Print the number of parameters per module in the model"""
1616
+ return "\n".join(
1617
+ [
1618
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1619
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1620
+ f"Language model: {num_params(self.lang_model):,} parameters",
1621
+ ]
1622
+ )
1623
+
1624
+ @property
1625
+ def num_trainable_params_per_module(self):
1626
+ """Print the number of trainable parameters per module in the model"""
1627
+ return "\n".join(
1628
+ [
1629
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1630
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1631
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1632
+ ]
1633
+ )
1634
+
1635
+
1636
+ class XGenMMPerceiver(VLMWithLanguageStream):
1637
+ def __init__(
1638
+ self,
1639
+ vision_encoder: nn.Module,
1640
+ vision_tokenizer: nn.Module,
1641
+ lang_model: nn.Module,
1642
+ initial_tokenizer_len: int,
1643
+ pad_token_id: int,
1644
+ decoder_layers_attr_name: str = None,
1645
+ gradient_checkpointing: bool = False,
1646
+ image_aspect_ratio: str = "anyres",
1647
+ anyres_patch_sampling: bool = True,
1648
+ anyres_grids: list[int] = None,
1649
+ ):
1650
+ """
1651
+ Args:
1652
+ vision_encoder (nn.Module): HF CLIPModel
1653
+ lang_encoder (nn.Module): HF causal language model
1654
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1655
+ initial_tokenizer_len (int): size of the tokenizer vocab
1656
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1657
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1658
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1659
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1660
+ """
1661
+ self._special_tokens = {
1662
+ "media_token": "<image>",
1663
+ "image_placeholder_token": "<image placeholder>",
1664
+ "end_of_trunk_token": "<|endofchunk|>",
1665
+ }
1666
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1667
+ super().__init__(
1668
+ vision_encoder=vision_encoder,
1669
+ vision_tokenizer=vision_tokenizer,
1670
+ lang_model=lang_model,
1671
+ initial_tokenizer_len=initial_tokenizer_len,
1672
+ gradient_checkpointing=gradient_checkpointing,
1673
+ decoder_layers_attr_name=decoder_layers_attr_name,
1674
+ pad_token_id=pad_token_id,
1675
+ )
1676
+ self.image_aspect_ratio = image_aspect_ratio
1677
+ self.anyres_patch_sampling = anyres_patch_sampling
1678
+ self.anyres_grids = anyres_grids
1679
+
1680
+ def set_trainable(self):
1681
+ """
1682
+ Unfreeze everything except the vision_encoder
1683
+ """
1684
+ self.requires_grad_(True)
1685
+ self.vision_encoder.requires_grad_(False)
1686
+
1687
+ def _should_apply_weight_decay(self, parameter_name):
1688
+ """
1689
+ Kosmos applies 0.01 weight deacy to everything
1690
+ """
1691
+ return True
1692
+
1693
+ def generate(
1694
+ self,
1695
+ vision_x: torch.Tensor,
1696
+ lang_x: torch.Tensor,
1697
+ image_size: Optional[Tuple] = None,
1698
+ attention_mask: torch.Tensor = None,
1699
+ past_key_values: Optional[
1700
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1701
+ ] = None,
1702
+ past_media_locations: Optional[torch.Tensor] = None,
1703
+ past_vision_tokens: Optional[torch.Tensor] = None,
1704
+ **kwargs,
1705
+ ):
1706
+ """
1707
+ Generate text conditioned on vision and language inputs.
1708
+ Args:
1709
+ vision_x (torch.Tensor): Vision input
1710
+ shape (B, T_img, F, C, H, W)
1711
+ see documentation for forward
1712
+ lang_x (torch.Tensor): Language input
1713
+ shape (B, T_txt)
1714
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1715
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1716
+ Returns:
1717
+ torch.Tensor: lang_x with generated tokens appended to it
1718
+ """
1719
+ num_beams = kwargs.pop("num_beams", 1)
1720
+
1721
+ # convert pixels to vision tokens
1722
+ vision_attention_mask = None
1723
+ if vision_x is not None:
1724
+ if self.image_aspect_ratio == "anyres":
1725
+ input_dict = dict(image=vision_x, image_size=image_size)
1726
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(
1727
+ input_dict, lang_x.device
1728
+ )
1729
+ else:
1730
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1731
+ vision_attn_masks = None
1732
+ # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1733
+ # Same for attention masks: [b, Np, v] -> [b*Np, v]
1734
+ if self.anyres_patch_sampling:
1735
+ split_sizes = [feature.shape[0] for feature in vision_features]
1736
+ # Nested splits for multi-image samples.
1737
+ if isinstance(vision_x[0], list):
1738
+ nt_images = [len(images) for images in vision_x]
1739
+ split_split_sizes = []
1740
+ img_id = 0
1741
+ for nt in nt_images:
1742
+ split_split_sizes.append(split_sizes[img_id : img_id + nt])
1743
+ img_id += nt
1744
+ else:
1745
+ nt_images = [1] * len(vision_x)
1746
+ split_split_sizes = split_sizes
1747
+ vision_features = torch.cat(vision_features, dim=0)
1748
+ vision_features = vision_features[
1749
+ :, None, None, :, :
1750
+ ] # Expand dimensions.
1751
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1752
+ vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1753
+
1754
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1755
+ if self.anyres_patch_sampling:
1756
+ assert isinstance(vision_x, list)
1757
+ if isinstance(vision_x[0], list):
1758
+ vision_token_groups = torch.split(
1759
+ vision_tokens,
1760
+ list(sum(nt_img) for nt_img in split_split_sizes),
1761
+ dim=0,
1762
+ )
1763
+ vision_tokens = []
1764
+
1765
+ for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1766
+ patch_vis_token_groups = torch.split(
1767
+ patch_vis_tokens, split_split_sizes[sample_id], dim=0
1768
+ ) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1769
+ flatten_vision_tokens = []
1770
+ for image_vis_token in patch_vis_token_groups:
1771
+ image_vis_token = image_vis_token.flatten(
1772
+ 0, 2
1773
+ ) # [Np, 1, v, d] -> [Np*v, d]
1774
+ flatten_vision_tokens.append(image_vis_token)
1775
+ vision_tokens_i = flatten_vision_tokens
1776
+ vision_tokens.append(vision_tokens_i)
1777
+ else:
1778
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1779
+ vision_tokens = []
1780
+ for patch_vis_tokens in vision_token_groups:
1781
+ patch_vis_tokens = patch_vis_tokens.flatten(
1782
+ 0, 2
1783
+ ) # [Np, 1, v, d] -> [Np*v, d]
1784
+ vision_tokens.append(
1785
+ patch_vis_tokens.unsqueeze(0)
1786
+ ) # Add the nt dimension.
1787
+ else:
1788
+ vision_tokens = None
1789
+
1790
+ # fuse the vision and language tokens
1791
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1792
+ # the total batch size is B * num_beams
1793
+ new_inputs = self._prepare_inputs_for_forward(
1794
+ vision_tokens=vision_tokens,
1795
+ lang_x=lang_x,
1796
+ attention_mask=attention_mask,
1797
+ vision_attention_mask=vision_attention_mask,
1798
+ past_key_values=past_key_values,
1799
+ past_media_locations=past_media_locations,
1800
+ past_vision_tokens=past_vision_tokens,
1801
+ padding_side="left",
1802
+ num_beams=num_beams,
1803
+ )
1804
+ if past_key_values is not None:
1805
+ output = self.lang_model.generate(
1806
+ **new_inputs,
1807
+ past_key_values=past_key_values,
1808
+ num_beams=num_beams,
1809
+ use_cache=True,
1810
+ **kwargs,
1811
+ )
1812
+ else:
1813
+ output = self.lang_model.generate(
1814
+ **new_inputs,
1815
+ num_beams=num_beams,
1816
+ use_cache=True,
1817
+ **kwargs,
1818
+ )
1819
+ self._post_forward_hook()
1820
+ return output
1821
+
1822
+
1823
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
1824
+ model_type = "xgenmm_vision_encoder"
1825
+
1826
+ def __init__(
1827
+ self,
1828
+ model_name: str = "google/siglip-so400m-patch14-384",
1829
+ anyres_grids: list[int] = [
1830
+ [384, 768],
1831
+ [768, 384],
1832
+ [768, 768],
1833
+ [1152, 384],
1834
+ [384, 1152],
1835
+ ],
1836
+ **kwargs,
1837
+ ):
1838
+ self.model_name = model_name
1839
+ self.anyres_grids = anyres_grids
1840
+ super().__init__(**kwargs)
1841
+
1842
+
1843
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
1844
+ model_type = "xgenmm_vision_tokenizer"
1845
+
1846
+ def __init__(
1847
+ self,
1848
+ vis_feature_dim: int = 1152,
1849
+ lang_embedding_dim: int = 3072,
1850
+ num_vis_tokens: int = 128,
1851
+ image_aspect_ratio: str = "anyres",
1852
+ **kwargs,
1853
+ ):
1854
+ self.vis_feature_dim = vis_feature_dim
1855
+ self.lang_embedding_dim = lang_embedding_dim
1856
+ self.num_vis_tokens = num_vis_tokens
1857
+ self.image_aspect_ratio = image_aspect_ratio
1858
+ super().__init__(**kwargs)
1859
+
1860
+
1861
+ class XGenMMConfig(PretrainedConfig):
1862
+ model_type = "xgenmm"
1863
+
1864
+ def __init__(
1865
+ self,
1866
+ vision_encoder_config: dict = None,
1867
+ vision_tokenizer_config: dict = None,
1868
+ text_config: dict = None,
1869
+ **kwargs,
1870
+ ):
1871
+
1872
+ if vision_encoder_config is None:
1873
+ vision_encoder_config = {
1874
+ "image_aspect_ratio": "anyres",
1875
+ "anyres_patch_sampling": True,
1876
+ }
1877
+ logger.info(
1878
+ "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
1879
+ )
1880
+
1881
+ if vision_tokenizer_config is None:
1882
+ vision_tokenizer_config = {}
1883
+ logger.info(
1884
+ "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
1885
+ )
1886
+
1887
+ if text_config is None:
1888
+ text_config = {
1889
+ "initial_tokenizer_len": 32012,
1890
+ "pad_token_id": 32011,
1891
+ "bos_token_id": 1,
1892
+ "eos_token_id": 32000,
1893
+ "vocab_size": 32064,
1894
+ "hidden_size": 3072,
1895
+ "intermediate_size": 8192,
1896
+ "num_hidden_layers": 32,
1897
+ "num_attention_heads": 32,
1898
+ "num_key_value_heads": 32,
1899
+ "resid_pdrop": 0.0,
1900
+ "embd_pdrop": 0.0,
1901
+ "attention_dropout": 0.0,
1902
+ "hidden_act": "silu",
1903
+ "max_position_embeddings": 4096,
1904
+ "original_max_position_embeddings": 4096,
1905
+ "initializer_range": 0.02,
1906
+ "rms_norm_eps": 1e-05,
1907
+ "use_cache": True,
1908
+ "rope_theta": 10000.0,
1909
+ "rope_scaling": None,
1910
+ "sliding_window": 2047,
1911
+ "return_dict": True,
1912
+ "output_hidden_states": False,
1913
+ "output_attentions": False,
1914
+ "torchscript": False,
1915
+ "torch_dtype": "bfloat16",
1916
+ "use_bfloat16": False,
1917
+ "tf_legacy_loss": False,
1918
+ "pruned_heads": {},
1919
+ "tie_word_embeddings": False,
1920
+ "chunk_size_feed_forward": 0,
1921
+ "is_encoder_decoder": False,
1922
+ "is_decoder": False,
1923
+ "cross_attention_hidden_size": None,
1924
+ "add_cross_attention": False,
1925
+ "tie_encoder_decoder": False,
1926
+ "max_length": 20,
1927
+ "min_length": 0,
1928
+ "do_sample": False,
1929
+ "early_stopping": False,
1930
+ "num_beams": 1,
1931
+ "num_beam_groups": 1,
1932
+ "diversity_penalty": 0.0,
1933
+ "temperature": 1.0,
1934
+ "top_k": 50,
1935
+ "top_p": 1.0,
1936
+ "typical_p": 1.0,
1937
+ "repetition_penalty": 1.0,
1938
+ "length_penalty": 1.0,
1939
+ "no_repeat_ngram_size": 0,
1940
+ "encoder_no_repeat_ngram_size": 0,
1941
+ "bad_words_ids": None,
1942
+ "num_return_sequences": 1,
1943
+ "output_scores": False,
1944
+ "return_dict_in_generate": False,
1945
+ "forced_bos_token_id": None,
1946
+ "forced_eos_token_id": None,
1947
+ "remove_invalid_values": False,
1948
+ "exponential_decay_length_penalty": None,
1949
+ "suppress_tokens": None,
1950
+ "begin_suppress_tokens": None,
1951
+ "finetuning_task": None,
1952
+ "id2label": {0: "LABEL_0", 1: "LABEL_1"},
1953
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
1954
+ "tokenizer_class": None,
1955
+ "prefix": None,
1956
+ "bos_token_id": 1,
1957
+ "pad_token_id": 32000,
1958
+ "eos_token_id": 32000,
1959
+ "sep_token_id": None,
1960
+ "decoder_start_token_id": None,
1961
+ "task_specific_params": None,
1962
+ "problem_type": None,
1963
+ "model_type": "phi3",
1964
+ }
1965
+ logger.info(
1966
+ "text_config is None. Initializing the text config with default values (`Phi3Config`)."
1967
+ )
1968
+
1969
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
1970
+
1971
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(
1972
+ **vision_tokenizer_config
1973
+ )
1974
+
1975
+ text_model_type = (
1976
+ text_config["model_type"] if "model_type" in text_config else "phi3"
1977
+ )
1978
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
1979
+
1980
+ for key in ["initial_tokenizer_len", "pad_token_id"]:
1981
+ if key not in self.text_config.to_dict():
1982
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
1983
+
1984
+ super().__init__(**kwargs)
1985
+
1986
 
1987
  class XGenMMVisionEncoder(PreTrainedModel):
1988
  main_input_name = "pixel_values"
1989
  config_class = XGenMMVisionEncoderConfig
1990
+
1991
  def __init__(self, config: XGenMMVisionEncoderConfig):
1992
  super().__init__(config)
1993
+ if config.model_name != "google/siglip-so400m-patch14-384":
1994
+ raise ValueError(
1995
+ f"Unsupported model {config.model_name}. New vision models will be added soon."
1996
+ )
1997
  self.model = AutoModel.from_pretrained(config.model_name)
1998
+
1999
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
2000
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
2001
  return self.model.encode_image(pixel_values)
 
2002
 
2003
+
2004
+ # vision tokenizer
2005
  class XGenMMVisionTokenizer(PreTrainedModel):
2006
  config_class = XGenMMVisionTokenizerConfig
2007
+
2008
  def __init__(self, config: XGenMMVisionTokenizerConfig):
2009
  super().__init__(config)
2010
  self.model = PerceiverResampler(
 
2012
  dim_inner=config.lang_embedding_dim,
2013
  num_latents=config.num_vis_tokens,
2014
  )
2015
+
2016
+ def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor):
 
 
2017
  return self.model(vision_features, vision_attn_masks)
2018
+
2019
+
2020
  # XGenMM model
2021
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
2022
  config_class = XGenMMConfig
2023
+
2024
  def __init__(self, config: XGenMMConfig):
2025
  super().__init__(config)
2026
+
2027
  # vision encoder initialization
2028
+ vision_encoder = AutoModel.from_pretrained(
2029
+ config.vision_encoder_config.model_name
2030
+ ).vision_model
2031
+
2032
+ # language model initialization
2033
  language_model = AutoModelForCausalLM.from_config(config.text_config)
2034
  check_embedding_fns(language_model)
2035
  # Update _tied_weights_keys using the base model used.
2036
  if language_model._tied_weights_keys is not None:
2037
+ self._tied_weights_keys = [
2038
+ f"language_model.{k}" for k in language_model._tied_weights_keys
2039
+ ]
2040
+
2041
  # vision tokenizer initialization
2042
+ if (
2043
+ config.vision_tokenizer_config.lang_embedding_dim
2044
+ != language_model.get_input_embeddings().weight.shape[1]
2045
+ ):
2046
  overwrite = language_model.get_input_embeddings().weight.shape[1]
2047
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
2048
+ print(
2049
+ f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}."
2050
+ )
2051
+
2052
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
2053
 
2054
  self.vlm = XGenMMPerceiver(
2055
  vision_encoder=vision_encoder,
2056
  vision_tokenizer=vision_tokenizer,
2057
  lang_model=language_model,
2058
+ initial_tokenizer_len=config.text_config.initial_tokenizer_len,
2059
+ pad_token_id=config.text_config.pad_token_id,
2060
+ image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio,
2061
+ anyres_patch_sampling=config.vision_encoder_config.anyres_patch_sampling,
2062
+ anyres_grids=config.vision_encoder_config.anyres_grids,
2063
  )
2064
  # Initialize weights and apply final processing
2065
  self.post_init()
2066
+
2067
  @torch.no_grad()
2068
  def generate(
2069
  self,
 
2071
  input_ids: Optional[torch.LongTensor] = None,
2072
  attention_mask: Optional[torch.LongTensor] = None,
2073
  **generate_kwargs,
2074
+ ) -> torch.LongTensor:
2075
  self.vlm = self.vlm.eval()
2076
  return self.vlm.generate(
2077
+ vision_x=pixel_values,
2078
+ lang_x=input_ids,
2079
+ attention_mask=attention_mask,
2080
+ **generate_kwargs,
2081
+ )
2082
+
2083
  def update_special_tokens(self, tokenizer):
2084
  tokenizer.add_special_tokens(
2085
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
 
2087
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
2088
  self.vlm.set_special_token_ids(
2089
  {
2090
+ v: tokenizer.convert_tokens_to_ids(v)
2091
+ for v in self.vlm.special_tokens.values()
2092
  }
2093
  )
2094
  return tokenizer
 
utils.py DELETED
@@ -1,383 +0,0 @@
1
- import torch
2
- import ast
3
- import math
4
- from PIL import Image
5
- from packaging.version import Version
6
-
7
- def has_fn(model, fn_name):
8
- """Check if model has a function fn_name"""
9
- return callable(getattr(model, fn_name, None))
10
-
11
- def exists(val):
12
- return val is not None
13
-
14
- def num_params(module, filter_to_trainable=False):
15
- """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
- if filter_to_trainable:
17
- return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
- else:
19
- return sum(p.numel() for p in module.parameters())
20
-
21
- def hasattr_recursive(obj, att):
22
- """
23
- Check if obj has nested attribute
24
- Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
- """
26
- if att == "":
27
- return True
28
- i = att.find(".")
29
- if i < 0:
30
- return hasattr(obj, att)
31
- else:
32
- try:
33
- return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
- except:
35
- return False
36
-
37
- def getattr_recursive(obj, att):
38
- """
39
- Return nested attribute of obj
40
- Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
- """
42
- if att == "":
43
- return obj
44
- i = att.find(".")
45
- if i < 0:
46
- return getattr(obj, att)
47
- else:
48
- return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
-
50
-
51
- def setattr_recursive(obj, att, val):
52
- """
53
- Set nested attribute of obj
54
- Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
- """
56
- if "." in att:
57
- obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
- setattr(obj, att.split(".")[-1], val)
59
-
60
-
61
- def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
- """
63
- Stack a list of tensors with padding on one side
64
- Args:
65
- list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
- padding_value (int, optional): Value to pad with. Defaults to 0.
67
- padding_side (str, optional): Side to pad on. Defaults to "right".
68
- Returns:
69
- torch.Tensor: Stacked tensors
70
- """
71
- max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
- padded_tensors = []
73
- for tensor in list_of_tensors:
74
- num_tokens = tensor.size(0)
75
- if len(tensor.size()) == 1:
76
- padding = torch.full(
77
- (max_tokens - num_tokens,),
78
- padding_value,
79
- dtype=tensor.dtype,
80
- device=tensor.device,
81
- )
82
- else:
83
- padding = torch.full(
84
- (max_tokens - num_tokens, tensor.size(1)),
85
- padding_value,
86
- dtype=tensor.dtype,
87
- device=tensor.device,
88
- )
89
- padded_tensor = (
90
- torch.cat((tensor, padding), dim=0)
91
- if padding_side == "right"
92
- else torch.cat((padding, tensor), dim=0)
93
- )
94
- padded_tensors.append(padded_tensor)
95
- return torch.stack(padded_tensors)
96
-
97
-
98
- def check_embedding_fns(lang_model):
99
- """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
- if not has_fn(lang_model, "get_input_embeddings"):
101
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
- lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
- lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
- else:
106
- raise ValueError(
107
- "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
- )
109
-
110
- if not has_fn(lang_model, "set_input_embeddings"):
111
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
- lang_model, "transformer.wte", x
114
- )
115
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
- lang_model, "model.decoder.embed_tokens", x
118
- )
119
- else:
120
- raise ValueError(
121
- "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
- )
123
-
124
- if not has_fn(lang_model, "get_output_embeddings"):
125
- if hasattr_recursive(lang_model, "lm_head"):
126
- lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
- else:
128
- raise ValueError(
129
- "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
- )
131
-
132
- if not has_fn(lang_model, "set_output_embeddings"):
133
- if hasattr_recursive(lang_model, "lm_head"):
134
- lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
- lang_model, "lm_head", x
136
- )
137
- else:
138
- raise ValueError(
139
- "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
- )
141
-
142
-
143
- def has_fn(model, fn_name):
144
- """Check if model has a function fn_name"""
145
- return callable(getattr(model, fn_name, None))
146
-
147
-
148
- # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
- #
150
- # Licensed under the Apache License, Version 2.0 (the "License");
151
- # you may not use this file except in compliance with the License.
152
- # You may obtain a copy of the License at
153
- #
154
- # http://www.apache.org/licenses/LICENSE-2.0
155
- #
156
- # Unless required by applicable law or agreed to in writing, software
157
- # distributed under the License is distributed on an "AS IS" BASIS,
158
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
- # See the License for the specific language governing permissions and
160
- # limitations under the License.
161
-
162
- def unpad_image(tensor, original_size, keep_original_shape=False):
163
- """
164
- Unpads a PyTorch tensor of a padded and resized image.
165
-
166
- Args:
167
- tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
- original_size (tuple): The original size of the image (height, width).
169
-
170
- Returns:
171
- torch.Tensor: The unpadded image tensor.
172
- """
173
- original_width, original_height = original_size
174
- current_height, current_width = tensor.shape[1:]
175
-
176
- original_aspect_ratio = original_width / original_height
177
- current_aspect_ratio = current_width / current_height
178
-
179
- if original_aspect_ratio > current_aspect_ratio:
180
- scale_factor = current_width / original_width
181
- new_height = int(original_height * scale_factor)
182
- padding = (current_height - new_height) // 2
183
- if keep_original_shape:
184
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
- attention_mask[:padding, :] = 0
186
- attention_mask[current_height - padding:, :] = 0
187
- return tensor, attention_mask
188
- else:
189
- unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
- return unpadded_tensor, None
191
- else:
192
- scale_factor = current_height / original_height
193
- new_width = int(original_width * scale_factor)
194
- padding = (current_width - new_width) // 2
195
- if keep_original_shape:
196
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
- attention_mask[:, :padding] = 0
198
- attention_mask[:, current_width - padding:] = 0
199
- return tensor, attention_mask
200
- else:
201
- unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
- return unpadded_tensor, None
203
-
204
-
205
- def select_best_resolution(original_size, possible_resolutions):
206
- """
207
- Selects the best resolution from a list of possible resolutions based on the original size.
208
-
209
- Args:
210
- original_size (tuple): The original size of the image in the format (width, height).
211
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
-
213
- Returns:
214
- tuple: The best fit resolution in the format (width, height).
215
- """
216
- original_width, original_height = original_size
217
- best_fit = None
218
- max_effective_resolution = 0
219
- min_wasted_resolution = float('inf')
220
-
221
- for width, height in possible_resolutions:
222
- scale = min(width / original_width, height / original_height)
223
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
- wasted_resolution = (width * height) - effective_resolution
226
-
227
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
- max_effective_resolution = effective_resolution
229
- min_wasted_resolution = wasted_resolution
230
- best_fit = (width, height)
231
-
232
- return best_fit
233
-
234
-
235
- def resize_and_pad_image(image, target_resolution):
236
- """
237
- Resize and pad an image to a target resolution while maintaining aspect ratio.
238
-
239
- Args:
240
- image (PIL.Image.Image): The input image.
241
- target_resolution (tuple): The target resolution (width, height) of the image.
242
-
243
- Returns:
244
- PIL.Image.Image: The resized and padded image.
245
- """
246
- original_width, original_height = image.size
247
- target_width, target_height = target_resolution
248
-
249
- scale_w = target_width / original_width
250
- scale_h = target_height / original_height
251
-
252
- if scale_w < scale_h:
253
- new_width = target_width
254
- new_height = min(math.ceil(original_height * scale_w), target_height)
255
- else:
256
- new_height = target_height
257
- new_width = min(math.ceil(original_width * scale_h), target_width)
258
-
259
- # Resize the image
260
- resized_image = image.resize((new_width, new_height))
261
-
262
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
- paste_x = (target_width - new_width) // 2
264
- paste_y = (target_height - new_height) // 2
265
- new_image.paste(resized_image, (paste_x, paste_y))
266
-
267
- return new_image
268
-
269
-
270
- def divide_to_patches(image, patch_size):
271
- """
272
- Divides an image into patches of a specified size.
273
-
274
- Args:
275
- image (PIL.Image.Image): The input image.
276
- patch_size (int): The size of each patch.
277
-
278
- Returns:
279
- list: A list of PIL.Image.Image objects representing the patches.
280
- """
281
- patches = []
282
- width, height = image.size
283
- for i in range(0, height, patch_size):
284
- for j in range(0, width, patch_size):
285
- box = (j, i, j + patch_size, i + patch_size)
286
- patch = image.crop(box)
287
- patches.append(patch)
288
-
289
- return patches
290
-
291
-
292
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
- """
294
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
-
296
- Args:
297
- image_size (tuple): The size of the input image in the format (width, height).
298
- grid_pinpoints (str): A string representation of a list of possible resolutions.
299
- patch_size (int): The size of each image patch.
300
-
301
- Returns:
302
- tuple: The shape of the image patch grid in the format (width, height).
303
- """
304
- if type(grid_pinpoints) is list:
305
- possible_resolutions = grid_pinpoints
306
- else:
307
- possible_resolutions = ast.literal_eval(grid_pinpoints)
308
- width, height = select_best_resolution(image_size, possible_resolutions)
309
- return width // patch_size, height // patch_size
310
-
311
-
312
- def process_anyres_image(image, processor, grid_pinpoints):
313
- """
314
- Process an image with variable resolutions.
315
-
316
- Args:
317
- image (PIL.Image.Image): The input image to be processed.
318
- processor: The image processor object.
319
- grid_pinpoints (str): A string representation of a list of possible resolutions.
320
-
321
- Returns:
322
- torch.Tensor: A tensor containing the processed image patches.
323
- """
324
- # FIXME: determine grid_pinpoints from image sizes.
325
- if type(grid_pinpoints) is list:
326
- possible_resolutions = grid_pinpoints
327
- else:
328
- possible_resolutions = ast.literal_eval(grid_pinpoints)
329
- best_resolution = select_best_resolution(image.size, possible_resolutions)
330
- image_padded = resize_and_pad_image(image, best_resolution)
331
-
332
- processor_size = processor.transforms[0].size
333
- patches = divide_to_patches(image_padded, processor_size[0])
334
-
335
- image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
-
337
- image_patches = [image_original_resize] + patches
338
- image_patches = [processor(image_patch)
339
- for image_patch in image_patches]
340
- return torch.stack(image_patches, dim=0)
341
-
342
-
343
- def expand2square(pil_img, background_color):
344
- width, height = pil_img.size
345
- if width == height:
346
- return pil_img
347
- elif width > height:
348
- result = Image.new(pil_img.mode, (width, width), background_color)
349
- result.paste(pil_img, (0, (width - height) // 2))
350
- return result
351
- else:
352
- result = Image.new(pil_img.mode, (height, height), background_color)
353
- result.paste(pil_img, ((height - width) // 2, 0))
354
- return result
355
-
356
-
357
- def process_images(images, image_processor, model_cfg):
358
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
- new_images = []
360
- if image_aspect_ratio == 'pad':
361
- for image in images:
362
- image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
- image = image_processor(image)
364
- new_images.append(image)
365
- elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
- base_img_size = image_processor.transforms[0].size[0]
367
- for image in images:
368
- image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
- [base_img_size*2,base_img_size],
370
- [base_img_size*2,base_img_size*2],
371
- [base_img_size*3,base_img_size],
372
- [base_img_size,base_img_size*3]])
373
-
374
- # Debug any res inference by only using 672x672.
375
- # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
- new_images.append(image)
377
- else:
378
- return image_processor(images)
379
- if all(x.shape == new_images[0].shape for x in new_images):
380
- new_images = torch.stack(new_images, dim=0)
381
- return new_images
382
-
383
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vlm.py DELETED
@@ -1,1381 +0,0 @@
1
-
2
- import torch
3
- from torch import einsum, nn
4
- from einops import rearrange, repeat
5
- from einops_exts import rearrange_many
6
- from einops import rearrange
7
- from typing import List, Optional, Tuple, Union
8
- import torch.nn.functional as F
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from dataclasses import dataclass
11
- from transformers import CLIPVisionModel
12
- from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
13
-
14
- import transformers
15
- from packaging.version import Version
16
-
17
- from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
18
-
19
-
20
- class VisionTokenizer(nn.Module):
21
- def __init__(self, dim_media, num_tokens_per_media):
22
- super().__init__()
23
- self.dim_media = dim_media
24
- self.num_tokens_per_media = num_tokens_per_media
25
-
26
- class PerceiverAttention(nn.Module):
27
- def __init__(self, *, dim, dim_head=64, heads=8):
28
- super().__init__()
29
- self.scale = dim_head**-0.5
30
- self.heads = heads
31
- inner_dim = dim_head * heads
32
-
33
- self.norm_media = nn.LayerNorm(dim)
34
- self.norm_latents = nn.LayerNorm(dim)
35
-
36
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
37
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
38
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
39
-
40
- def forward(self, x, latents, vision_attn_masks=None):
41
- """
42
- Args:
43
- x (torch.Tensor): image features
44
- shape (b, T, n1, D)
45
- latent (torch.Tensor): latent features
46
- shape (b, T, n2, D)
47
- """
48
- x = self.norm_media(x)
49
- latents = self.norm_latents(latents)
50
-
51
- h = self.heads
52
-
53
- q = self.to_q(latents)
54
- kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
55
- if vision_attn_masks is not None:
56
- vision_attn_masks = torch.cat((vision_attn_masks,
57
- torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
58
- dim=-1)
59
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
- q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
- q = q * self.scale
62
-
63
- # attention
64
- sim = einsum("... i d, ... j d -> ... i j", q, k)
65
- # Apply vision attention mask here.
66
- # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
67
- if vision_attn_masks is not None:
68
- attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
69
- vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
70
- attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
71
- sim += attn_bias
72
-
73
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
74
- attn = sim.softmax(dim=-1)
75
-
76
-
77
- out = einsum("... i j, ... j d -> ... i d", attn, v)
78
- out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
79
- return self.to_out(out)
80
-
81
-
82
- def FeedForward(dim, mult=4):
83
- inner_dim = int(dim * mult)
84
- return nn.Sequential(
85
- nn.LayerNorm(dim),
86
- nn.Linear(dim, inner_dim, bias=False),
87
- nn.GELU(),
88
- nn.Linear(inner_dim, dim, bias=False),
89
- )
90
-
91
-
92
- class PerceiverResampler(VisionTokenizer):
93
- def __init__(
94
- self,
95
- *,
96
- dim,
97
- dim_inner=None,
98
- depth=6,
99
- dim_head=96,
100
- heads=16,
101
- num_latents=128,
102
- max_num_media=None,
103
- max_num_frames=None,
104
- ff_mult=4,
105
- ):
106
- """
107
- Perceiver module which takes in image features and outputs image tokens.
108
- Args:
109
- dim (int): dimension of the incoming image features
110
- dim_inner (int, optional): final dimension to project the incoming image features to;
111
- also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
112
- depth (int, optional): number of layers. Defaults to 6.
113
- dim_head (int, optional): dimension of each head. Defaults to 64.
114
- heads (int, optional): number of heads. Defaults to 8.
115
- num_latents (int, optional): number of latent tokens to use in the Perceiver;
116
- also corresponds to number of tokens per sequence to output. Defaults to 64.
117
- max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
118
- and keep positional embeddings for. If None, no positional embeddings are used.
119
- max_num_frames (int, optional): maximum number of frames to input into the Perceiver
120
- and keep positional embeddings for. If None, no positional embeddings are used.
121
- ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
122
- """
123
- if dim_inner is not None:
124
- projection = nn.Linear(dim, dim_inner)
125
- else:
126
- projection = None
127
- dim_inner = dim
128
- super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
129
- self.projection = projection
130
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
131
-
132
- # positional embeddings
133
- self.frame_embs = (
134
- nn.Parameter(torch.randn(max_num_frames, dim))
135
- if exists(max_num_frames)
136
- else None
137
- )
138
- self.media_time_embs = (
139
- nn.Parameter(torch.randn(max_num_media, 1, dim))
140
- if exists(max_num_media)
141
- else None
142
- )
143
-
144
- self.layers = nn.ModuleList([])
145
- for _ in range(depth):
146
- self.layers.append(
147
- nn.ModuleList(
148
- [
149
- PerceiverAttention(
150
- dim=dim, dim_head=dim_head, heads=heads
151
- ),
152
- FeedForward(dim=dim, mult=ff_mult),
153
- ]
154
- )
155
- )
156
-
157
- self.norm = nn.LayerNorm(dim)
158
-
159
- def forward(self, x, vision_attn_masks):
160
- """
161
- Args:
162
- x (torch.Tensor): image features
163
- shape (b, T, F, v, D)
164
- vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
165
- shape (b, v)
166
- Returns:
167
- shape (b, T, n, D) where n is self.num_latents
168
- """
169
- b, T, F, v = x.shape[:4]
170
-
171
- # frame and media time embeddings
172
- if exists(self.frame_embs):
173
- frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
174
- x = x + frame_embs
175
- x = rearrange(
176
- x, "b T F v d -> b T (F v) d"
177
- ) # flatten the frame and spatial dimensions
178
- if exists(self.media_time_embs):
179
- x = x + self.media_time_embs[:T]
180
-
181
- # blocks
182
- latents = self.latents
183
- latents = repeat(latents, "n d -> b T n d", b=b, T=T)
184
- for attn, ff in self.layers:
185
- latents = attn(x, latents, vision_attn_masks) + latents
186
- latents = ff(latents) + latents
187
-
188
- if exists(self.projection):
189
- return self.projection(self.norm(latents))
190
- else:
191
- return self.norm(latents)
192
-
193
-
194
- class DecoupledEmbedding(nn.Embedding):
195
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
196
- """
197
- Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
198
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
199
- then it will create `num_additional_embeddings` additional parameters that are always trained. If
200
- `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
201
- """
202
-
203
- def __init__(
204
- self,
205
- max_original_id: int,
206
- num_additional_embeddings: int = 0,
207
- _weight: torch.Tensor = None,
208
- num_original_embeddings: int = None,
209
- embedding_dim: int = None,
210
- partially_freeze=True,
211
- device=None,
212
- dtype=None,
213
- pad_token_id=None,
214
- ) -> None:
215
- """
216
- Args:
217
- max_original_id (`int`):
218
- The largest token id that should be embedded using the regular embedding (regular `weight`).
219
- This is usually len(tokenizer) - 1 before additional tokens are added.
220
- Note that this may not equal self.weight.shape[0]
221
- num_additional_embeddings (`int`):
222
- Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
223
- _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
224
- If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
225
- num_original_embeddings (`int`):
226
- self.weight.shape[0]
227
- embedding_dim (`int`):
228
- The size of each embedding vector
229
- partially_freeze: (`bool`, *optional*, defaults to `True`):
230
- If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
231
- padding_idx (`int`, *optional*):
232
- The padding index (needs to be less than num_embeddings)
233
-
234
- Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
235
- `max_norm` or `norm_type`. We are not supporting these.
236
- """
237
- # validate args
238
- if pad_token_id is not None and pad_token_id > max_original_id:
239
- raise ValueError(
240
- f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
241
- + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
242
- )
243
- if _weight is not None:
244
- assert (num_original_embeddings is None) or (
245
- _weight.shape[0] == num_original_embeddings
246
- ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
247
- assert (embedding_dim is None) or (
248
- _weight.shape[1] == embedding_dim
249
- ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
250
- num_original_embeddings = _weight.shape[0]
251
- embedding_dim = _weight.shape[1]
252
- else:
253
- assert (
254
- num_original_embeddings is not None
255
- ), "num_original_embeddings must be provided if _weight is not provided"
256
- assert (
257
- embedding_dim is not None
258
- ), "embedding_dim must be provided if _weight is not provided"
259
-
260
- super().__init__(
261
- num_embeddings=num_original_embeddings,
262
- embedding_dim=embedding_dim,
263
- device=device,
264
- dtype=dtype,
265
- padding_idx=pad_token_id,
266
- _weight=_weight,
267
- )
268
- self.max_original_id = max_original_id
269
- self.padding_idx = pad_token_id
270
- self.num_additional_embeddings = num_additional_embeddings
271
- if self.num_additional_embeddings > 0:
272
- self.additional_embedding = nn.Embedding(
273
- num_embeddings=self.num_additional_embeddings,
274
- embedding_dim=embedding_dim,
275
- device=device,
276
- dtype=dtype,
277
- )
278
- self.set_requires_grad(
279
- require_regular_grad=not partially_freeze, require_additional_grad=True
280
- )
281
-
282
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
283
- """
284
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
285
- """
286
- self.weight.requires_grad_(require_regular_grad)
287
- self.additional_embedding.requires_grad_(require_additional_grad)
288
-
289
- def forward(self, input_ids):
290
- """
291
- we have 2 embeddings, with different indices - one pretrained self.weight and another
292
- self.additional_embedding.weight that is being trained.
293
-
294
- in order to make a lookup of the input ids, we:
295
- 1. find out the indices of the entries belonging to the 2nd embedding
296
- 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
297
- embedding starts from 0 and not num_embeddings
298
- 3. perform the 2nd embedding lookup
299
- 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
300
- 5. perform the 1st embedding lookup
301
- 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
302
-
303
- note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
304
- then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
305
- i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
306
- usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
307
- measure.
308
-
309
- """
310
- if self.num_additional_embeddings == 0:
311
- return F.embedding(input_ids, self.weight)
312
-
313
- # Clone so that we don't modify the original input_ids later on
314
- input_ids = input_ids.clone()
315
- additional_vocab_indices = torch.where(input_ids > self.max_original_id)
316
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
317
- additional_embeddings = self.additional_embedding(
318
- input_ids_additional_vocab - self.max_original_id - 1
319
- )
320
-
321
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
322
- input_ids[additional_vocab_indices] = 0
323
- full_vector = F.embedding(input_ids, self.weight)
324
-
325
- # overwrite the records with high indices
326
- full_vector[additional_vocab_indices] = additional_embeddings
327
-
328
- return full_vector
329
-
330
- def extra_repr(self) -> str:
331
- return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
332
- self.max_original_id + 1,
333
- self.num_additional_embeddings,
334
- self.embedding_dim,
335
- (not self.weight.requires_grad),
336
- )
337
-
338
-
339
- class DecoupledLinear(nn.Linear):
340
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
341
- """
342
- Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
343
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
344
- then it will create `additional_out_features * in_features` additional parameters that are always trained. If
345
- `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
346
- """
347
-
348
- def __init__(
349
- self,
350
- max_original_id: int,
351
- additional_out_features: int = 0,
352
- _weight: torch.Tensor = None,
353
- _bias: torch.Tensor = None,
354
- in_features: int = None,
355
- original_out_features: int = None,
356
- bias: bool = True,
357
- partially_freeze: bool = True,
358
- device=None,
359
- dtype=None,
360
- ) -> None:
361
- """
362
- Args:
363
- max_original_id (`int`): The largest token id that should be extracted from the regular weight.
364
- This is usually len(tokenizer) - 1 before additional tokens are added.
365
- Note that this may not equal original_out_features - 1
366
- _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
367
- If provided, this sets the `in_features` and `original_out_features` parameters.
368
- _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
369
- in_features: int. Input hidden size.
370
- original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
371
- additional_out_features: int. Number of additional trainable dimensions.
372
- bias: bool. Whether to include a bias term.
373
- partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
374
- """
375
- # argument validation
376
- if _weight is not None:
377
- assert (_weight.shape[0] == original_out_features) or (
378
- original_out_features is None
379
- ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
380
- assert (_weight.shape[1] == in_features) or (
381
- in_features is None
382
- ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
383
- in_features = _weight.shape[1]
384
- original_out_features = _weight.shape[0]
385
- else:
386
- assert (
387
- in_features is not None
388
- ), "in_features must be provided if _weight is not provided"
389
- assert (
390
- original_out_features is not None
391
- ), "original_out_features must be provided if _weight is not provided"
392
-
393
- if _bias is not None:
394
- assert bias is True, "bias must be True if _bias is provided"
395
-
396
- # initialize original linear
397
- super().__init__(
398
- in_features,
399
- original_out_features,
400
- bias,
401
- device,
402
- dtype)
403
-
404
- # set weight and bias manually
405
- if _weight is not None:
406
- self.weight = nn.Parameter(_weight)
407
- if _bias is not None:
408
- self.bias = nn.Parameter(_bias)
409
-
410
- self.in_features = in_features
411
- self.original_out_features = original_out_features
412
- self.max_original_id = max_original_id
413
-
414
- # initialize additional linear
415
- self.additional_out_features = additional_out_features
416
- self.has_bias = bias
417
- if additional_out_features > 0:
418
- self.additional_fc = nn.Linear(
419
- in_features=in_features,
420
- out_features=additional_out_features,
421
- bias=self.has_bias,
422
- device=device,
423
- dtype=dtype,
424
- )
425
- self.set_requires_grad(
426
- require_regular_grad=not partially_freeze, require_additional_grad=True
427
- )
428
-
429
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
430
- """
431
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
432
- """
433
- self.weight.requires_grad_(require_regular_grad)
434
- if self.has_bias:
435
- self.bias.requires_grad_(require_regular_grad)
436
- self.additional_fc.requires_grad_(require_additional_grad)
437
-
438
- def forward(self, input: torch.Tensor) -> torch.Tensor:
439
- output = F.linear(input, self.weight, self.bias)
440
- output = output[..., : self.max_original_id + 1]
441
-
442
- if self.additional_out_features > 0:
443
- additional_features = F.linear(
444
- input, self.additional_fc.weight, self.additional_fc.bias
445
- )
446
- output = torch.cat((output, additional_features), -1)
447
- return output
448
-
449
- def extra_repr(self) -> str:
450
- """Overwriting `nn.Linear.extra_repr` to include new parameters."""
451
- return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
452
- self.in_features,
453
- self.max_original_id + 1,
454
- self.additional_out_features,
455
- self.bias is not None,
456
- (not self.weight.requires_grad or not self.bias.requires_grad),
457
- )
458
-
459
- class VLM(nn.Module):
460
- """
461
- Generic vision-language model (VLM) class.
462
- A VLM consists of four components:
463
- 1. A vision encoder that extracts features from pixels, e.g. CLIP
464
- input: (B, T_img, F, C, H, W)
465
- output: (B, T_img, F, v, d)
466
- 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
467
- input: (B, T_img, F, v, d)
468
- output: (B, T_img, n, d)
469
- 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
470
- 4. A language model
471
- """
472
-
473
- def __init__(
474
- self,
475
- vision_encoder: nn.Module,
476
- vision_tokenizer: nn.Module,
477
- lang_model: nn.Module,
478
- initial_tokenizer_len: int,
479
- pad_token_id: int,
480
- gradient_checkpointing: bool = False,
481
- ):
482
- """
483
- Args:
484
- vision_encoder (nn.Module): e.g. CLIP
485
- vision_tokenizer (nn.Module): e.g. PerceiverResampler
486
- lang_model (nn.Module): e.g. MPT
487
- initial_tokenizer_len (int): size of the original tokenizer vocab
488
- pad_token_id (int): id of the pad token
489
- gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
490
- """
491
- super().__init__()
492
-
493
- # save dimension information
494
- self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
495
- if hasattr(lang_model.config, "d_model"):
496
- self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
497
- else:
498
- self.lang_hidden_dim = lang_model.config.hidden_size
499
- self.vis_embedding_dim = vision_tokenizer.dim_media
500
- self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
501
-
502
- # core components
503
- self.vision_encoder = vision_encoder
504
- self.vision_tokenizer = vision_tokenizer
505
- self.lang_model = lang_model
506
-
507
- # lm embeddings
508
- self.pad_token_id = pad_token_id
509
- self.initial_tokenizer_len = initial_tokenizer_len
510
- input_embeds = DecoupledEmbedding(
511
- max_original_id=initial_tokenizer_len - 1,
512
- num_additional_embeddings=len(self.special_tokens),
513
- _weight=self.lang_model.get_input_embeddings().weight,
514
- pad_token_id=self.pad_token_id,
515
- )
516
- if hasattr(input_embeds, "additional_embedding"):
517
- input_embeds.additional_embedding.weight.data.normal_(
518
- mean=0.0,
519
- std=self.lang_model.config.initializer_range
520
- if hasattr(self.lang_model.config, "initializer_range")
521
- else 0.02,
522
- )
523
- self.lang_model.set_input_embeddings(input_embeds)
524
-
525
- out_embeds = DecoupledLinear(
526
- max_original_id=initial_tokenizer_len - 1,
527
- additional_out_features=len(self.special_tokens),
528
- _weight=self.lang_model.get_output_embeddings().weight,
529
- _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
530
- )
531
- if hasattr(out_embeds, "additional_fc"):
532
- out_embeds.additional_fc.weight.data.normal_(
533
- mean=0.0,
534
- std=self.lang_model.config.initializer_range
535
- if hasattr(self.lang_model.config, "initializer_range")
536
- else 0.02,
537
- )
538
- self.lang_model.set_output_embeddings(out_embeds)
539
-
540
- # gradient checkpointing
541
- self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
542
-
543
- def forward(
544
- self,
545
- vision_x: Optional[torch.Tensor],
546
- lang_x: torch.Tensor,
547
- attention_mask: Optional[torch.Tensor] = None,
548
- labels: Optional[torch.Tensor] = None,
549
- past_key_values: Optional[
550
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
551
- ] = None,
552
- past_media_locations: Optional[torch.Tensor] = None,
553
- past_vision_tokens: Optional[torch.Tensor] = None,
554
- use_cache: Optional[bool] = False,
555
- **kwargs,
556
- ):
557
- """
558
- Args:
559
- vision_x: Vision input
560
- shape (B, T_img, F, C, H, W) with F=1
561
- only F = 1 is supported (single-frame videos)
562
- if T_img > the number of media tokens in the corresponding input_ids (lang_x),
563
- only the first number of media tokens in lang_x are used
564
- lang_x: Language input ids, with media tokens denoting where
565
- visual media should be inserted.
566
- shape (B, T_txt)
567
- attention_mask: Attention mask. Defaults to None.
568
- labels: Labels. Defaults to None.
569
- shape (B, T_txt)
570
- past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
571
- list of length = number of decoder layers in the LM
572
- exact implementation depends on LM, see Hugging Face docs
573
- past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
574
- shape (B, T_txt)
575
- past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
576
- use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
577
- If True, includes key_values, media_locations, and vision_tokens in the output.
578
- """
579
- assert not (past_vision_tokens is None) ^ (
580
- past_media_locations is None
581
- ), "past_vision_tokens and past_media_locations must both be None or both be not None"
582
-
583
- # convert pixels to vision tokens
584
- if vision_x is not None:
585
- vision_features = self._encode_vision_x(vision_x=vision_x)
586
- vision_tokens = self.vision_tokenizer(vision_features)
587
- else:
588
- vision_tokens = None
589
-
590
- # fuse the vision and language tokens
591
- new_inputs = self._prepare_inputs_for_forward(
592
- vision_tokens=vision_tokens,
593
- lang_x=lang_x,
594
- attention_mask=attention_mask,
595
- labels=labels,
596
- past_key_values=past_key_values,
597
- past_media_locations=past_media_locations,
598
- padding_side="right",
599
- past_vision_tokens=past_vision_tokens,
600
- )
601
- output = self.lang_model(
602
- **new_inputs,
603
- use_cache=use_cache,
604
- past_key_values=past_key_values,
605
- **kwargs,
606
- )
607
-
608
- # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
609
- # or to add the past_vision_tokens and past_media_locations to the output
610
- output = self._postprocess_outputs_from_forward(
611
- output=output,
612
- lang_x=lang_x,
613
- vision_tokens=vision_tokens,
614
- use_cache=use_cache,
615
- past_vision_tokens=past_vision_tokens,
616
- past_media_locations=past_media_locations,
617
- )
618
-
619
- # postforward hooks
620
- self._post_forward_hook()
621
- return output
622
-
623
- def _encode_vision_x_anyres(self, samples, device):
624
- assert self.anyres_grids is not None
625
- image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
626
- image_sizes = samples["image_size"]
627
-
628
- # Image_raw can be a list of list of patches, when a `samples` has multiple images.
629
- if isinstance(image_raw[0], list):
630
- images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
631
- image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
632
- else:
633
- # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
634
- # concate list of patches into one big patch for any res encoding.
635
- images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
636
- image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
637
- image = image.to(device)
638
-
639
- with torch.no_grad():
640
- if self.vision_encoder.__class__.__name__ == "TimmModel":
641
- image_embeds = self.vision_encoder.trunk.forward_features(image)
642
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
643
- image_embeds = self.vision_encoder(image).last_hidden_state
644
- else:
645
- image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
646
-
647
- if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
648
- base_img_size = self.vision_encoder.config.image_size
649
- else:
650
- base_img_size = self.vision_encoder.image_size[0]
651
-
652
- if self.vision_encoder.__class__.__name__ == "TimmModel":
653
- grid_size = self.vision_encoder.trunk.patch_embed.grid_size
654
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
655
- grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
656
- grid_size = (grid_size_base, grid_size_base)
657
- else:
658
- grid_size = self.vision_encoder.grid_size
659
- height, width = grid_size
660
-
661
- if not image_embeds.shape[1] == height * width:
662
- assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
663
- image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
664
- n_vis_token_per_patch = image_embeds.shape[1]
665
-
666
- # Split encoded patches and merge patch features
667
- # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
668
- split_sizes = [image.shape[0] for image in images]
669
- image_embeds = torch.split(image_embeds, split_sizes, dim=0)
670
- # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
671
- new_image_embeds = []
672
- patch_attn_masks = []
673
- max_n_img_token = -1
674
- for idx, patch_embeds in enumerate(image_embeds):
675
- if patch_embeds.shape[0] > 1:
676
- # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
677
- base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
678
- patch_embeds = patch_embeds[1:]
679
-
680
- assert height * width == base_patch_embeds.shape[0]
681
-
682
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
683
- self.anyres_grids,
684
- base_img_size) # Hardcoded grid_pinpoints.
685
- patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
686
-
687
- patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
688
- patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
689
- patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
690
- if hasattr(self, 'image_newline'):
691
- patch_embeds = torch.cat((
692
- patch_embeds,
693
- self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
694
- ), dim=-1)
695
- if self.anyres_patch_sampling:
696
- patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
697
- patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
698
- assert patch_attn_mask is not None
699
- patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
700
- patch_attn_mask = patch_attn_mask.flatten(0, 1)
701
- patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
702
- patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
703
- else:
704
- patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
705
- patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
706
- else:
707
- patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
708
- patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
709
- if hasattr(self, 'image_newline'):
710
- patch_embeds = torch.cat((
711
- patch_embeds,
712
- self.image_newline[None]
713
- ), dim=0)
714
- if not self.anyres_patch_sampling:
715
- max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
716
-
717
- new_image_embeds.append(patch_embeds)
718
- patch_attn_masks.append(patch_attn_mask)
719
-
720
- if self.anyres_patch_sampling:
721
- # Return individual patches for independent token downsampling.
722
- return new_image_embeds, patch_attn_masks
723
-
724
- # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
725
- image_embeds = []
726
- image_atts = []
727
- for image_embed in new_image_embeds:
728
- n_img_token = image_embed.shape[0]
729
- img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
730
- if n_img_token < max_n_img_token:
731
- padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
732
- padded_embed[:n_img_token, :] = image_embed
733
- img_attn[n_img_token:] = 0 # Mask out the padded entries.
734
- else:
735
- padded_embed = image_embed
736
- image_embeds.append(padded_embed)
737
- image_atts.append(img_attn)
738
- image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
739
- image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
740
- # TODO: reshape image_embeds and image_atts to "b T F v d"
741
- image_embeds = image_embeds[:, None, None, :, :]
742
- # image_atts = image_atts[:, None, None, :, :]
743
-
744
- return image_embeds, image_atts
745
-
746
- def _encode_vision_x(self, vision_x: torch.Tensor):
747
- """
748
- Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
749
- Args:
750
- vision_x: Vision input
751
- shape (B, T_img, F, C, H, W)
752
- Images in the same chunk are collated along T_img, and frames are collated along F
753
- Currently only F=1 is supported (single-frame videos)
754
-
755
- rearrange code based on https://github.com/dhansmair/flamingo-mini
756
- """
757
- assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
758
- b, T, F = vision_x.shape[:3]
759
-
760
- vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
761
- with torch.no_grad():
762
- if self.vision_encoder.__class__.__name__ == "TimmModel":
763
- vision_x = self.vision_encoder.trunk.forward_features(vision_x)
764
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
765
- vision_x = self.vision_encoder(vision_x).last_hidden_state
766
- else:
767
- vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
768
- vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
769
- return vision_x
770
-
771
- def _concat_vision_cache(
772
- self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
773
- ):
774
- """
775
- Helper function to include the past vision tokens and past media locations in the output.
776
- """
777
- if use_cache:
778
- if past_media_locations is not None and past_vision_tokens is not None:
779
- if vision_tokens is not None:
780
- updated_vision_tokens = torch.cat(
781
- [
782
- past_vision_tokens,
783
- vision_tokens,
784
- ],
785
- dim=1,
786
- )
787
- else:
788
- updated_vision_tokens = past_vision_tokens
789
- updated_media_locations = torch.cat(
790
- [
791
- past_media_locations,
792
- lang_x == self.media_token_id,
793
- ],
794
- dim=1,
795
- )
796
- else:
797
- updated_vision_tokens = vision_tokens
798
- updated_media_locations = lang_x == self.media_token_id
799
-
800
- else:
801
- updated_vision_tokens = None
802
- updated_media_locations = None
803
-
804
- return updated_vision_tokens, updated_media_locations
805
-
806
- def generate(
807
- self,
808
- vision_x: torch.Tensor,
809
- lang_x: torch.Tensor,
810
- attention_mask: torch.Tensor = None,
811
- past_key_values: Optional[
812
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
813
- ] = None,
814
- past_media_locations: Optional[torch.Tensor] = None,
815
- past_vision_tokens: Optional[torch.Tensor] = None,
816
- **kwargs,
817
- ):
818
- """
819
- Generate text conditioned on vision and language inputs.
820
- Args:
821
- vision_x (torch.Tensor): Vision input
822
- shape (B, T_img, F, C, H, W)
823
- see documentation for forward
824
- lang_x (torch.Tensor): Language input
825
- shape (B, T_txt)
826
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
827
- **kwargs: see generate documentation in Hugging Face CausalLM models.
828
- Returns:
829
- torch.Tensor: lang_x with generated tokens appended to it
830
- """
831
- num_beams = kwargs.pop("num_beams", 1)
832
-
833
- # convert pixels to vision tokens
834
- if vision_x is not None:
835
- vision_features = self._encode_vision_x(vision_x=vision_x)
836
- vision_tokens = self.vision_tokenizer(vision_features)
837
- else:
838
- vision_tokens = None
839
-
840
- # fuse the vision and language tokens
841
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
842
- # the total batch size is B * num_beams
843
- new_inputs = self._prepare_inputs_for_forward(
844
- vision_tokens=vision_tokens,
845
- lang_x=lang_x,
846
- attention_mask=attention_mask,
847
- past_key_values=past_key_values,
848
- past_media_locations=past_media_locations,
849
- past_vision_tokens=past_vision_tokens,
850
- padding_side="left",
851
- num_beams=num_beams,
852
- )
853
- output = self.lang_model.generate(
854
- **new_inputs,
855
- past_key_values=past_key_values,
856
- num_beams=num_beams,
857
- use_cache=True,
858
- **kwargs,
859
- )
860
- self._post_forward_hook()
861
- return output
862
-
863
- @property
864
- def num_trainable_params(self):
865
- """Print the number of trainable parameters"""
866
- return num_params(self, filter_to_trainable=True)
867
-
868
- def set_trainable(self):
869
- """
870
- Freeze appropriate parameters in the model.
871
- """
872
- raise NotImplementedError
873
-
874
- def group_params_by_weight_decay(self):
875
- """
876
- Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
877
- """
878
- params_with_wd, params_without_wd = [], []
879
- for n, p in self.named_parameters():
880
- if p.requires_grad:
881
- if self._should_apply_weight_decay(n):
882
- params_with_wd.append(p)
883
- else:
884
- params_without_wd.append(p)
885
- return params_with_wd, params_without_wd
886
-
887
- def _should_apply_weight_decay(self, parameter_name):
888
- """
889
- Return whether weight decay should be applied to a parameter.
890
- """
891
- raise NotImplementedError
892
-
893
- @property
894
- def special_tokens(self):
895
- """
896
- Returns a dict mapping from the attribute name of a special token to its string format,
897
- e.g. "media_token": "<image>"
898
- """
899
- assert (
900
- "media_token" in self._special_tokens
901
- ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
902
- return self._special_tokens
903
-
904
- @property
905
- def special_token_ids(self):
906
- """
907
- Returns a list of the special token ids
908
- """
909
- return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
910
-
911
- def set_special_token_ids(self, string_to_ids):
912
- """
913
- Args:
914
- string_to_ids (dict): mapping from token string to id
915
- """
916
- assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
917
- for att_name, token_str in self.special_tokens.items():
918
- token_id = string_to_ids[token_str]
919
- setattr(self, f"{att_name}_id", token_id)
920
- setattr(self.lang_model, f"{att_name}_id", token_id)
921
-
922
- def init_gradient_checkpointing(self):
923
- from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
924
- checkpoint_wrapper,
925
- CheckpointWrapper,
926
- CheckpointImpl,
927
- apply_activation_checkpointing,
928
- )
929
- from functools import partial
930
-
931
- non_reentrant_wrapper = partial(
932
- checkpoint_wrapper,
933
- checkpoint_impl=CheckpointImpl.NO_REENTRANT,
934
- )
935
- apply_activation_checkpointing(
936
- self,
937
- checkpoint_wrapper_fn=non_reentrant_wrapper,
938
- check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
939
- and not isinstance(m, CheckpointWrapper),
940
- )
941
-
942
- @dataclass
943
- class VLMOutputWithPast(CausalLMOutputWithPast):
944
- """
945
- VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
946
- past_media_locations: Optional[torch.Tensor] = None,
947
- past_vision_tokens: Optional[torch.Tensor] = None,
948
- """
949
-
950
- past_media_locations: Optional[torch.Tensor] = None
951
- past_vision_tokens: Optional[torch.Tensor] = None
952
-
953
-
954
- def exists(val):
955
- return val is not None
956
-
957
-
958
- def FeedForward(dim, mult=4):
959
- inner_dim = int(dim * mult)
960
- return nn.Sequential(
961
- nn.LayerNorm(dim),
962
- nn.Linear(dim, inner_dim, bias=False),
963
- nn.GELU(),
964
- nn.Linear(inner_dim, dim, bias=False),
965
- )
966
-
967
- class VLMWithLanguageStream(VLM):
968
- """
969
- VLM that fuses modalities by inserting vision tokens directly into the language stream.
970
- """
971
-
972
- def __init__(
973
- self,
974
- vision_encoder: nn.Module,
975
- vision_tokenizer: nn.Module,
976
- lang_model: nn.Module,
977
- initial_tokenizer_len: int,
978
- pad_token_id: int,
979
- decoder_layers_attr_name: str = None,
980
- gradient_checkpointing: bool = False,
981
- ):
982
- super().__init__(
983
- vision_encoder=vision_encoder,
984
- vision_tokenizer=vision_tokenizer,
985
- lang_model=lang_model,
986
- initial_tokenizer_len=initial_tokenizer_len,
987
- pad_token_id=pad_token_id,
988
- gradient_checkpointing=gradient_checkpointing,
989
- )
990
- self.decoder_layers_attr_name = decoder_layers_attr_name
991
- if decoder_layers_attr_name is not None:
992
- for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
993
- block._use_gradient_checkpointing = gradient_checkpointing
994
-
995
- def _prepare_inputs_for_forward(
996
- self,
997
- vision_tokens: torch.Tensor,
998
- lang_x: torch.Tensor,
999
- attention_mask: torch.Tensor,
1000
- labels: torch.Tensor = None,
1001
- past_key_values=None,
1002
- vision_attention_mask: Optional[torch.Tensor] = None,
1003
- past_media_locations: torch.Tensor = None,
1004
- past_vision_tokens: torch.Tensor = None,
1005
- padding_side: str = "left",
1006
- num_beams: int = 1,
1007
- ):
1008
- """
1009
- Insert the vision tokens directly into the language stream/
1010
- This requires us to modify the input_ids, attention_mask, and labels.
1011
- """
1012
- if past_key_values is not None:
1013
- past_len = past_key_values[0][0].shape[2]
1014
- assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1015
- "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1016
- + "Check that you've expanded the attention mask to account for past image tokens."
1017
- )
1018
-
1019
- if vision_tokens is None:
1020
- return {
1021
- "input_ids": lang_x,
1022
- "attention_mask": attention_mask,
1023
- "labels": labels,
1024
- }
1025
-
1026
- # get the language embeddings
1027
- lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1028
-
1029
- # build up the multimodal embeddings
1030
- B = lang_x.shape[0]
1031
- has_labels = labels is not None
1032
- multimodal_embeds = []
1033
- multimodal_attention_mask = []
1034
- multimodal_labels = [] if has_labels else None
1035
- for i in range(B):
1036
- # get index of <image> tokens in lang_x[i]
1037
- image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1038
-
1039
- if len(image_token_idxs) == 0:
1040
- multimodal_embeds.append(lang_embeds[i].clone())
1041
- multimodal_attention_mask.append(attention_mask[i].clone())
1042
- if has_labels:
1043
- multimodal_labels.append(labels[i].clone())
1044
- continue
1045
-
1046
- # loop through the image_token_idxs and insert the vision tokens
1047
- new_embed = lang_embeds[i].clone()
1048
- new_attention_mask = (
1049
- attention_mask[i].clone() if attention_mask is not None else None
1050
- )
1051
- if has_labels:
1052
- new_label = labels[i].clone()
1053
-
1054
- for img_num, img_idx in enumerate(image_token_idxs):
1055
- # Get vision token attention mask for padded llava-style any resolution image tokens.
1056
- if self.image_aspect_ratio =='anyres':
1057
- num_vis_tokens = vision_tokens[i][img_num].shape[0]
1058
- if vision_attention_mask is not None:
1059
- vis_attention_mask = vision_attention_mask[i]
1060
- else:
1061
- vis_attention_mask = torch.ones(
1062
- num_vis_tokens, dtype=torch.long
1063
- ).to(attention_mask.device)
1064
- else:
1065
- assert (
1066
- vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1067
- ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1068
- vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1069
- # By default, vision tokens are not padded.
1070
- num_vis_tokens = self.num_tokens_per_vis
1071
- vis_attention_mask = torch.ones(
1072
- num_vis_tokens, dtype=torch.long
1073
- ).to(attention_mask.device)
1074
-
1075
- new_embed = torch.cat(
1076
- (
1077
- new_embed[:img_idx],
1078
- vision_tokens[i][img_num],
1079
- new_embed[img_idx + 1 :],
1080
- ),
1081
- dim=0,
1082
- )
1083
- new_attention_mask = torch.cat(
1084
- (
1085
- new_attention_mask[:img_idx],
1086
- vis_attention_mask,
1087
- new_attention_mask[img_idx + 1 :],
1088
- ),
1089
- dim=0,
1090
- )
1091
- if has_labels:
1092
- new_label = torch.cat(
1093
- (
1094
- new_label[:img_idx],
1095
- torch.ones(num_vis_tokens, dtype=torch.long).to(
1096
- labels.device
1097
- )
1098
- * -100,
1099
- new_label[img_idx + 1 :],
1100
- ),
1101
- dim=0,
1102
- )
1103
- multimodal_embeds.append(new_embed)
1104
- multimodal_attention_mask.append(new_attention_mask)
1105
- if has_labels:
1106
- multimodal_labels.append(new_label)
1107
-
1108
- # stack
1109
- multimodal_embeds = stack_with_padding(
1110
- multimodal_embeds,
1111
- padding_value=self.pad_token_id,
1112
- padding_side=padding_side,
1113
- )
1114
- multimodal_attention_mask = stack_with_padding(
1115
- multimodal_attention_mask,
1116
- padding_value=0,
1117
- padding_side=padding_side,
1118
- )
1119
- if has_labels:
1120
- multimodal_labels = stack_with_padding(
1121
- multimodal_labels,
1122
- padding_value=-100,
1123
- padding_side=padding_side,
1124
- )
1125
-
1126
- return {
1127
- "inputs_embeds": multimodal_embeds,
1128
- "attention_mask": multimodal_attention_mask,
1129
- "labels": multimodal_labels,
1130
- }
1131
-
1132
- def _postprocess_outputs_from_forward(
1133
- self,
1134
- output: CausalLMOutputWithPast,
1135
- lang_x: torch.Tensor,
1136
- vision_tokens: torch.Tensor,
1137
- past_vision_tokens: torch.Tensor,
1138
- past_media_locations: torch.Tensor,
1139
- use_cache: bool = False,
1140
- ):
1141
- # Include the past vision tokens and past media locations in the output
1142
- updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1143
- lang_x=lang_x,
1144
- vision_tokens=vision_tokens,
1145
- past_vision_tokens=past_vision_tokens,
1146
- past_media_locations=past_media_locations,
1147
- use_cache=use_cache,
1148
- )
1149
-
1150
- # return logits that are the same shape as the original input_ids
1151
- logits = output.logits
1152
- batch_logits = []
1153
- B, T_txt = lang_x.shape
1154
- for i in range(B):
1155
- sequence_logits = []
1156
- logits_j = 0
1157
- for j in range(T_txt):
1158
- if lang_x[i, j] != self.media_token_id:
1159
- sequence_logits.append(logits[i, logits_j])
1160
- logits_j += 1
1161
- else:
1162
- # append the logit for the first image token, then skip over the rest
1163
- # note: the model actually learns to predict <im_patch>, not <image>
1164
- sequence_logits.append(logits[i, logits_j])
1165
- logits_j += self.num_tokens_per_vis
1166
- sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1167
- batch_logits.append(sequence_logits)
1168
-
1169
- batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1170
- # The final logits shape should be the same as the original input_ids shape
1171
- assert batch_logits.shape[:2] == (B, T_txt)
1172
-
1173
- # assemble the output
1174
- output = VLMOutputWithPast(
1175
- loss=output.loss,
1176
- logits=batch_logits,
1177
- past_key_values=output.past_key_values,
1178
- hidden_states=output.hidden_states,
1179
- attentions=output.attentions,
1180
- past_media_locations=updated_media_locations,
1181
- past_vision_tokens=updated_vision_tokens,
1182
- )
1183
-
1184
- return output
1185
-
1186
- def _post_forward_hook(self):
1187
- pass
1188
-
1189
-
1190
- @property
1191
- def num_params_per_module(self):
1192
- """Print the number of parameters per module in the model"""
1193
- return "\n".join(
1194
- [
1195
- f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1196
- f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1197
- f"Language model: {num_params(self.lang_model):,} parameters",
1198
- ]
1199
- )
1200
-
1201
- @property
1202
- def num_trainable_params_per_module(self):
1203
- """Print the number of trainable parameters per module in the model"""
1204
- return "\n".join(
1205
- [
1206
- f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1207
- f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1208
- f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1209
- ]
1210
- )
1211
-
1212
-
1213
- class XGenMMPerceiver(VLMWithLanguageStream):
1214
- def __init__(
1215
- self,
1216
- vision_encoder: nn.Module,
1217
- vision_tokenizer: nn.Module,
1218
- lang_model: nn.Module,
1219
- initial_tokenizer_len: int,
1220
- pad_token_id: int,
1221
- decoder_layers_attr_name: str = None,
1222
- gradient_checkpointing: bool = False,
1223
- image_aspect_ratio: str = 'anyres',
1224
- anyres_patch_sampling: bool = True,
1225
- anyres_grids: list[int] = None,
1226
- ):
1227
- """
1228
- Args:
1229
- vision_encoder (nn.Module): HF CLIPModel
1230
- lang_encoder (nn.Module): HF causal language model
1231
- vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1232
- initial_tokenizer_len (int): size of the tokenizer vocab
1233
- padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1234
- will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1235
- decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1236
- gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1237
- """
1238
- self._special_tokens = {
1239
- "media_token": "<image>",
1240
- "image_placeholder_token": "<image placeholder>",
1241
- "end_of_trunk_token": "<|endofchunk|>",
1242
- }
1243
- lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1244
- super().__init__(
1245
- vision_encoder=vision_encoder,
1246
- vision_tokenizer=vision_tokenizer,
1247
- lang_model=lang_model,
1248
- initial_tokenizer_len=initial_tokenizer_len,
1249
- gradient_checkpointing=gradient_checkpointing,
1250
- decoder_layers_attr_name=decoder_layers_attr_name,
1251
- pad_token_id=pad_token_id,
1252
- )
1253
- self.image_aspect_ratio = image_aspect_ratio
1254
- self.anyres_patch_sampling = anyres_patch_sampling
1255
- self.anyres_grids = anyres_grids
1256
-
1257
- def set_trainable(self):
1258
- """
1259
- Unfreeze everything except the vision_encoder
1260
- """
1261
- self.requires_grad_(True)
1262
- self.vision_encoder.requires_grad_(False)
1263
-
1264
- def _should_apply_weight_decay(self, parameter_name):
1265
- """
1266
- Kosmos applies 0.01 weight deacy to everything
1267
- """
1268
- return True
1269
-
1270
- def generate(
1271
- self,
1272
- vision_x: torch.Tensor,
1273
- lang_x: torch.Tensor,
1274
- image_size: Optional[Tuple] = None,
1275
- attention_mask: torch.Tensor = None,
1276
- past_key_values: Optional[
1277
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1278
- ] = None,
1279
- past_media_locations: Optional[torch.Tensor] = None,
1280
- past_vision_tokens: Optional[torch.Tensor] = None,
1281
- **kwargs,
1282
- ):
1283
- """
1284
- Generate text conditioned on vision and language inputs.
1285
- Args:
1286
- vision_x (torch.Tensor): Vision input
1287
- shape (B, T_img, F, C, H, W)
1288
- see documentation for forward
1289
- lang_x (torch.Tensor): Language input
1290
- shape (B, T_txt)
1291
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1292
- **kwargs: see generate documentation in Hugging Face CausalLM models.
1293
- Returns:
1294
- torch.Tensor: lang_x with generated tokens appended to it
1295
- """
1296
- num_beams = kwargs.pop("num_beams", 1)
1297
-
1298
- # convert pixels to vision tokens
1299
- vision_attention_mask = None
1300
- if vision_x is not None:
1301
- if self.image_aspect_ratio == 'anyres':
1302
- input_dict = dict(image=vision_x, image_size=image_size)
1303
- vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1304
- else:
1305
- vision_features = self._encode_vision_x(vision_x=vision_x)
1306
- vision_attn_masks = None
1307
- # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1308
- # Same for attention masks: [b, Np, v] -> [b*Np, v]
1309
- if self.anyres_patch_sampling:
1310
- split_sizes = [feature.shape[0] for feature in vision_features]
1311
- # Nested splits for multi-image samples.
1312
- if isinstance(vision_x[0], list):
1313
- nt_images = [len(images) for images in vision_x]
1314
- split_split_sizes = []
1315
- img_id = 0
1316
- for nt in nt_images:
1317
- split_split_sizes.append(split_sizes[img_id:img_id+nt])
1318
- img_id += nt
1319
- else:
1320
- nt_images = [1] * len(vision_x)
1321
- split_split_sizes = split_sizes
1322
- vision_features = torch.cat(vision_features, dim=0)
1323
- vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1324
- vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1325
- vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1326
-
1327
- # Post-processing: Split the batches into groups of patches and concatenate them together.
1328
- if self.anyres_patch_sampling:
1329
- assert isinstance(vision_x, list)
1330
- if isinstance(vision_x[0], list):
1331
- vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
1332
- vision_tokens = []
1333
-
1334
- for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1335
- patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1336
- flatten_vision_tokens = []
1337
- for image_vis_token in patch_vis_token_groups:
1338
- image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1339
- flatten_vision_tokens.append(image_vis_token)
1340
- vision_tokens_i = flatten_vision_tokens
1341
- vision_tokens.append(vision_tokens_i)
1342
- else:
1343
- vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1344
- vision_tokens = []
1345
- for patch_vis_tokens in vision_token_groups:
1346
- patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1347
- vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
1348
- else:
1349
- vision_tokens = None
1350
-
1351
- # fuse the vision and language tokens
1352
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
1353
- # the total batch size is B * num_beams
1354
- new_inputs = self._prepare_inputs_for_forward(
1355
- vision_tokens=vision_tokens,
1356
- lang_x=lang_x,
1357
- attention_mask=attention_mask,
1358
- vision_attention_mask=vision_attention_mask,
1359
- past_key_values=past_key_values,
1360
- past_media_locations=past_media_locations,
1361
- past_vision_tokens=past_vision_tokens,
1362
- padding_side="left",
1363
- num_beams=num_beams,
1364
- )
1365
- if past_key_values is not None:
1366
- output = self.lang_model.generate(
1367
- **new_inputs,
1368
- past_key_values=past_key_values,
1369
- num_beams=num_beams,
1370
- use_cache=True,
1371
- **kwargs,
1372
- )
1373
- else:
1374
- output = self.lang_model.generate(
1375
- **new_inputs,
1376
- num_beams=num_beams,
1377
- use_cache=True,
1378
- **kwargs,
1379
- )
1380
- self._post_forward_hook()
1381
- return output