File size: 7,009 Bytes
3419697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6159ab9
3419697
6159ab9
3419697
 
 
6159ab9
 
3419697
 
 
 
 
 
 
 
 
 
 
6159ab9
3419697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6159ab9
3419697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6159ab9
 
 
 
 
 
 
 
3419697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b5fd1
 
3419697
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import streamlit as st
import time
from PIL import Image
import matplotlib.pyplot as plt



from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
from tqdm import tqdm
import gan_cls_768
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"

def clean(txt):
    txt = txt.lower()
    txt = txt.strip()
    txt = txt.strip('.')
    return txt


max_len = 76

def tokenize(tokenizer, txt):
    return tokenizer(
        txt,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_offsets_mapping=False
    )


def encode(model, tokenizer, txt):
    txt = clean(txt)
    txt_tokenized = tokenize(tokenizer, txt)

    for k, v in txt_tokenized.items():
        txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]

    model.eval()
    with torch.no_grad():
        encoded = model(**txt_tokenized)

    return encoded.last_hidden_state.squeeze()[0].cpu().numpy()


@st.cache_resource
def get_model_roberta():
    model_name = 'roberta-base'
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(
        model_name,
        config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)

    return model, tokenizer


@st.cache_resource
def get_model_gan():
    generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
    generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
    generator.eval()
    return generator
    


def generate_image(text, n):
    model, tokenizer = get_model_roberta()
    generator = get_model_gan()
    
    embed = encode(model, tokenizer, text)
    embed2 = torch.FloatTensor(embed)
    embed2 = embed2.unsqueeze(0)
    right_embed = Variable(embed2.float()).to(device)
    
    l = []
    for i in tqdm(range(n)):
        noise = Variable(torch.randn(1, 100)).to(device)
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = generator(right_embed, noise)
        
        for idx, image in enumerate(fake_images):
            im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
            l.append(im)
    return l




st.set_page_config(
    page_title="ImageGen",
    page_icon="🧊",
    layout="centered",
    initial_sidebar_state="expanded",
    )


hide_st_style = """
            <style>
            #MainMenu {visibility: hidden;}
            footer {visibility: hidden;}
            header {visibility: hidden;}
            </style>
            """
st.markdown(hide_st_style, unsafe_allow_html=True)



examples = [
    "this petal has gorgeous purple petals and a long green pedicel",
    "this petal has gorgeous green petals and a long green pedicel",
    "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
    "this flower has petals that are pink and bell shaped",
    "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
    "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
    "delicated pink petals clumped on one green pedicel with small sepals.",
    "the flower has big yellow upright petals attached to a thick vine",
    "these bright flowers have many yellow strip petals and stamen.",
    "a large red flower with black dots and a very long stigmas.", 
    "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
    "this flower has petals that are yellow and has black lines",
    "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
    "this flower has orange petals with many dark spots, white stamen, and dark anthers.",
    "this flower has petals that are white and has a yellow style",
    "his flower has petals that are orange and are very thin",
    "a flower with singular conical purple petal and large white pistil.",
    "the flower has bright yellow soft petals with yellow stamens.",
    "this flower has petals that are purple and have dark lines",
    "this purple flower has pointy short petals and green sepal.",
    "this flower has petals that are purple and has a yellow style",
    "the petals on this flower are orange with a purple pistil.",
    
    "a flower with no visible petals and purple pistils in the center.",
    "a star shaped flower with five white petals with purple lines running through them.",
    "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
    "this flower features a purple stigma surrounded by pointed waxy orange petals.",
    ]



def app():

    st.title("Text to Flower")
    st.markdown(
        """
        **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
        Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
    """
    )

    
    
    se = st.selectbox("Select from example", examples)
    
    row1_col1, row1_col2 = st.columns([2, 3])
    width = 950
    height = 600

    with row1_col1:
        caption = st.text_area("Write your flower description here:", se, height=120)
        
        
        backend = st.selectbox(
            "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
        )

        

        if st.button("Generate", type="primary"):
            with st.spinner("Generating Flower Images..."):
                
                
                # # gen all
                # for i in examples:
                #     imgs = generate_image(i, 1)
                #     st.markdown(i)
                    
                #     st.image(imgs[0])
                    
                imgs = generate_image(caption, 12)
                #ss = st.success("Scores predicted successfully!")
                
                with row1_col2:
                    st.markdown("Generated Flower Images:")
                    
                    fig, ax = plt.subplots(nrows=3, ncols=4)
                    ax = ax.flatten()
                    
                    for idx, ax in enumerate(ax):
                        ax.imshow(imgs[idx])
                        ax.axis('off')
                    
                    fig.tight_layout()
                    st.pyplot(fig)
                    
                    
    
app()

# # Display a footer with links and credits
#st.markdown("---")
#st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
# #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")