# Copyright 2024 EPFL and Apple Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from typing import Union, List, Optional import numpy as np import torch from einops import rearrange, repeat from torch import nn import torch.nn.functional as F from fourm.utils import get_sentinel_to_id_mapping, merge_span_masking from fourm.utils.generation import cosine_schedule, linear_schedule, onex_temp_schedule, linear_temp_schedule, continue_schedule from tqdm import tqdm import copy def empty_img_modality(mod_dict, key): # Input mask mod_dict[key]['input_mask'][:] = True # Target Mask mod_dict[key]['target_mask'][:] = False return mod_dict def empty_seq_modality(mod_dict, key, s1_id=5): # To create an empty sequence, we suppose an input budget of 1, and the rest assigned to targets # Input tensor # Input is [S_1], target is [S_1] ...... [S_2] # (so [S_1] [S_1] ..... [S_2] when combined) mod_dict[key]['tensor'][:] = 0 mod_dict[key]['tensor'][:,[0,1]] = s1_id # s1_id is id of the first sentinel token ([S_1]) mod_dict[key]['tensor'][:,-1] = s1_id + 1 # Input mask # Set first token to input (i.e. 0), rest to target (i.e. 1) mod_dict[key]['input_mask'][:] = True mod_dict[key]['input_mask'][:,0] = False # Target Mask mod_dict[key]['target_mask'] = ~mod_dict[key]['input_mask'] # Decoder attn mask # WARNING: Not needed / used in GenerationSampler, where causal mask is enforced # First token is input, not part of target mod_dict[key]['decoder_attention_mask'][:] = 1 mod_dict[key]['decoder_attention_mask'][:, 0] = 0 return mod_dict def empty_seq_emb_modality(mod_dict, key): # Tensor mod_dict[key]['tensor'] = torch.zeros_like(mod_dict[key]['tensor']) # Input mask mod_dict[key]['input_mask'] = torch.ones_like(mod_dict[key]['input_mask']) # It is crucial to specify the input mask as such, CFG won't work otherwise! mod_dict[key]['input_mask'][:, 0] = False # Target Mask mod_dict[key]['target_mask'] = torch.ones_like(mod_dict[key]['target_mask']) # Decoder attn mask mod_dict[key]['decoder_attention_mask'][:] = False return mod_dict def init_empty_target_modality(mod_dict, modality_info, domain, batch_size, num_tokens, device): """ Initializes an empty target modality dictionary for a given domain. Used to initialize target modality dictionaries for generation. """ if modality_info[domain]['type'] == 'img': # Initialize mod dict mod_dict[domain] = { 'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int64, device=device), 'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device), 'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), } # Set it to the correct values mod_dict = empty_img_modality(mod_dict, domain) elif modality_info[domain]['type'] in ['seq', 'seq_token', 'seq_emb']: # Initialize mod dict num_tokens = max(num_tokens, 2) mod_dict[domain] = { 'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int32, device=device), 'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device), 'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), 'decoder_attention_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), } # Set it to the correct values if modality_info[domain]['type'] in ['seq', 'seq_token']: mod_dict = empty_seq_modality(mod_dict, domain) elif modality_info[domain]['type'] == 'seq_emb': mod_dict = empty_seq_emb_modality(mod_dict, domain) else: raise ValueError() return mod_dict def init_full_input_modality(mod_dict, modality_info, domain, device, eos_id=3): if domain.startswith('rgb'): batch_size, _, H, W = mod_dict[domain]['tensor'].shape patch_size = modality_info[domain]['patch_size'] num_tokens = (H // patch_size) * (W // patch_size) shape = (batch_size, num_tokens) else: shape = mod_dict[domain]['tensor'].shape if 'input_mask' not in mod_dict[domain]: mod_dict[domain]['input_mask'] = torch.zeros(shape, dtype=torch.bool, device=device) if 'target_mask' not in mod_dict[domain]: mod_dict[domain]['target_mask'] = torch.ones(shape, dtype=torch.bool, device=device) if 'decoder_attention_mask' not in mod_dict[domain]: mod_dict[domain]['decoder_attention_mask'] = torch.zeros(shape, dtype=torch.bool, device=device) if modality_info[domain]['type'] == 'img': mod_dict[domain]['input_mask'][:] = False mod_dict[domain]['target_mask'][:] = True elif modality_info[domain]['type'] in ['seq', 'seq_token']: if eos_id in mod_dict[domain]['tensor']: eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item() else: mod_dict[domain]['tensor'][:,0] = eos_id eos_idx = 0 mod_dict[domain]['input_mask'][:,:eos_idx+1] = False mod_dict[domain]['input_mask'][:,eos_idx+1:] = True mod_dict[domain]['target_mask'][:] = True elif modality_info[domain]['type'] in ['seq_emb']: # T5 caption has the valid mask saved alongside the embeddings mod_dict[domain]['input_mask'] = ~mod_dict[domain]['mask_valid'] mod_dict[domain]['target_mask'] = torch.ones_like(mod_dict[domain]['mask_valid']) mod_dict[domain]['decoder_attention_mask'] = torch.zeros_like(mod_dict[domain]['mask_valid']) return mod_dict def custom_text(sample, input_text, eos_token, key, device, text_tokenizer, target_max_len=50, start_token="[S_1]"): input_ids = text_tokenizer.encode(input_text).ids input_ids = torch.tensor(input_ids).unsqueeze(0) target_text = [start_token] target_text.extend(["[PAD]"] * (target_max_len - 2)) target_text.append(eos_token) target_text = " ".join(target_text) target_ids = text_tokenizer.encode(target_text).ids target_ids = torch.tensor(target_ids).unsqueeze(0) all_ids = torch.cat([input_ids, target_ids], dim=1) input_mask = torch.cat([ torch.zeros_like(input_ids, dtype=torch.bool), torch.ones_like(target_ids, dtype=torch.bool), ], dim=1) target_mask = torch.cat([ torch.ones_like(input_ids, dtype=torch.bool), torch.zeros_like(target_ids, dtype=torch.bool), ], dim=1) sample[key] = {} sample[key]['tensor'] = all_ids.to(device) sample[key]['input_mask'] = input_mask.to(device) sample[key]['target_mask'] = target_mask.to(device) sample[key]['decoder_attention_mask'] = torch.zeros(all_ids.shape, dtype=torch.bool, device=device) return sample def expand_to_batch(mod_dict, batch_size): for mod, d in mod_dict.items(): for k, v in d.items(): if k in ['tensor', 'input_mask', 'target_mask', 'decoder_attention_mask', 'mask_valid']: B = v.shape[0] if B == 1: mod_dict[mod][k] = repeat(v, "1 ... -> b ...", b=batch_size) elif B != batch_size: raise ValueError(f"Invalid batch size: {B} instead of {batch_size}") return mod_dict def build_chained_generation_schedules( cond_domains: List[str], target_domains: List[str], tokens_per_target: List[int], autoregression_schemes: List[str], decoding_steps: List[int], token_decoding_schedules: List[str], temps: List[float], temp_schedules: List[float], cfg_scales: List[float], cfg_schedules: List[str], cfg_grow_conditioning: bool = False, modality_info: Optional[dict] = None, ): """ Builds a list of chained generation schedules, where each schedule is a tuple of the form: (target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains) Args: cond_domains: List of conditioning domains target_domains: List of target domains tokens_per_target: List of number of tokens to decode for each target domain autoregression_schemes: List of autoregression schemes for each target domain. maskgit, roar, or autoregressive decoding_steps: List of number of maskgit steps for each target domain (if applicable) token_decoding_schedules: List of maskgit token schedules for each target domain (if applicable). cosine or linear temps: List of starting temperatures for each target domain temp_schedules: List of temperature schedules for each target domain. linear, constant, or onex:{min_t}:{power} cfg_scales: List of classifier-free guidance scales for each target domain cfg_schedules: List of classifier-free guidance schedules for each target domain. constant or cosine cfg_grow_conditioning: After every completed modality, add them to classifier-free guidance conditioning modality_info: Dictionary with metadata for each modality, optionally used to verify that the schedule is compatible with the modality """ # List of {target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains} dicts chained_schedules = [] cond_domains = cond_domains.copy() for target_idx in range(len(target_domains)): scheme = autoregression_schemes[target_idx] target_domain = target_domains[target_idx] ntoks = tokens_per_target[target_idx] maskgit_token_schedule_name = token_decoding_schedules[target_idx] temp = temps[target_idx] temp_schedule_name = temp_schedules[target_idx] cfg_scale = cfg_scales[target_idx] cfg_schedule_name = cfg_schedules[target_idx] # Auto-regressive (caption, detection, ...) if scheme == 'autoregressive': chained_schedules.append({ 'target_domain': target_domain, 'scheme': scheme, 'num_tokens': None, 'temperature': temp, 'cfg_scale': cfg_scale, 'cfg_cond_domains': cond_domains.copy() }) continue # Use modality info for (optional) assert if provided if modality_info is not None: assert modality_info[target_domain]['type'] not in ['seq', 'seq_token'], f'Illegal autoregressive scheme {scheme} for target domain {target_domain}' # Token schedule if scheme == 'maskgit': # MaskGIT token schedule setup num_steps = decoding_steps[target_idx] if maskgit_token_schedule_name == 'cosine': token_schedule = cosine_schedule(num_steps, (ntoks)) elif maskgit_token_schedule_name == 'linear': token_schedule = linear_schedule(num_steps, (ntoks)) else: raise ValueError(f'Illegal MaskGIT token schedule {maskgit_token_schedule_name}') elif scheme == 'roar': # ROAR token schedule setup (one-by-one, but random order) num_steps = decoding_steps[target_idx] token_schedule = linear_schedule(num_steps, ntoks) else: raise ValueError(f'Illegal decoding scheme {scheme}') # Temperature schedule if temp_schedule_name == 'linear': temp_schedule = linear_temp_schedule(temp, token_schedule) elif temp_schedule_name == 'constant': temp_schedule = temp * np.ones(num_steps) elif 'onex' in temp_schedule_name: # onex temperature schedule has to be formatted like onex:{min_t}:{power} min_t, power = [float(f) for f in temp_schedule_name.split(':')[1:]] temp_schedule = onex_temp_schedule(max_t=temp, min_t=min_t, token_schedule=token_schedule, power=power) else: raise ValueError(f'Illegal temperature schedule {temp_schedule_name}') # Classifier-free guidance scale schedule if cfg_schedule_name == 'constant': if isinstance(cfg_scale, float): cfg_schedule = cfg_scale * np.ones(num_steps) elif isinstance(cfg_scale, list): cfg_schedule = np.array(cfg_scale) * np.ones(num_steps).reshape(-1, 1) elif cfg_schedule_name == 'cosine': raise NotImplementedError() else: raise ValueError(f'Illegal guidance schedule {cfg_schedule_name}') # Concatenate schedule for this modality with previous ones schedule = [ { 'target_domain': target_domain, 'scheme': scheme, 'num_tokens': tok, 'temperature': temp, 'cfg_scale': cfg, 'cfg_cond_domains': cond_domains.copy() } for tok, temp, cfg in zip(token_schedule, temp_schedule, cfg_schedule) ] chained_schedules.extend(schedule) # Optionally add this new modality to the ones affected by classifier-free guidance if cfg_grow_conditioning: cond_domains.append(target_domain) return chained_schedules class GenerationSampler(nn.Module): """Sampler that wraps a trained 4M model for generation use cases. Implements standard autoregressive, MaskGIT, and ROAR generation schemes with chaining and weighted guidance.""" def __init__(self, model): super().__init__() self.model = model def top_k_top_p_filtering(self, logits, top_k=0.0, top_p=0.0): # Compatible with batching # From https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 if top_k > 0.0: if isinstance(top_k, int): k = min(top_k, logits.shape[-1]) elif isinstance(top_k, float): k = min(int(top_k * logits.shape[-1]), logits.shape[-1]) else: raise ValueError(f"Invalid value for top_k: {top_k}") # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None] logits[indices_to_remove] = float("-inf") if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, dim=1, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cum_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 restore_indices = torch.argsort(sorted_indices, dim=-1) indices_to_remove = torch.gather(sorted_indices_to_remove, dim=-1, index=restore_indices) logits[indices_to_remove] = float("-inf") return logits def sample_tokens(self, logits, temperature=1.0, top_k=0.0, top_p=0.0): if np.isclose(temperature, 0, atol=1e-10): samples = torch.argmax(logits, dim=-1) # Since argmax is used, all sampled_probs will be 1 as we're selecting the max probability sampled_probs = torch.ones_like(samples, dtype=torch.float32) else: filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p) probs = F.softmax(filtered_logits / temperature, dim=-1) samples = torch.multinomial(probs, 1)[:, 0] sampled_probs = probs[torch.arange(len(samples)), samples] return samples, sampled_probs def sample_tokens_batched(self, logits, temperature=1.0, top_k=0.0, top_p=0.0): if logits.ndim > 2: B, N = logits.shape[0], logits.shape[1] logits = rearrange(logits, 'b n v -> (b n) v') samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p) samples = rearrange(samples, '(b n) -> b n', b=B, n=N) sampled_probs = rearrange(sampled_probs, '(b n) -> b n', b=B, n=N) return samples, sampled_probs else: return self.sample_tokens(logits, temperature, top_k, top_p) def select_tokens(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False): samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p) top_indices = torch.topk(sampled_probs, num_select)[1] top_samples = samples[top_indices] if return_all_samples: return top_samples, top_indices, samples else: return top_samples, top_indices def select_tokens_batched(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False): if logits.ndim > 2: samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k, top_p) # both of shape (B, N) top_indices = torch.topk(sampled_probs, num_select, dim=-1)[1] # Need to switch to gather instead of indexing here top_samples = torch.gather(samples, dim=-1, index=top_indices) if return_all_samples: return top_samples, top_indices, samples else: return top_samples, top_indices else: return self.sample_tokens(logits, num_select, temperature, top_k, top_p, return_all_samples) def forward_mask_encoder_generation(self, encoder_mod_dict): """Modification of forward_mask_encoder adapted for generation, with support for batching """ # Form input B = list(encoder_mod_dict.values())[0]['tensor'].shape[0] encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.model.cat_encoder_tensors(encoder_mod_dict) # Take max num encoder of tokens (although assuming it's the same everywhere would be better) num_encoder_tokens = (~encoder_mask_all.reshape(B, -1)).sum(dim=1).max() # Add arange multiplied by small constant to mask so they get sorted in a deterministic way mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :num_encoder_tokens] encoder_tokens = torch.gather(encoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2])) encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) if self.model.num_register_tokens > 0: prompt_tokens = repeat(self.prompt_tokens, '() n d -> b n d', b=B) # We add prompt tokens at the beginning of the sequence encoder_tokens = torch.cat([prompt_tokens, encoder_tokens], dim=1) encoder_emb = torch.cat([torch.zeros_like(prompt_tokens), encoder_emb], dim=1) encoder_mask = torch.cat([torch.zeros((B, prompt_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1) mod_mask = torch.cat([torch.full((B, prompt_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1) encoder_tokens[encoder_mask] = 0. encoder_emb[encoder_mask] = 0. mod_mask[encoder_mask] = -1 # Mask could be of shape 'b n1 n2' but not needed for masked_fill # This means this mask can then be re-used for decoder cross-attention encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2') return encoder_tokens, encoder_emb, encoder_mask, mod_mask def forward_mask_decoder_maskgit(self, mod_dict, target_mod, seed=None): """Modification of forward_mask_decoder for MaskGIT generation, with support for batching """ if seed is not None: torch.manual_seed(seed) d = mod_dict[target_mod] decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token emb_all = d['emb'] decoder_mask_all = d['target_mask'] B = decoder_tokens_all.shape[0] # Get batch size mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching / Assumes num_decoder_tokens is the same across the batch # Add arange multiplied by small constant to mask so they get sorted in a deterministic way mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :num_decoder_tokens] decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) decoder_tokens[decoder_mask] = 0. decoder_emb[decoder_mask] = 0. mod_mask[decoder_mask] = -1 return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos def forward_mask_decoder_roar(self, mod_dict, target_mod, num_select, seed=None): """Modification of forward_mask_decoder for ROAR generation, with support for batching """ if seed is not None: torch.manual_seed(seed) d = mod_dict[target_mod] decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token emb_all = d['emb'] decoder_mask_all = d['target_mask'] B = decoder_tokens_all.shape[0] # Get batch size mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching # Only keep the first num_select tokens num_decoder_tokens = min(num_select, (~decoder_mask_all[0]).sum()) # Adapted for batching / Assumes num_decoder_tokens is the same across the batch # Add a small random number to the mask so they get sorted in a random way, but keeping the masked tokens first mask_rand = torch.rand(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(decoder_mask_all + mask_rand, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) # Only keep the first num_select_tokens ids_keep = ids_shuffle[:, :num_decoder_tokens] decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) decoder_tokens[decoder_mask] = 0. decoder_emb[decoder_mask] = 0. mod_mask[decoder_mask] = -1 return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos def forward_mask_decoder_autoregressive(self, mod_dict, target_mod, seed=None): # Adapted for batching if seed is not None: torch.manual_seed(seed) # This is the concatenation part d = mod_dict[target_mod] decoder_ids_all = d['ids'] emb_all = d['emb'] decoder_mask_all = d['target_mask'] B = decoder_ids_all.shape[0] # Get batch size mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching, but assumes num_decoder_tokens is the same across the batch # Add arange multiplied by small constant to mask so they get sorted in a deterministic way mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :num_decoder_tokens] # Same as in forward_mask_decoder decoder_ids = torch.gather(decoder_ids_all, dim=1, index=ids_keep) decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) decoder_ids[decoder_mask] = 0 decoder_emb[decoder_mask] = 0. mod_mask[decoder_mask] = -1 return decoder_ids, decoder_emb, decoder_mask, mod_mask, mod_pos def merge_sequences(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"): device = mod_dict[target_mod]['tensor'].device # Get input ids input_ids = mod_dict[target_mod]['tensor'].squeeze().detach().cpu() input_ids = input_ids[mod_dict[target_mod]['input_mask'].squeeze().detach().cpu() == 0] input_ids = input_ids.tolist() if len(input_ids) == 0: input_ids = [text_tokenizer.get_vocab()[default_sentinel]] # Get predicted ids pred_ids = pred_ids.squeeze().detach().cpu().tolist() if isinstance(pred_ids, int): pred_ids = [pred_ids] # Get sentinel ids using the tokenizer sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values()) # Perform merging merged_ids = merge_span_masking(input_ids, pred_ids, sentinel_ids) merged_ids = torch.tensor(merged_ids).unsqueeze(0) # Create new dict new_input_mask = torch.zeros_like(merged_ids, dtype=torch.bool) new_target_mask = torch.ones_like(merged_ids, dtype=torch.bool) new_dict = {'tensor': merged_ids.to(device), 'input_mask': new_input_mask.to(device), 'target_mask': new_target_mask.to(device)} new_dict['decoder_attention_mask'] = torch.zeros_like(new_target_mask, dtype=torch.bool) mod_dict[target_mod] = new_dict return mod_dict def merge_sequences_batched(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"): # Unbatches and calls merge sequence per batch, then regroups it into a batch pad_id = text_tokenizer.token_to_id("[PAD]") B = mod_dict[target_mod]['tensor'].shape[0] device = mod_dict[target_mod]['tensor'].device tensors = torch.split(mod_dict[target_mod]['tensor'], 1) input_masks = torch.split(mod_dict[target_mod]['input_mask'], 1) pred_ids = torch.split(pred_ids, 1) input_dicts = [] for t, im in zip(tensors, input_masks): d = {target_mod: {'tensor': t, 'input_mask': im}} input_dicts.append(d) merged_tensors = [] merged_input_masks = [] merged_target_masks = [] merged_seq_lens = [] for input_d, pi in zip(input_dicts, pred_ids): # Output of merge_sequences is mod_dict with modified target mod merged_d = self.merge_sequences(input_d, pi, target_mod, text_tokenizer, default_sentinel)[target_mod] merged_tensors.append(merged_d['tensor']) merged_input_masks.append(merged_d['input_mask']) merged_target_masks.append(merged_d['input_mask']) merged_seq_lens.append(merged_d['tensor'].shape[1]) max_seq_len = max(merged_seq_lens) for i in range(len(merged_tensors)): # Right pad all tensors p1d = (0, max_seq_len - merged_seq_lens[i]) merged_tensors[i] = F.pad(merged_tensors[i], p1d, "constant",pad_id) merged_input_masks[i] = F.pad(merged_input_masks[i], p1d, "constant", True) merged_target_masks[i] = F.pad(merged_target_masks[i], p1d, "constant", True) new_dict = {'tensor': torch.cat(merged_tensors, dim=0).to(device), 'input_mask': torch.cat(merged_input_masks, dim=0).to(device), 'target_mask': torch.cat(merged_target_masks, dim=0).to(device)} new_dict['decoder_attention_mask'] = torch.zeros_like(new_dict['target_mask'], dtype=torch.bool) mod_dict[target_mod] = new_dict return mod_dict def forward_enc_dec_maskgit_batched(self, mod_dict, target_mod, seed=None): # Encoder encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) for mod, d in mod_dict.items() if mod in self.model.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) x = encoder_tokens + encoder_emb x = self.model.forward_encoder(x, encoder_mask) # Decoder context = self.model.decoder_proj_context(x) + encoder_emb decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_maskgit(decoder_mod_dict, target_mod, seed=seed) y = decoder_tokens + decoder_emb y = self.model.forward_decoder(y, context, encoder_mask, None) B, N, D = y.shape logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod] logits = logits.reshape(B, N, -1) return logits, mod_pos def maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None): logits, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed) # MaskGIT sampling top_samples, top_indices = self.select_tokens_batched(logits, num_select, temperature=temperature, top_k=top_k, top_p=top_p) # Update mod dict # We rely on gather / scatter for batched operations top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples) mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) return mod_dict def guided_maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, conditioning=[], guidance_scale=1.0, seed=None, write_all_predictions=False): ### 1 - First pass, with conditioning logits_cond, _ = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed) ### 2 - Second pass, without conditioning mod_dict_uncond = copy.deepcopy(mod_dict) for mod in conditioning: if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) elif self.model.modality_info[mod]['type'] in ['seq_emb']: mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) else: mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict_uncond, target_mod, seed=seed) ### 3 - Classifier-free guidance logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale ### 4 - MaskGIT sampling top_samples, top_indices, all_samples = self.select_tokens_batched( logits, num_select, temperature=temperature, top_k=top_k, top_p=top_p, return_all_samples=True ) ### 5 - Update mod dict # We rely on gather / scatter for batched operations top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) if write_all_predictions: mod_dict[target_mod]['tensor'][:, mod_pos] = all_samples else: mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples) mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) return mod_dict def multi_guided_maskgit_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temperature, top_k, top_p, seed=None, write_all_predictions=False): ### 1 - Conditional forward passes (one for each guided condition) logits_cond_all = [] for cond_dict in cond_dicts: logits_cond_i, _ = self.forward_enc_dec_maskgit_batched(cond_dict, target_mod, seed=seed) logits_cond_all.append(logits_cond_i) ### 2 - Unconditional forward pass logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(uncond_dict, target_mod, seed=seed) ### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)} # See https://arxiv.org/abs/2206.01714 logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0) ### 4 - MaskGIT sampling top_samples, top_indices, all_samples = self.select_tokens_batched( logits, num_select, temperature=temperature, top_k=top_k, top_p=top_p, return_all_samples=True ) ### 5 - Update mod dict with newly generated tokens # We rely on gather / scatter for batched operations top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) if write_all_predictions: uncond_dict[target_mod]['tensor'][:, mod_pos] = all_samples else: uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, top_pos, top_samples) uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) # Update conditioning dicts for i in range(len(cond_dicts)): cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, top_pos, top_samples) cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) return uncond_dict, cond_dicts def forward_enc_dec_roar_batched(self, mod_dict, target_mod, num_select, seed=None): # Encoder encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) for mod, d in mod_dict.items() if mod in self.model.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) x = encoder_tokens + encoder_emb x = self.model.forward_encoder(x, encoder_mask) # Decoder context = self.model.decoder_proj_context(x) + encoder_emb decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_roar(decoder_mod_dict, target_mod, num_select, seed=seed) y = decoder_tokens + decoder_emb y = self.model.forward_decoder(y, context, encoder_mask, None) B, N, D = y.shape logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod] logits = logits.reshape(B, N, -1) return logits, mod_pos def roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None): """ROAR = Random Order Autoregression""" logits, mod_pos = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed) # Simple sampling samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) # Update mod dict # We rely on scatter for batched operations select_pos = mod_pos mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples) mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) return mod_dict def guided_roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, conditioning=[], guidance_scale=1.0, seed=None): """ROAR = Random Order Autoregression""" ### 1 - First pass, with conditioning logits_cond, _ = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed) ### 2 - Second pass, without conditioning mod_dict_uncond = copy.deepcopy(mod_dict) for mod in conditioning: if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) elif self.model.modality_info[mod]['type'] in ['seq_emb']: mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) else: mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(mod_dict_uncond, target_mod, num_select, seed=seed) ### 3 - Classifier-free guidance logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale ### 4 - Simple sampling samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) ### 5 - Update mod dict # We rely on gather / scatter for batched operations select_pos = mod_pos mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples) mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) return mod_dict def multi_guided_roar_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temperature, top_k, top_p, seed=None): ### 1 - Conditional forward passes (one for each guided condition) logits_cond_all = [] for cond_dict in cond_dicts: logits_cond_i, _ = self.forward_enc_dec_roar_batched(cond_dict, target_mod, num_select, seed=seed) logits_cond_all.append(logits_cond_i) ### 2 - Unconditional forward pass logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(uncond_dict, target_mod, num_select, seed=seed) ### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)} # See https://arxiv.org/abs/2206.01714 logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0) ### 4 - Simple sampling samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) ### 5 - Update mod dict # We rely on gather / scatter for batched operations select_pos = mod_pos uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, select_pos, samples) uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) # Update conditioning dicts for i in range(len(cond_dicts)): cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, select_pos, samples) cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) return uncond_dict, cond_dicts def autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float, use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, seed=None): # Encoder encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) for mod, d in mod_dict.items() if mod in self.model.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) x = encoder_tokens + encoder_emb x = self.model.forward_encoder(x, encoder_mask) # B, N, D # Get batch size B = x.shape[0] # Decoder context = self.model.decoder_proj_context(x) + encoder_emb decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict, target_mod, seed=seed) device = decoder_ids.device seq_len = self.model.modality_info[target_mod]['max_tokens'] if use_eos and eos_token is None: # The eos_token is the final sentinel token provided eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all if use_eos: eos_token = eos_token.to(device) # If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token) out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device) # Set decoder_tokens to None, we do not use them for decoding decoder_ids = None # If all samples of the batch have eos, return early if use_eos and (out == eos_token).any(dim=-1).all(): return out y_emb = decoder_emb[:, :seq_len] seq_len = y_emb.shape[1] # Auto-regressive decoding and sampling for i in range(seq_len): cur_len = out.shape[1] # Convert ids into word embeddings and add corresponding posembs + modemb y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len] # Build causal mask causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1) causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) y = self.model.forward_decoder(y, context, encoder_mask, causal_mask) logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask[:, :cur_len])[target_mod] logits = rearrange(logits, "(b n) d -> b n d", b=B, n=cur_len) last_logits = logits[:, -1] # Sample token for the newly generated logit if np.isclose(temperature, 0, atol=1e-10): sample = torch.argmax(last_logits, dim=-1, keepdim=True) else: filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) if use_eos and (out == eos_token).any(dim=-1).all(): break mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer) return mod_dict def guided_autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float, use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, conditioning=[], guidance_scale=1.0, seed=None): ### 1 - Encoder forward pass, with conditioning # Encoder encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) for mod, d in mod_dict.items() if mod in self.model.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask_cond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) x = encoder_tokens + encoder_emb x = self.model.forward_encoder(x, encoder_mask_cond) # B, N, D # Get batch size B = x.shape[0] # Decoder context_cond = self.model.decoder_proj_context(x) + encoder_emb decoder_mod_dict_cond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_cond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_cond, target_mod, seed=seed) device = decoder_ids.device seq_len = self.model.modality_info[target_mod]['max_tokens'] ### 2 - Encoder forward pass, without conditioning mod_dict_uncond = copy.deepcopy(mod_dict) for mod in conditioning: if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) elif self.model.modality_info[mod]['type'] in ['seq_emb']: mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) else: mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) # Encoder encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) for mod, d in mod_dict_uncond.items() if mod in self.model.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask_uncond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) x = encoder_tokens + encoder_emb x = self.model.forward_encoder(x, encoder_mask_uncond) # B, N, D # Decoder context_uncond = self.model.decoder_proj_context(x) + encoder_emb decoder_mod_dict_uncond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_uncond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_uncond, target_mod, seed=seed) if use_eos and eos_token is None: # The eos_token is the final sentinel token provided eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all if use_eos: eos_token = eos_token.to(device) # If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token) out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device) # Set decoder_tokens to None, we do not use them for decoding decoder_ids = None # If all samples of the batch have eos, return early if use_eos and (out == eos_token).any(dim=-1).all(): return out y_emb = decoder_emb[:, :seq_len] seq_len = y_emb.shape[1] ### 3 - Auto-regressive decoding and sampling for i in range(seq_len): cur_len = out.shape[1] # Convert ids into word embeddings and add corresponding posembs + modemb y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len] # Build causal mask causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1) causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) ### 3a - Decoder forward pass, with conditioning y_cond = self.model.forward_decoder(y, context_cond, encoder_mask_cond, causal_mask) logits_cond = self.model.forward_logits(y_cond, decoder_mod_dict_cond, decoder_mod_mask_cond[:, :cur_len])[target_mod] logits_cond = rearrange(logits_cond, "(b n) d -> b n d", b=B, n=cur_len) last_logits_cond = logits_cond[:, -1] ### 3b - Decoder forward pass, without conditioning y_uncond = self.model.forward_decoder(y, context_uncond, encoder_mask_uncond, causal_mask) logits_uncond = self.model.forward_logits(y_uncond, decoder_mod_dict_uncond, decoder_mod_mask_uncond[:, :cur_len])[target_mod] logits_uncond = rearrange(logits_uncond, "(b n) d -> b n d", b=B, n=cur_len) last_logits_uncond = logits_uncond[:, -1] ### 3c - Classifier-free guidance last_logits = last_logits_uncond + (last_logits_cond - last_logits_uncond) * guidance_scale # Sample token for the newly generated logit if np.isclose(temperature, 0, atol=1e-10): sample = torch.argmax(last_logits, dim=-1, keepdim=True) else: filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) if use_eos and (out == eos_token).any(dim=-1).all(): break mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer) return mod_dict @torch.no_grad() def generate(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None): """ Generates a sequence of tokens from the input modalities. :param mod_dict: Dictionary of modalities. :param schedule: Schedule of modalities to use. List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}. :param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering). :param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering). :param text_tokenizer: Text tokenizer. :param verbose: Whether to print progress. :param seed: Random seed. :return: Generated mod dict. """ # Input embedding -> tokenizes the modalities - Many are placeholder for now mod_dict = copy.deepcopy(mod_dict) for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): target_mod = schedule_step_info['target_domain'] temp = schedule_step_info['temperature'] cfg_scale = schedule_step_info.get('cfg_scale', 1.0) cfg_conditioning = schedule_step_info.get('cfg_cond_domains', []) seed_i = seed + step if seed is not None else None if self.model.modality_info[target_mod]['type'] == 'img': scheme = schedule_step_info['scheme'] num_select = schedule_step_info['num_tokens'] if scheme.lower() == 'maskgit': if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.maskgit_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, seed=seed_i ) else: mod_dict = self.guided_maskgit_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i ) elif scheme.lower() == 'roar': if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.roar_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, seed=seed_i ) else: mod_dict = self.guided_roar_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i ) else: raise ValueError("Invalid sampling scheme") elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']: if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.autoregressive_step_batched( mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, text_tokenizer=text_tokenizer, seed=seed_i ) else: mod_dict = self.guided_autoregressive_step_batched( mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, text_tokenizer=text_tokenizer, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i ) else: raise ValueError("Invalid schedule") return mod_dict @torch.no_grad() def generate_iter(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None): """ Iterator that generates a sequence of tokens from the input modalities step by step. :param mod_dict: Dictionary of modalities. :param schedule: Schedule of modalities to use. List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}. :param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering). :param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering). :param text_tokenizer: Text tokenizer. :param verbose: Whether to print progress. :param seed: Random seed. :return: Iterator of generated mod dict. """ # Input embedding -> tokenizes the modalities - Many are placeholder for now mod_dict = copy.deepcopy(mod_dict) for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): target_mod = schedule_step_info['target_domain'] temp = schedule_step_info['temperature'] cfg_scale = schedule_step_info.get('cfg_scale', 1.0) cfg_conditioning = schedule_step_info.get('cfg_cond_domains', []) seed_i = seed + step if seed is not None else None if self.model.modality_info[target_mod]['type'] == 'img': scheme = schedule_step_info['scheme'] num_select = schedule_step_info['num_tokens'] if scheme.lower() == 'maskgit': if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.maskgit_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, seed=seed_i ) else: mod_dict = self.guided_maskgit_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i, write_all_predictions=True ) elif scheme.lower() == 'roar': if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.roar_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, seed=seed_i ) else: mod_dict = self.guided_roar_step_batched( mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i ) else: raise ValueError("Invalid sampling scheme") elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']: if cfg_scale == 1.0 or len(cfg_conditioning) == 0: mod_dict = self.autoregressive_step_batched( mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, text_tokenizer=text_tokenizer, seed=seed_i ) else: mod_dict = self.guided_autoregressive_step_batched( mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, text_tokenizer=text_tokenizer, conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i ) else: raise ValueError("Invalid schedule") yield mod_dict @torch.no_grad() def generate_multi_guided(self, uncond_dict, cond_dicts, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None): # Generation function for multiple weighted conditions # To detect when a modality has finished generating, we keep track of the current target modality cur_target_mod = schedule[0]['target_domain'] uncond_dict = copy.deepcopy(uncond_dict) cond_dicts = copy.deepcopy(cond_dicts) # Add the to-be-generated modality to the conditional dicts for i in range(len(cond_dicts)): cond_dicts[i][cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod]) for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): target_mod = schedule_step_info['target_domain'] temp = schedule_step_info['temperature'] num_select = schedule_step_info['num_tokens'] cond_weights = schedule_step_info['cfg_scale'] # Once a modality is fully generated, add it as a new condition if cur_target_mod != target_mod: for i in range(len(cond_dicts)): # Remove the previously generated modality from the conditionings del cond_dicts[i][cur_target_mod] # Add the next modality to be generated to the conditionings cond_dicts[i][target_mod] = copy.deepcopy(uncond_dict[target_mod]) # Remove the fully generated modality from the unconditional dict inputs uncond_dict[cur_target_mod]['input_mask'][:] = True # Add the previously generated modality as an additional condition new_cond = {} new_cond[cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod]) new_cond[cur_target_mod]['input_mask'][:] = False new_cond[cur_target_mod]['target_mask'][:] = True new_cond[target_mod] = copy.deepcopy(uncond_dict[target_mod]) cond_dicts.append(new_cond) cur_target_mod = target_mod if self.model.modality_info[target_mod]['type'] == 'img': scheme = schedule_step_info['scheme'] if scheme.lower() == 'maskgit': uncond_dict, cond_dicts = self.multi_guided_maskgit_step_batched( uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed ) elif scheme.lower() == 'roar': uncond_dict, cond_dicts = self.multi_guided_roar_step_batched( uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed ) else: raise ValueError("Invalid sampling scheme") else: raise NotImplementedError("Only image modalities are supported for now") return uncond_dict @torch.no_grad() def generate_sam_dense(self, mod_dict, schedule, text_tokenizer, batch_size=16, key='sam_instance', top_k=0.0, top_p=0.0, seed=None, verbose=False): # Generation function for dense SAM instance prediction device = mod_dict[list(mod_dict.keys())[0]]['tensor'].device mod_dict = copy.deepcopy(mod_dict) # Repeat the input batch to match the batch size expanded_batch = expand_to_batch(copy.deepcopy(mod_dict), batch_size=batch_size) # Filter the schedule to only include the key domain schedule = [s for s in schedule if s['target_domain'] == key] out_dict = self.generate( expanded_batch, schedule, text_tokenizer=text_tokenizer, verbose=verbose, seed=seed, top_p=top_p, top_k=top_k, ) # Merge the batch generated sequences into one sequence sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values()) merged_seq = [] for i in range(batch_size): input_seq = out_dict[key]['tensor'][i] input_seq = input_seq[out_dict[key]['input_mask'][i] == 0] input_seq = input_seq.tolist() target_seq = out_dict[key]['tensor'][i] target_seq = target_seq[out_dict[key]['target_mask'][i] == 0] target_seq = target_seq.tolist() merged_seq.extend(merge_span_masking(input_seq, target_seq, sentinel_ids=sentinel_ids)) merged_seq = torch.tensor(merged_seq, device=device).unsqueeze(0) mod_dict[key] = { 'tensor': merged_seq, 'input_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device), 'target_mask': torch.ones(merged_seq.shape, dtype=torch.bool, device=device), 'decoder_attention_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device), } return mod_dict