from transformers import AutoTokenizer, AutoModel, AutoConfig import torch from tqdm import tqdm import gan_cls_768 from torch.autograd import Variable from PIL import Image import matplotlib.pyplot as plt device = "cuda" if torch.cuda.is_available() else "cpu" def clean(txt): txt = txt.lower() txt = txt.strip() txt = txt.strip('.') return txt max_len = 76 def tokenize(tokenizer, txt): return tokenizer( txt, max_length=max_len, padding='max_length', truncation=True, return_offsets_mapping=False ) def encode(model_name, model, tokenizer, txt): txt = clean(txt) txt_tokenized = tokenize(tokenizer, txt) for k, v in txt_tokenized.items(): txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] model.eval() with torch.no_grad(): encoded = model(**txt_tokenized) return encoded.last_hidden_state.squeeze()[0].cpu().numpy() model_name = 'roberta-base' tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained( model_name, config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) def generate_image(text, n): embed = encode(model_name, model, tokenizer, text) generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) generator.eval() embed2 = torch.FloatTensor(embed) embed2 = embed2.unsqueeze(0) right_embed = Variable(embed2.float()).to(device) l = [] for i in tqdm(range(n)): noise = Variable(torch.randn(1, 100)).to(device) noise = noise.view(noise.size(0), 100, 1, 1) fake_images = generator(right_embed, noise) for idx, image in enumerate(fake_images): im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) l.append(im) return l if __name__ == '__main__': n = 10 imgs = generate_image('Red images', n) fig, ax = plt.subplots(nrows=5, ncols=2) ax = ax.flatten() for idx, ax in enumerate(ax): ax.imshow(imgs[idx]) ax.axis('off') fig.tight_layout() plt.show() # while True: # print('Type Caption: ') # txt = input() # print('Generating images...') # generate_image(txt) # print('Completed')