# ------------------------------------------ # TextDiffuser: Diffusion Models as Text Painters # Paper Link: https://arxiv.org/abs/2305.10855 # Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser # Copyright (c) Microsoft Corporation. # This file define the Layout Transformer for predicting the layout of keywords. # ------------------------------------------ import torch import torch.nn as nn from transformers import CLIPTokenizer, CLIPTextModel class TextConditioner(nn.Module): def __init__(self): super(TextConditioner, self).__init__() self.transformer = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14') # fix self.transformer.eval() for param in self.transformer.parameters(): param.requires_grad = False def forward(self, prompt_list): batch_encoding = self.tokenizer(prompt_list, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") text_embedding = self.transformer(batch_encoding["input_ids"].cuda()) return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768 / 1, 768 class LayoutTransformer(nn.Module): def __init__(self, layer_number=2): super(LayoutTransformer, self).__init__() self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) self.transformer = torch.nn.TransformerEncoder( self.encoder_layer, num_layers=layer_number ) self.decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) self.decoder_transformer = torch.nn.TransformerDecoder( self.decoder_layer, num_layers=layer_number ) self.mask_embedding = nn.Embedding(2,512) self.length_embedding = nn.Embedding(256,512) self.width_embedding = nn.Embedding(256,512) self.position_embedding = nn.Embedding(256,512) self.state_embedding = nn.Embedding(256,512) self.match_embedding = nn.Embedding(256,512) self.x_embedding = nn.Embedding(512,512) self.y_embedding = nn.Embedding(512,512) self.w_embedding = nn.Embedding(512,512) self.h_embedding = nn.Embedding(512,512) self.encoder_target_embedding = nn.Embedding(256,512) self.input_layer = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 512), ) self.output_layer = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 4), ) def forward(self, x, length, width, mask, state, match, target, right_shifted_boxes, train=False, encoder_embedding=None): # detect whether the encoder_embedding is cached if encoder_embedding is None: # augmentation if train: width = width + torch.randint(-3, 3, (width.shape[0], width.shape[1])).cuda() x = self.input_layer(x) # (1, 77, 512) width_embedding = self.width_embedding(torch.clamp(width, 0, 255).long()) # (1, 77, 512) encoder_target_embedding = self.encoder_target_embedding(target[:,:,0].long()) # (1, 77, 512) pe_embedding = self.position_embedding(torch.arange(77).cuda()).unsqueeze(0) # (1, 77, 512) total_embedding = x + width_embedding + pe_embedding + encoder_target_embedding # combine all the embeddings (1, 77, 512) total_embedding = total_embedding.permute(1,0,2) # (77, 1, 512) encoder_embedding = self.transformer(total_embedding) # (77, 1, 512) right_shifted_boxes_resize = (right_shifted_boxes * 512).long() # (1, 8, 4) right_shifted_boxes_resize = torch.clamp(right_shifted_boxes_resize, 0, 511) # (1, 8, 4) # decoder pe pe_decoder = torch.arange(8).cuda() # (8, ) pe_embedding_decoder = self.position_embedding(pe_decoder).unsqueeze(0) # (1, 8, 512) decoder_input = pe_embedding_decoder + self.x_embedding(right_shifted_boxes_resize[:,:,0]) + self.y_embedding(right_shifted_boxes_resize[:,:,1]) + self.w_embedding(right_shifted_boxes_resize[:,:,2]) + self.h_embedding(right_shifted_boxes_resize[:,:,3]) # (1, 8, 512) decoder_input = decoder_input.permute(1,0,2) # (8, 1, 512) # generate triangular mask mask = nn.Transformer.generate_square_subsequent_mask(8) # (8, 8) mask = mask.cuda() # (8, 8) decoder_result = self.decoder_transformer(decoder_input, encoder_embedding, tgt_mask=mask) # (8, 1, 512) decoder_result = decoder_result.permute(1,0,2) # (1, 8, 512) box_prediction = self.output_layer(decoder_result) # (1, 8, 4) return box_prediction, encoder_embedding