Spaces:
Running
Running
import numpy as np | |
import torch | |
from skimage import io, transform, color | |
from torch.utils.data import Dataset | |
class SalObjDataset(Dataset): | |
def __init__(self, imgs_list, lbl_name_list, transform=None): | |
self.imgs_list = imgs_list | |
self.label_name_list = lbl_name_list | |
self.transform = transform | |
def __len__(self): | |
return len(self.imgs_list) | |
def __getitem__(self, idx): | |
image = np.array(self.imgs_list[idx]) | |
imidx = np.array([idx]) | |
if (0 == len(self.label_name_list)): | |
label_3 = np.zeros(image.shape) | |
else: | |
label_3 = io.imread(self.label_name_list[idx]) | |
label = np.zeros(label_3.shape[0:2]) | |
if (3 == len(label_3.shape)): | |
label = label_3[:, :, 0] | |
elif (2 == len(label_3.shape)): | |
label = label_3 | |
if (3 == len(image.shape) and 2 == len(label.shape)): | |
label = label[:, :, np.newaxis] | |
elif (2 == len(image.shape) and 2 == len(label.shape)): | |
image = image[:, :, np.newaxis] | |
label = label[:, :, np.newaxis] | |
sample = {'imidx': imidx, 'image': image, 'label': label} | |
if self.transform: | |
sample = self.transform(sample) | |
return sample['image'] | |
class RescaleT(object): | |
def __init__(self, output_size): | |
assert isinstance(output_size, (int, tuple)) | |
self.output_size = output_size | |
def __call__(self, sample): | |
imidx, image, label = sample['imidx'], sample['image'], sample['label'] | |
h, w = image.shape[:2] | |
if isinstance(self.output_size, int): | |
if h > w: | |
new_h, new_w = self.output_size * h / w, self.output_size | |
else: | |
new_h, new_w = self.output_size, self.output_size * w / h | |
else: | |
new_h, new_w = self.output_size | |
new_h, new_w = int(new_h), int(new_w) | |
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] | |
# img = transform.resize(image,(new_h,new_w),mode='constant') | |
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) | |
img = transform.resize(image, (self.output_size, self.output_size), mode='constant') | |
lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0, | |
preserve_range=True) | |
return {'imidx': imidx, 'image': img, 'label': lbl} | |
class ToTensorLab(object): | |
"""Convert ndarrays in sample to Tensors.""" | |
def __init__(self, flag=0): | |
self.flag = flag | |
def __call__(self, sample): | |
imidx, image, label = sample['imidx'], sample['image'], sample['label'] | |
tmpLbl = np.zeros(label.shape) | |
if (np.max(label) < 1e-6): | |
label = label | |
else: | |
label = label / np.max(label) | |
# change the color space | |
if self.flag == 2: # with rgb and Lab colors | |
tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) | |
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) | |
if image.shape[2] == 1: | |
tmpImgt[:, :, 0] = image[:, :, 0] | |
tmpImgt[:, :, 1] = image[:, :, 0] | |
tmpImgt[:, :, 2] = image[:, :, 0] | |
else: | |
tmpImgt = image | |
tmpImgtl = color.rgb2lab(tmpImgt) | |
# nomalize image to range [0,1] | |
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( | |
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])) | |
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( | |
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])) | |
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( | |
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])) | |
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( | |
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])) | |
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( | |
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])) | |
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( | |
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])) | |
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) | |
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) | |
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) | |
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) | |
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3]) | |
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4]) | |
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5]) | |
elif self.flag == 1: # with Lab color | |
tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) | |
if image.shape[2] == 1: | |
tmpImg[:, :, 0] = image[:, :, 0] | |
tmpImg[:, :, 1] = image[:, :, 0] | |
tmpImg[:, :, 2] = image[:, :, 0] | |
else: | |
tmpImg = image | |
tmpImg = color.rgb2lab(tmpImg) | |
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) | |
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( | |
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])) | |
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( | |
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])) | |
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( | |
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])) | |
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0]) | |
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1]) | |
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2]) | |
else: # with rgb color | |
tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) | |
image = image / np.max(image) | |
if image.shape[2] == 1: | |
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 | |
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 | |
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 | |
else: | |
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 | |
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 | |
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 | |
tmpLbl[:, :, 0] = label[:, :, 0] | |
tmpImg = tmpImg.transpose((2, 0, 1)) | |
tmpLbl = label.transpose((2, 0, 1)) | |
return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)} | |