import numpy as np import torch from skimage import io, transform, color from torch.utils.data import Dataset class SalObjDataset(Dataset): def __init__(self, imgs_list, lbl_name_list, transform=None): self.imgs_list = imgs_list self.label_name_list = lbl_name_list self.transform = transform def __len__(self): return len(self.imgs_list) def __getitem__(self, idx): image = np.array(self.imgs_list[idx]) imidx = np.array([idx]) if (0 == len(self.label_name_list)): label_3 = np.zeros(image.shape) else: label_3 = io.imread(self.label_name_list[idx]) label = np.zeros(label_3.shape[0:2]) if (3 == len(label_3.shape)): label = label_3[:, :, 0] elif (2 == len(label_3.shape)): label = label_3 if (3 == len(image.shape) and 2 == len(label.shape)): label = label[:, :, np.newaxis] elif (2 == len(image.shape) and 2 == len(label.shape)): image = image[:, :, np.newaxis] label = label[:, :, np.newaxis] sample = {'imidx': imidx, 'image': image, 'label': label} if self.transform: sample = self.transform(sample) return sample['image'] class RescaleT(object): def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_size def __call__(self, sample): imidx, image, label = sample['imidx'], sample['image'], sample['label'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size * h / w, self.output_size else: new_h, new_w = self.output_size, self.output_size * w / h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] # img = transform.resize(image,(new_h,new_w),mode='constant') # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) img = transform.resize(image, (self.output_size, self.output_size), mode='constant') lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0, preserve_range=True) return {'imidx': imidx, 'image': img, 'label': lbl} class ToTensorLab(object): """Convert ndarrays in sample to Tensors.""" def __init__(self, flag=0): self.flag = flag def __call__(self, sample): imidx, image, label = sample['imidx'], sample['image'], sample['label'] tmpLbl = np.zeros(label.shape) if (np.max(label) < 1e-6): label = label else: label = label / np.max(label) # change the color space if self.flag == 2: # with rgb and Lab colors tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) if image.shape[2] == 1: tmpImgt[:, :, 0] = image[:, :, 0] tmpImgt[:, :, 1] = image[:, :, 0] tmpImgt[:, :, 2] = image[:, :, 0] else: tmpImgt = image tmpImgtl = color.rgb2lab(tmpImgt) # nomalize image to range [0,1] tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])) tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])) tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])) tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])) tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])) tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3]) tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4]) tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5]) elif self.flag == 1: # with Lab color tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) if image.shape[2] == 1: tmpImg[:, :, 0] = image[:, :, 0] tmpImg[:, :, 1] = image[:, :, 0] tmpImg[:, :, 2] = image[:, :, 0] else: tmpImg = image tmpImg = color.rgb2lab(tmpImg) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])) tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])) tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])) tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) else: # with rgb color tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) image = image / np.max(image) if image.shape[2] == 1: tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 else: tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 tmpLbl[:, :, 0] = label[:, :, 0] tmpImg = tmpImg.transpose((2, 0, 1)) tmpLbl = label.transpose((2, 0, 1)) return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}