DataRaptor commited on
Commit
f8a1225
1 Parent(s): 3face26

Upload 6 files

Browse files
Files changed (6) hide show
  1. T2I.py +112 -0
  2. gan_cls_768.py +151 -0
  3. gen_125.pth +3 -0
  4. main.py +239 -0
  5. requirements.txt +9 -0
  6. run.sh +1 -0
T2I.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
2
+ import torch
3
+ from tqdm import tqdm
4
+ import gan_cls_768
5
+ from torch.autograd import Variable
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+
12
+ def clean(txt):
13
+ txt = txt.lower()
14
+ txt = txt.strip()
15
+ txt = txt.strip('.')
16
+ return txt
17
+
18
+
19
+ max_len = 76
20
+
21
+
22
+ def tokenize(tokenizer, txt):
23
+ return tokenizer(
24
+ txt,
25
+ max_length=max_len,
26
+ padding='max_length',
27
+ truncation=True,
28
+ return_offsets_mapping=False
29
+ )
30
+
31
+
32
+ def encode(model_name, model, tokenizer, txt):
33
+ txt = clean(txt)
34
+ txt_tokenized = tokenize(tokenizer, txt)
35
+
36
+ for k, v in txt_tokenized.items():
37
+ txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
38
+
39
+ model.eval()
40
+ with torch.no_grad():
41
+ encoded = model(**txt_tokenized)
42
+
43
+ return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
44
+
45
+
46
+ model_name = 'roberta-base'
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ model = AutoModel.from_pretrained(
50
+ model_name,
51
+ config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
52
+
53
+
54
+
55
+ def generate_image(text, n):
56
+ embed = encode(model_name, model, tokenizer, text)
57
+
58
+ generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
59
+ generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
60
+ generator.eval()
61
+
62
+ embed2 = torch.FloatTensor(embed)
63
+ embed2 = embed2.unsqueeze(0)
64
+ right_embed = Variable(embed2.float()).to(device)
65
+
66
+ l = []
67
+ for i in tqdm(range(n)):
68
+ noise = Variable(torch.randn(1, 100)).to(device)
69
+ noise = noise.view(noise.size(0), 100, 1, 1)
70
+ fake_images = generator(right_embed, noise)
71
+
72
+ for idx, image in enumerate(fake_images):
73
+ im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
74
+ l.append(im)
75
+ return l
76
+
77
+
78
+
79
+ if __name__ == '__main__':
80
+
81
+
82
+ n = 10
83
+ imgs = generate_image('Red images', n)
84
+
85
+
86
+ fig, ax = plt.subplots(nrows=5, ncols=2)
87
+ ax = ax.flatten()
88
+
89
+ for idx, ax in enumerate(ax):
90
+
91
+ ax.imshow(imgs[idx])
92
+ ax.axis('off')
93
+
94
+
95
+ fig.tight_layout()
96
+
97
+ plt.show()
98
+
99
+
100
+
101
+ # while True:
102
+ # print('Type Caption: ')
103
+ # txt = input()
104
+ # print('Generating images...')
105
+ # generate_image(txt)
106
+ # print('Completed')
107
+
108
+
109
+
110
+
111
+
112
+
gan_cls_768.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import numpy as np
5
+ import pdb
6
+
7
+ from torch.nn import functional as F
8
+ from torch.nn import init
9
+
10
+
11
+ '''
12
+
13
+ '''
14
+
15
+
16
+ class Concat_embed4(nn.Module):
17
+
18
+ def __init__(self, embed_dim, projected_embed_dim):
19
+ super(Concat_embed4, self).__init__()
20
+ self.projection = nn.Sequential(
21
+ nn.Linear(in_features=embed_dim, out_features=embed_dim),
22
+ nn.BatchNorm1d(num_features=embed_dim),
23
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
24
+
25
+ nn.Linear(in_features=embed_dim, out_features=embed_dim),
26
+ nn.BatchNorm1d(num_features=embed_dim),
27
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
28
+
29
+ nn.Linear(in_features=embed_dim, out_features=projected_embed_dim),
30
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
31
+ )
32
+
33
+ def forward(self, inp, embed):
34
+ projected_embed = self.projection(embed)
35
+ replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
36
+ hidden_concat = torch.cat([inp, replicated_embed], 1)
37
+ return hidden_concat
38
+
39
+
40
+ class generator(nn.Module):
41
+ def __init__(self):
42
+ super(generator, self).__init__()
43
+ self.image_size = 64
44
+ self.num_channels = 3
45
+ self.noise_dim = 100
46
+ self.embed_dim = 768
47
+ self.projected_embed_dim = 128
48
+ self.latent_dim = self.noise_dim + self.projected_embed_dim
49
+ self.ngf = 64
50
+
51
+ self.projection = nn.Sequential(
52
+ nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
53
+ nn.BatchNorm1d(num_features=self.embed_dim),
54
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
55
+
56
+ nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
57
+ nn.BatchNorm1d(num_features=self.embed_dim),
58
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
59
+
60
+ nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim),
61
+ nn.BatchNorm1d(num_features=self.projected_embed_dim),
62
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
63
+ )
64
+
65
+ self.netG = nn.ModuleList([
66
+ nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False),
67
+ nn.BatchNorm2d(self.ngf * 8),
68
+ nn.ReLU(True),
69
+
70
+
71
+ # state size. (ngf*8) x 4 x 4
72
+ nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
73
+ nn.BatchNorm2d(self.ngf * 4),
74
+ nn.ReLU(True),
75
+
76
+ # state size. (ngf*4) x 8 x 8
77
+ nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
78
+ nn.BatchNorm2d(self.ngf * 2),
79
+ nn.ReLU(True),
80
+
81
+ # state size. (ngf*2) x 16 x 16
82
+ nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
83
+ nn.BatchNorm2d(self.ngf),
84
+ nn.ReLU(True),
85
+
86
+ # state size. (ngf) x 32 x 32
87
+ nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False),
88
+ nn.Tanh()
89
+ # state size. (num_channels) x 64 x 64
90
+ ])
91
+
92
+ def forward(self, embed_vector, z):
93
+ projected_embed = self.projection(embed_vector)
94
+ out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1)
95
+ for m in self.netG:
96
+ out = m(out)
97
+ return out
98
+
99
+
100
+ class discriminator(nn.Module):
101
+ def __init__(self):
102
+ super(discriminator, self).__init__()
103
+ self.image_size = 64
104
+ self.num_channels = 3
105
+ self.embed_dim = 768
106
+ self.projected_embed_dim = 128
107
+ self.ndf = 64
108
+ self.B_dim = 128
109
+ self.C_dim = 16
110
+
111
+ self.netD_1 = nn.Sequential(
112
+ # input is (nc) x 64 x 64
113
+ nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False),
114
+ nn.LeakyReLU(0.2, inplace=True),
115
+ # state size. (ndf) x 32 x 32
116
+
117
+ # SelfAttention(self.ndf),
118
+ nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
119
+ nn.BatchNorm2d(self.ndf * 2),
120
+ nn.LeakyReLU(0.2, inplace=True),
121
+
122
+ # state size. (ndf*2) x 16 x 16
123
+
124
+ nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
125
+ nn.BatchNorm2d(self.ndf * 4),
126
+ nn.LeakyReLU(0.2, inplace=True),
127
+
128
+ # state size. (ndf*4) x 8 x 8
129
+ nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
130
+ nn.BatchNorm2d(self.ndf * 8),
131
+ nn.LeakyReLU(0.2, inplace=True),
132
+ )
133
+
134
+ self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim)
135
+
136
+ self.netD_2 = nn.Sequential(
137
+ # state size. (ndf*8) x 4 x 4
138
+ nn.Conv2d(self.ndf * 8 + self.projected_embed_dim,
139
+ self.ndf * 8, 1, 1, 0, bias=False),
140
+ nn.BatchNorm2d(self.ndf * 8),
141
+ nn.LeakyReLU(0.2, inplace=True),
142
+ nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
143
+ nn.Sigmoid()
144
+ )
145
+
146
+ def forward(self, inp, embed):
147
+ x_intermediate = self.netD_1(inp)
148
+ x = self.projector(x_intermediate, embed)
149
+ x = self.netD_2(x)
150
+
151
+ return x.view(-1, 1).squeeze(1), x_intermediate
gen_125.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd835271c23087d4cf25a974b51e6680592d906da9cd20159a060123fdc7b8c5
3
+ size 23668507
main.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+
8
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
9
+ import torch
10
+ from tqdm import tqdm
11
+ import gan_cls_768
12
+ from torch.autograd import Variable
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def clean(txt):
18
+ txt = txt.lower()
19
+ txt = txt.strip()
20
+ txt = txt.strip('.')
21
+ return txt
22
+
23
+
24
+ max_len = 76
25
+
26
+ def tokenize(tokenizer, txt):
27
+ return tokenizer(
28
+ txt,
29
+ max_length=max_len,
30
+ padding='max_length',
31
+ truncation=True,
32
+ return_offsets_mapping=False
33
+ )
34
+
35
+
36
+ def encode(model, tokenizer, txt):
37
+ txt = clean(txt)
38
+ txt_tokenized = tokenize(tokenizer, txt)
39
+
40
+ for k, v in txt_tokenized.items():
41
+ txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None]
42
+
43
+ model.eval()
44
+ with torch.no_grad():
45
+ encoded = model(**txt_tokenized)
46
+
47
+ return encoded.last_hidden_state.squeeze()[0].cpu().numpy()
48
+
49
+
50
+ @st.cache_resource
51
+ def get_model_roberta():
52
+ model_name = 'roberta-base'
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModel.from_pretrained(
56
+ model_name,
57
+ config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device)
58
+
59
+ return model, tokenizer
60
+
61
+
62
+ @st.cache_resource
63
+ def get_model_gan():
64
+ generator = torch.nn.DataParallel(gan_cls_768.generator().to(device))
65
+ generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu')))
66
+ generator.eval()
67
+ return generator
68
+
69
+
70
+
71
+ def generate_image(text, n):
72
+ model, tokenizer = get_model_roberta()
73
+ generator = get_model_gan()
74
+
75
+ embed = encode(model, tokenizer, text)
76
+ embed2 = torch.FloatTensor(embed)
77
+ embed2 = embed2.unsqueeze(0)
78
+ right_embed = Variable(embed2.float()).to(device)
79
+
80
+ l = []
81
+ for i in tqdm(range(n)):
82
+ noise = Variable(torch.randn(1, 100)).to(device)
83
+ noise = noise.view(noise.size(0), 100, 1, 1)
84
+ fake_images = generator(right_embed, noise)
85
+
86
+ for idx, image in enumerate(fake_images):
87
+ im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
88
+ l.append(im)
89
+ return l
90
+
91
+
92
+
93
+
94
+ st.set_page_config(
95
+ page_title="ImageGen",
96
+ page_icon="🧊",
97
+ layout="centered",
98
+ initial_sidebar_state="expanded",
99
+ )
100
+
101
+
102
+ hide_st_style = """
103
+ <style>
104
+ #MainMenu {visibility: hidden;}
105
+ footer {visibility: hidden;}
106
+ header {visibility: hidden;}
107
+ </style>
108
+ """
109
+ st.markdown(hide_st_style, unsafe_allow_html=True)
110
+
111
+
112
+
113
+ examples = [
114
+ "this petal has gorgeous purple petals and a long green pedicel",
115
+ "this petal has gorgeous green petals and a long green pedicel",
116
+ "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.",
117
+ "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.",
118
+ "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.",
119
+ "delicated pink petals clumped on one green pedicel with small sepals.",
120
+ "the flower has big yellow upright petals attached to a thick vine",
121
+ "these bright flowers have many yellow strip petals and stamen.",
122
+ "a large red flower with black dots and a very long stigmas.",
123
+ "this flower has petals that are pink and bell shaped",
124
+ "this flower has petals that are yellow and has black lines",
125
+ "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre",
126
+ "this flower has orange petals with many dark spots, white stamen, and dark anthers.",
127
+ "this flower has petals that are white and has a yellow style",
128
+ "his flower has petals that are orange and are very thin",
129
+ "a flower with singular conical purple petal and large white pistil.",
130
+ "this flower is yellow in color, and has petals that are very skinny.",
131
+ "a velvet large flower with a dark marking and a green stem.",
132
+ "this flower is yellow in color, and has petals that are very skinny.",
133
+ "the flower has bright yellow soft petals with yellow stamens.",
134
+ "this flower has petals that are pink and has red stamen",
135
+ "this flower has petals that are purple and have dark lines",
136
+ "this purple flower has pointy short petals and green sepal.",
137
+ "this flower has petals that are purple and has a yellow style",
138
+ "this flower is yellow in color, with petals that are skinny and pointed.",
139
+ "the petals on this flower are orange with a purple pistil.",
140
+ "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.",
141
+ "this purple color flower has the simple row of petals arranged in the circle with the red color pistils at the center",
142
+ "this flower has petals that are red and are very thin",
143
+ "a flower with many folded over bright yellow petals",
144
+ "a flower with no visible petals and purple pistils in the center.",
145
+ "a star shaped flower with five white petals with purple lines running through them.",
146
+ "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.",
147
+ "this flower features a purple stigma surrounded by pointed waxy orange petals.",
148
+ "this flower is yellow and brown in color, with petals that are oval shaped.",
149
+ "this flower has petals that are white and has a yellow stigma",
150
+ "a flower with folded open and back red petals with black spots and think red anther",
151
+ "this flower has large light red petals and a few white stamen in the center",
152
+ "this flower has bright orange tubular petals rising out of a thick receptacle on a green pedicel.",
153
+ "this flower is a beauty with light red leaves in an equal circle.",
154
+ "a flower with an open conical red petal and white anther supported by red filaments",
155
+ "this flower is red in color, with petals that are bell shaped.",
156
+ "the petals of this flower are yellow with a long stigma",
157
+ ]
158
+
159
+
160
+
161
+ def app():
162
+
163
+ st.title("Text to Flower")
164
+ st.markdown(
165
+ """
166
+ **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach.
167
+ Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"*
168
+ """
169
+ )
170
+
171
+
172
+
173
+ se = st.selectbox("Select from example",
174
+ examples)
175
+
176
+ row1_col1, row1_col2 = st.columns([2, 3])
177
+ width = 950
178
+ height = 600
179
+
180
+ with row1_col1:
181
+ caption = st.text_area("Write your flower description here:", se, height=120)
182
+
183
+
184
+ backend = st.selectbox(
185
+ "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0
186
+ )
187
+
188
+
189
+
190
+ if st.button("Generate", type="primary"):
191
+ with st.spinner("Generating Flower Images..."):
192
+
193
+ imgs = generate_image(caption, 12)
194
+ #ss = st.success("Scores predicted successfully!")
195
+
196
+ with row1_col2:
197
+ st.markdown("Generated Flower Images:")
198
+
199
+ fig, ax = plt.subplots(nrows=3, ncols=4)
200
+ ax = ax.flatten()
201
+
202
+ for idx, ax in enumerate(ax):
203
+ ax.imshow(imgs[idx])
204
+ ax.axis('off')
205
+
206
+ fig.tight_layout()
207
+ st.pyplot(fig)
208
+
209
+
210
+
211
+
212
+ # with row1_col2:
213
+ # img1 = Image.open('./images/t2i/1.jpg')
214
+ # img2 = Image.open('./images/t2i/2.jpg')
215
+ # img3 = Image.open('./images/t2i/3.jpg')
216
+ # img4 = Image.open('./images/t2i/4.jpg')
217
+ # cont = st.container()
218
+ # with cont:
219
+
220
+ # st.write("This is a container with a caption like a button.")
221
+ # col1, col2, col3, col4 = st.columns(4)
222
+ # with col1:
223
+ # st.image(img1, width=128)
224
+ # with col2:
225
+ # st.image(img2, width=128)
226
+ # with col3:
227
+ # st.image(img3, width=128)
228
+ # with col4:
229
+ # st.image(img4, width=128)
230
+
231
+
232
+
233
+
234
+ app()
235
+
236
+ # # Display a footer with links and credits
237
+ st.markdown("---")
238
+ st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).")
239
+ # #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ streamlit==1.21.0
3
+ Pillow
4
+ torch==2.0.1
5
+ numpy
6
+ transformers==4.30.2
7
+ tokenizers==0.13.3
8
+ matplotlib==3.7.1
9
+
run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ streamlit run main.py --server.runOnSave True