Spaces:
Sleeping
Sleeping
File size: 7,009 Bytes
3419697 6159ab9 3419697 6159ab9 3419697 6159ab9 3419697 6159ab9 3419697 6159ab9 3419697 6159ab9 3419697 c0b5fd1 3419697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import streamlit as st
import time
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
def clean(txt):
txt = txt.lower()
txt = txt.strip()
txt = txt.strip('.')
return txt
max_len = 76
def tokenize(tokenizer, txt):
return tokenizer(
txt,
max_length=max_len,
padding='max_length',
truncation=True,
return_offsets_mapping=False
)
def encode(model, tokenizer, txt):
txt = clean(txt)
txt_tokenized = tokenize(tokenizer, txt)
for k, v in txt_tokenized.items():
txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
model.eval()
with torch.no_grad():
encoded = model(**txt_tokenized)
return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
@st.cache_resource
def get_model_roberta():
model_name = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(
model_name,
config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
return model, tokenizer
@st.cache_resource
def get_model_gan():
generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
generator.eval()
return generator
def generate_image(text, n):
model, tokenizer = get_model_roberta()
generator = get_model_gan()
embed = encode(model, tokenizer, text)
embed2 = torch.FloatTensor(embed)
embed2 = embed2.unsqueeze(0)
right_embed = Variable(embed2.float()).to(device)
l = []
for i in tqdm(range(n)):
noise = Variable(torch.randn(1, 100)).to(device)
noise = noise.view(noise.size(0), 100, 1, 1)
fake_images = generator(right_embed, noise)
for idx, image in enumerate(fake_images):
im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
l.append(im)
return l
st.set_page_config(
page_title="ImageGen",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
examples = [
"this petal has gorgeous purple petals and a long green pedicel",
"this petal has gorgeous green petals and a long green pedicel",
"a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
"this flower has petals that are pink and bell shaped",
"salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
"this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
"delicated pink petals clumped on one green pedicel with small sepals.",
"the flower has big yellow upright petals attached to a thick vine",
"these bright flowers have many yellow strip petals and stamen.",
"a large red flower with black dots and a very long stigmas.",
"this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
"this flower has petals that are yellow and has black lines",
"the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
"this flower has orange petals with many dark spots, white stamen, and dark anthers.",
"this flower has petals that are white and has a yellow style",
"his flower has petals that are orange and are very thin",
"a flower with singular conical purple petal and large white pistil.",
"the flower has bright yellow soft petals with yellow stamens.",
"this flower has petals that are purple and have dark lines",
"this purple flower has pointy short petals and green sepal.",
"this flower has petals that are purple and has a yellow style",
"the petals on this flower are orange with a purple pistil.",
"a flower with no visible petals and purple pistils in the center.",
"a star shaped flower with five white petals with purple lines running through them.",
"the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
"this flower features a purple stigma surrounded by pointed waxy orange petals.",
]
def app():
st.title("Text to Flower")
st.markdown(
"""
**Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
"""
)
se = st.selectbox("Select from example", examples)
row1_col1, row1_col2 = st.columns([2, 3])
width = 950
height = 600
with row1_col1:
caption = st.text_area("Write your flower description here:", se, height=120)
backend = st.selectbox(
"Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
)
if st.button("Generate", type="primary"):
with st.spinner("Generating Flower Images..."):
# # gen all
# for i in examples:
# imgs = generate_image(i, 1)
# st.markdown(i)
# st.image(imgs[0])
imgs = generate_image(caption, 12)
#ss = st.success("Scores predicted successfully!")
with row1_col2:
st.markdown("Generated Flower Images:")
fig, ax = plt.subplots(nrows=3, ncols=4)
ax = ax.flatten()
for idx, ax in enumerate(ax):
ax.imshow(imgs[idx])
ax.axis('off')
fig.tight_layout()
st.pyplot(fig)
app()
# # Display a footer with links and credits
#st.markdown("---")
#st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")
|