tfwang commited on
Commit
d7db483
1 Parent(s): 31b9256

Upload image_datasets_sketch.py

Browse files
glide_text2im/image_datasets_sketch.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ from PIL import Image
5
+ import blobfile as bf
6
+ from mpi4py import MPI
7
+ import numpy as np
8
+ from torch.utils.data import DataLoader, Dataset
9
+ import os
10
+ import torchvision.transforms as transforms
11
+ import torch as th
12
+ from .degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
13
+ from functools import partial
14
+ import cv2
15
+
16
+ from PIL import PngImagePlugin
17
+ LARGE_ENOUGH_NUMBER = 100
18
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
19
+
20
+ def load_data_sketch(
21
+ *,
22
+ data_dir,
23
+ batch_size,
24
+ image_size,
25
+ class_cond=False,
26
+ deterministic=False,
27
+ random_crop=False,
28
+ random_flip=True,
29
+ train=True,
30
+ low_res = 0,
31
+ uncond_p = 0,
32
+ mode = ''
33
+ ):
34
+ """
35
+ For a dataset, create a generator over (images, kwargs) pairs.
36
+
37
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
38
+ more keys, each of which map to a batched Tensor of their own.
39
+ The kwargs dict can be used for class labels, in which case the key is "y"
40
+ and the values are integer tensors of class labels.
41
+
42
+ :param data_dir: a dataset directory.
43
+ :param batch_size: the batch size of each returned pair.
44
+ :param image_size: the size to which images are resized.
45
+ :param class_cond: if True, include a "y" key in returned dicts for class
46
+ label. If classes are not available and this is true, an
47
+ exception will be raised.
48
+ :param deterministic: if True, yield results in a deterministic order.
49
+ :param random_crop: if True, randomly crop the images for augmentation.
50
+ :param random_flip: if True, randomly flip the images for augmentation.
51
+ """
52
+ if not data_dir:
53
+ raise ValueError("unspecified data directory")
54
+ with open(data_dir) as f:
55
+ all_files = f.read().splitlines()
56
+
57
+ print(len(all_files))
58
+ classes = None
59
+ if class_cond:
60
+ # Assume classes are the first part of the filename,
61
+ # before an underscore.
62
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
63
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
64
+ classes = [sorted_classes[x] for x in class_names]
65
+ dataset = ImageDataset(
66
+ image_size,
67
+ all_files,
68
+ classes=classes,
69
+ shard=MPI.COMM_WORLD.Get_rank(),
70
+ num_shards=MPI.COMM_WORLD.Get_size(),
71
+ random_crop=random_crop,
72
+ random_flip=train,
73
+ down_sample_img_size = low_res,
74
+ uncond_p = uncond_p,
75
+ mode = mode,
76
+ )
77
+ if deterministic:
78
+ loader = DataLoader(
79
+ dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True, pin_memory=False
80
+ )
81
+ else:
82
+ loader = DataLoader(
83
+ dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=False
84
+ )
85
+ while True:
86
+ yield from loader
87
+
88
+ def _list_image_files_recursively(data_dir):
89
+ results = []
90
+ for entry in sorted(bf.listdir(data_dir)):
91
+ full_path = bf.join(data_dir, entry)
92
+ ext = entry.split(".")[-1]
93
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
94
+ results.append(full_path)
95
+ elif bf.isdir(full_path):
96
+ results.extend(_list_image_files_recursively(full_path))
97
+ return results
98
+
99
+ class ImageDataset(Dataset):
100
+ def __init__(
101
+ self,
102
+ resolution,
103
+ image_paths,
104
+ classes=None,
105
+ shard=0,
106
+ num_shards=1,
107
+ random_crop=False,
108
+ random_flip=True,
109
+ down_sample_img_size = 0,
110
+ uncond_p = 0,
111
+ mode = '',
112
+ ):
113
+ super().__init__()
114
+ self.crop_size = 256
115
+ self.resize_size = 256
116
+ self.local_images = image_paths[shard:][::num_shards]
117
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
118
+ self.random_crop = random_crop
119
+ self.random_flip = random_flip
120
+
121
+ self.down_sample_img = partial(degradation_fn_bsr_light, sf=resolution//down_sample_img_size) if down_sample_img_size else None
122
+ self.uncond_p = uncond_p
123
+ self.mode = mode
124
+ self.resolution = resolution
125
+
126
+ def __len__(self):
127
+ return len(self.local_images)
128
+
129
+ def __getitem__(self, idx):
130
+ if self.mode == 'coco-edge':
131
+ path = self.local_images[idx].replace('COCO-STUFF', 'COCO-Sketch')[:-4] + '.png'
132
+ path2 = path.replace('_img', '_sketch')
133
+ elif self.mode == 'flickr-edge':
134
+ path = self.local_images[idx].replace('images', 'img256')[:-4] + '.png'
135
+ path2 = path.replace('img256', 'sketch256')
136
+
137
+
138
+ with bf.BlobFile(path, "rb") as f:
139
+ pil_image = Image.open(f)
140
+ pil_image.load()
141
+ pil_image = pil_image.convert("RGB")
142
+
143
+
144
+ with bf.BlobFile(path2, "rb") as f:
145
+ pil_image2 = Image.open(f)
146
+ pil_image2.load()
147
+ pil_image2 = pil_image2.convert("L")
148
+
149
+
150
+ params = get_params(pil_image2.size, self.resize_size, self.crop_size)
151
+ transform_label = get_transform(params, self.resize_size, self.crop_size, method=Image.NEAREST, crop =self.random_crop, flip=self.random_flip)
152
+ label_pil = transform_label(pil_image2)
153
+
154
+ im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
155
+ im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
156
+ im_dist = Image.fromarray(im_dist).convert("RGB")
157
+
158
+ label_tensor = get_tensor()(im_dist)[:1]
159
+ label_tensor_ori = get_tensor()(label_pil.convert('RGB'))
160
+
161
+ transform_image = get_transform( params, self.resize_size, self.crop_size, crop =self.random_crop, flip=self.random_flip)
162
+ image_pil = transform_image(pil_image)
163
+ if self.resolution < 256:
164
+ image_pil = image_pil.resize((self.resolution, self.resolution), Image.BICUBIC)
165
+ image_tensor = get_tensor()(image_pil)
166
+
167
+ if self.down_sample_img:
168
+ image_pil = np.array(image_pil).astype(np.uint8)
169
+ down_sampled_image = self.down_sample_img(image=image_pil)["image"]
170
+ down_sampled_image = get_tensor()(down_sampled_image)
171
+ data_dict = {"ref":label_tensor, "low_res":down_sampled_image, "ref_ori":label_tensor_ori, "path": path}
172
+ return image_tensor, data_dict
173
+
174
+ if random.random() < self.uncond_p:
175
+ label_tensor = th.ones_like(label_tensor)
176
+ data_dict = {"ref":label_tensor, "ref_ori":label_tensor_ori, "path": path}
177
+
178
+ return image_tensor, data_dict
179
+
180
+ def get_params( size, resize_size, crop_size):
181
+ w, h = size
182
+ new_h = h
183
+ new_w = w
184
+
185
+ ss, ls = min(w, h), max(w, h) # shortside and longside
186
+ width_is_shorter = w == ss
187
+ ls = int(resize_size * ls / ss)
188
+ ss = resize_size
189
+ new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
190
+
191
+ x = random.randint(0, np.maximum(0, new_w - crop_size))
192
+ y = random.randint(0, np.maximum(0, new_h - crop_size))
193
+
194
+ flip = random.random() > 0.5
195
+ return {'crop_pos': (x, y), 'flip': flip}
196
+
197
+
198
+ def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True):
199
+ transform_list = []
200
+
201
+ transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method)))
202
+
203
+ if flip:
204
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
205
+
206
+ return transforms.Compose(transform_list)
207
+
208
+ def get_tensor(normalize=True, toTensor=True):
209
+ transform_list = []
210
+ if toTensor:
211
+ transform_list += [transforms.ToTensor()]
212
+
213
+ if normalize:
214
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
215
+ (0.5, 0.5, 0.5))]
216
+ return transforms.Compose(transform_list)
217
+
218
+ def normalize():
219
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
220
+
221
+
222
+ def __scale(img, target_width, method=Image.BICUBIC):
223
+ return img.resize((target_width, target_width), method)
224
+
225
+ def __flip(img, flip):
226
+ if flip:
227
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
228
+ return img