import streamlit as st import time from PIL import Image import matplotlib.pyplot as plt from transformers import AutoTokenizer, AutoModel, AutoConfig import torch from tqdm import tqdm import gan_cls_768 from torch.autograd import Variable from PIL import Image import matplotlib.pyplot as plt device = "cuda" if torch.cuda.is_available() else "cpu" def clean(txt): txt = txt.lower() txt = txt.strip() txt = txt.strip('.') return txt max_len = 76 def tokenize(tokenizer, txt): return tokenizer( txt, max_length=max_len, padding='max_length', truncation=True, return_offsets_mapping=False ) def encode(model, tokenizer, txt): txt = clean(txt) txt_tokenized = tokenize(tokenizer, txt) for k, v in txt_tokenized.items(): txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] model.eval() with torch.no_grad(): encoded = model(**txt_tokenized) return encoded.last_hidden_state.squeeze()[0].cpu().numpy() @st.cache_resource def get_model_roberta(): model_name = 'roberta-base' tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained( model_name, config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) return model, tokenizer @st.cache_resource def get_model_gan(): generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) generator.eval() return generator def generate_image(text, n): model, tokenizer = get_model_roberta() generator = get_model_gan() embed = encode(model, tokenizer, text) embed2 = torch.FloatTensor(embed) embed2 = embed2.unsqueeze(0) right_embed = Variable(embed2.float()).to(device) l = [] for i in tqdm(range(n)): noise = Variable(torch.randn(1, 100)).to(device) noise = noise.view(noise.size(0), 100, 1, 1) fake_images = generator(right_embed, noise) for idx, image in enumerate(fake_images): im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) l.append(im) return l st.set_page_config( page_title="ImageGen", page_icon="🧊", layout="centered", initial_sidebar_state="expanded", ) hide_st_style = """ """ st.markdown(hide_st_style, unsafe_allow_html=True) examples = [ "this petal has gorgeous purple petals and a long green pedicel", "this petal has gorgeous green petals and a long green pedicel", "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.", "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.", "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.", "delicated pink petals clumped on one green pedicel with small sepals.", "the flower has big yellow upright petals attached to a thick vine", "these bright flowers have many yellow strip petals and stamen.", "a large red flower with black dots and a very long stigmas.", "this flower has petals that are pink and bell shaped", "this flower has petals that are yellow and has black lines", "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre", "this flower has orange petals with many dark spots, white stamen, and dark anthers.", "this flower has petals that are white and has a yellow style", "his flower has petals that are orange and are very thin", "a flower with singular conical purple petal and large white pistil.", "this flower is yellow in color, and has petals that are very skinny.", "a velvet large flower with a dark marking and a green stem.", "this flower is yellow in color, and has petals that are very skinny.", "the flower has bright yellow soft petals with yellow stamens.", "this flower has petals that are pink and has red stamen", "this flower has petals that are purple and have dark lines", "this purple flower has pointy short petals and green sepal.", "this flower has petals that are purple and has a yellow style", "this flower is yellow in color, with petals that are skinny and pointed.", "the petals on this flower are orange with a purple pistil.", "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.", "this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center", "this flower has petals that are red and are very thin", "a flower with many folded over bright yellow petals", "a flower with no visible petals and purple pistils in the center.", "a star shaped flower with five white petals with purple lines running through them.", "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.", "this flower features a purple stigma surrounded by pointed waxy orange petals.", "this flower is yellow and brown in color, with petals that are oval shaped.", "this flower has petals that are white and has a yellow stigma", "a flower with folded open and back red petals with black spots and think red anther", "this flower has large light red petals and a few white stamen in the center", "this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.", "this flower is a beauty with light red leaves in an equal circle.", "a flower with an open conical red petal and white anther supported by red filaments", "this flower is red in color, with petals that are bell shaped.", "the petals of this flower are yellow with a long stigma", ] def app(): st.title("Text to Flower") st.markdown( """ **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach. Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"* """ ) se = st.selectbox("Select from example", examples) row1_col1, row1_col2 = st.columns([2, 3]) width = 950 height = 600 with row1_col1: caption = st.text_area("Write your flower description here:", se, height=120) backend = st.selectbox( "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0 ) if st.button("Generate", type="primary"): with st.spinner("Generating Flower Images..."): imgs = generate_image(caption, 12) #ss = st.success("Scores predicted successfully!") with row1_col2: st.markdown("Generated Flower Images:") fig, ax = plt.subplots(nrows=3, ncols=4) ax = ax.flatten() for idx, ax in enumerate(ax): ax.imshow(imgs[idx]) ax.axis('off') fig.tight_layout() st.pyplot(fig) # with row1_col2: # img1 = Image.open('./images/t2i/1.jpg') # img2 = Image.open('./images/t2i/2.jpg') # img3 = Image.open('./images/t2i/3.jpg') # img4 = Image.open('./images/t2i/4.jpg') # cont = st.container() # with cont: # st.write("This is a container with a caption like a button.") # col1, col2, col3, col4 = st.columns(4) # with col1: # st.image(img1, width=128) # with col2: # st.image(img2, width=128) # with col3: # st.image(img3, width=128) # with col4: # st.image(img4, width=128) app() # # Display a footer with links and credits st.markdown("---") st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).") # #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")