T2I / T2I.py
DataRaptor's picture
Upload 6 files
f8a1225
raw
history blame
No virus
2.53 kB
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
def clean(txt):
txt = txt.lower()
txt = txt.strip()
txt = txt.strip('.')
return txt
max_len = 76
def tokenize(tokenizer, txt):
return tokenizer(
txt,
max_length=max_len,
padding='max_length',
truncation=True,
return_offsets_mapping=False
)
def encode(model_name, model, tokenizer, txt):
txt = clean(txt)
txt_tokenized = tokenize(tokenizer, txt)
for k, v in txt_tokenized.items():
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
model.eval()
with torch.no_grad():
encoded = model(**txt_tokenized)
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
model_name = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name,
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
def generate_image(text, n):
embed = encode(model_name, model, tokenizer, text)
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
generator.eval()
embed2 = torch.FloatTensor(embed)
embed2 = embed2.unsqueeze(0)
right_embed = Variable(embed2.float()).to(device)
l = []
for i in tqdm(range(n)):
noise = Variable(torch.randn(1, 100)).to(device)
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = generator(right_embed, noise)
for idx, image in enumerate(fake_images):
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
l.append(im)
return l
if __name__ == '__main__':
n = 10
imgs = generate_image('Red images', n)
fig, ax = plt.subplots(nrows=5, ncols=2)
ax = ax.flatten()
for idx, ax in enumerate(ax):
ax.imshow(imgs[idx])
ax.axis('off')
fig.tight_layout()
plt.show()
# while True:
# print('Type Caption: ')
# txt = input()
# print('Generating images...')
# generate_image(txt)
# print('Completed')