TextDiffuser / model /layout_generator.py
watermelon23's picture
update
e8dca02
raw
history blame
No virus
8.85 kB
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file aims to predict the layout of keywords in user prompts.
# ------------------------------------------
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import re
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from PIL import Image, ImageDraw, ImageFont
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic
from model.layout_transformer import LayoutTransformer, TextConditioner
from termcolor import colored
# import layout transformer
model = LayoutTransformer().cuda().eval()
model.load_state_dict(torch.load('textdiffuser-ckpt/layout_transformer.pth'))
# import text encoder and tokenizer
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
def process_caption(font_path, caption, keywords):
# remove punctuations. please remove this statement if you want to paint punctuations
caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption)
# tokenize it into ids and get length
caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
caption_words_ids = caption_words['input_ids'] # (1, 77)
length = caption_words['length'] # (1, )
# convert id to words
words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist())
words = [i.replace('</w>', '') for i in words]
words_valid = words[:int(length)]
# store the box coordinates and state of each token
info_array = np.zeros((77,5)) # (77, 5)
# split the caption into words and convert them into lower case
caption_split = caption.split()
caption_split = [i.lower() for i in caption_split]
start_dic = {} # get the start index of each word
state_list = [] # 0: start, 1: middle, 2: special token
word_match_list = [] # the index of the word in the caption
current_caption_index = 0
current_match = ''
for i in range(length):
# the first and last token are special tokens
if i == 0 or i == length-1:
state_list.append(2)
word_match_list.append(127)
continue
if current_match == '':
state_list.append(0)
start_dic[current_caption_index] = i
else:
state_list.append(1)
current_match += words_valid[i]
word_match_list.append(current_caption_index)
if current_match == caption_split[current_caption_index]:
current_match = ''
current_caption_index += 1
while len(state_list) < 77:
state_list.append(127)
while len(word_match_list) < 77:
word_match_list.append(127)
length_list = []
width_list =[]
for i in range(len(word_match_list)):
if word_match_list[i] == 127:
length_list.append(0)
width_list.append(0)
else:
length_list.append(len(caption.split()[word_match_list[i]]))
width_list.append(get_width(font_path, caption.split()[word_match_list[i]]))
while len(length_list) < 77:
length_list.append(127)
width_list.append(0)
length_list = torch.Tensor(length_list).long() # (77, )
width_list = torch.Tensor(width_list).long() # (77, )
boxes = []
duplicate_dict = {} # some words may appear more than once
for keyword in keywords:
keyword = keyword.lower()
if keyword in caption_split:
if keyword not in duplicate_dict:
duplicate_dict[keyword] = caption_split.index(keyword)
index = caption_split.index(keyword)
else:
if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
duplicate_dict[keyword] = index
else:
continue
index = caption_split.index(keyword)
index = start_dic[index]
info_array[index][0] = 1
box = [0,0,0,0]
boxes.append(list(box))
info_array[index][1:] = box
boxes_length = len(boxes)
if boxes_length > 8:
boxes = boxes[:8]
while len(boxes) < 8:
boxes.append([0,0,0,0])
return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length
def get_layout_from_prompt(args):
# prompt = args.prompt
font_path = args.font_path
keywords = get_key_words(args.prompt)
print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}')
text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77)
# process all relevant info
caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords)
target = target.cuda().unsqueeze(0) # (77, 5)
width_list = width_list.cuda().unsqueeze(0) # (77, )
length_list = length_list.cuda().unsqueeze(0) # (77, )
state_list = state_list.cuda().unsqueeze(0) # (77, )
word_match_list = word_match_list.cuda().unsqueeze(0) # (77, )
padding = torch.zeros(1, 1, 4).cuda()
boxes = boxes.unsqueeze(0).cuda()
right_shifted_boxes = torch.cat([padding, boxes[:,0:-1,:]],1) # (1, 8, 4)
# inference
return_boxes= []
with torch.no_grad():
for box_index in range(boxes_length):
if box_index == 0:
encoder_embedding = None
output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding)
output = torch.clamp(output, min=0, max=1) # (1, 8, 4)
# add overlap detection
output = adjust_overlap_box(output, box_index) # (1, 8, 4)
right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:]
xmin, ymin, xmax, ymax = output[0, box_index, :].tolist()
return_boxes.append([xmin, ymin, xmax, ymax])
# print the location of keywords
print(f'index\tkeyword\tx_min\ty_min\tx_max\ty_max')
for index, keyword in enumerate(keywords):
x_min = int(return_boxes[index][0] * 512)
y_min = int(return_boxes[index][1] * 512)
x_max = int(return_boxes[index][2] * 512)
y_max = int(return_boxes[index][3] * 512)
print(f'{index}\t{keyword}\t{x_min}\t{y_min}\t{x_max}\t{y_max}')
# paint the layout
render_image = Image.new('RGB', (512, 512), (255, 255, 255))
draw = ImageDraw.Draw(render_image)
segmentation_mask = Image.new("L", (512,512), 0)
segmentation_mask_draw = ImageDraw.Draw(segmentation_mask)
for index, box in enumerate(return_boxes):
box = [int(i*512) for i in box]
xmin, ymin, xmax, ymax = box
width = xmax - xmin
height = ymax - ymin
text = keywords[index]
font_size = adjust_font_size(args, width, height, draw, text)
font = ImageFont.truetype(args.font_path, font_size)
# draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0))
draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0))
boxes = []
for i, char in enumerate(text):
# paint character-level segmentation masks
# https://github.com/python-pillow/Pillow/issues/3921
bottom_1 = font.getsize(text[i])[1]
right, bottom_2 = font.getsize(text[:i+1])
bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2
width, height = font.getmask(char).size
right += xmin
bottom += ymin
top = bottom - height
left = right - width
char_box = (left, top, right, bottom)
boxes.append(char_box)
char_index = alphabet_dic[char]
segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index)
print(f'{colored("[√]", "green")} Layout is successfully generated')
return render_image, segmentation_mask