T2I / app.py
DataRaptor's picture
Upload app.py
6159ab9
raw
history blame contribute delete
No virus
7.01 kB
import streamlit as st
import time
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
def clean(txt):
txt = txt.lower()
txt = txt.strip()
txt = txt.strip('.')
return txt
max_len = 76
def tokenize(tokenizer, txt):
return tokenizer(
txt,
max_length=max_len,
padding='max_length',
truncation=True,
return_offsets_mapping=False
)
def encode(model, tokenizer, txt):
txt = clean(txt)
txt_tokenized = tokenize(tokenizer, txt)
for k, v in txt_tokenized.items():
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
model.eval()
with torch.no_grad():
encoded = model(**txt_tokenized)
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
@st.cache_resource
def get_model_roberta():
model_name = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name,
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
return model, tokenizer
@st.cache_resource
def get_model_gan():
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
generator.eval()
return generator
def generate_image(text, n):
model, tokenizer = get_model_roberta()
generator = get_model_gan()
embed = encode(model, tokenizer, text)
embed2 = torch.FloatTensor(embed)
embed2 = embed2.unsqueeze(0)
right_embed = Variable(embed2.float()).to(device)
l = []
for i in tqdm(range(n)):
noise = Variable(torch.randn(1, 100)).to(device)
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = generator(right_embed, noise)
for idx, image in enumerate(fake_images):
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
l.append(im)
return l
st.set_page_config(
page_title="ImageGen",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
examples = [
"this petal has gorgeous purple petals and a long green pedicel",
"this petal has gorgeous green petals and a long green pedicel",
"a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
"this flower has petals that are pink and bell shaped",
"salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
"this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
"delicated pink petals clumped on one green pedicel with small sepals.",
"the flower has big yellow upright petals attached to a thick vine",
"these bright flowers have many yellow strip petals and stamen.",
"a large red flower with black dots and a very long stigmas.",
"this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
"this flower has petals that are yellow and has black lines",
"the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
"this flower has orange petals with many dark spots, white stamen, and dark anthers.",
"this flower has petals that are white and has a yellow style",
"his flower has petals that are orange and are very thin",
"a flower with singular conical purple petal and large white pistil.",
"the flower has bright yellow soft petals with yellow stamens.",
"this flower has petals that are purple and have dark lines",
"this purple flower has pointy short petals and green sepal.",
"this flower has petals that are purple and has a yellow style",
"the petals on this flower are orange with a purple pistil.",
"a flower with no visible petals and purple pistils in the center.",
"a star shaped flower with five white petals with purple lines running through them.",
"the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
"this flower features a purple stigma surrounded by pointed waxy orange petals.",
]
def app():
st.title("Text to Flower")
st.markdown(
"""
**Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
"""
)
se = st.selectbox("Select from example", examples)
row1_col1, row1_col2 = st.columns([2, 3])
width = 950
height = 600
with row1_col1:
caption = st.text_area("Write your flower description here:", se, height=120)
backend = st.selectbox(
"Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
)
if st.button("Generate", type="primary"):
with st.spinner("Generating Flower Images..."):
# # gen all
# for i in examples:
# imgs = generate_image(i, 1)
# st.markdown(i)
# st.image(imgs[0])
imgs = generate_image(caption, 12)
#ss = st.success("Scores predicted successfully!")
with row1_col2:
st.markdown("Generated Flower Images:")
fig, ax = plt.subplots(nrows=3, ncols=4)
ax = ax.flatten()
for idx, ax in enumerate(ax):
ax.imshow(imgs[idx])
ax.axis('off')
fig.tight_layout()
st.pyplot(fig)
app()
# # Display a footer with links and credits
#st.markdown("---")
#st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")