hjc-owo
init repo
966ae59
raw
history blame contribute delete
No virus
7.11 kB
import numpy as np
import torch
from skimage import io, transform, color
from torch.utils.data import Dataset
class SalObjDataset(Dataset):
def __init__(self, imgs_list, lbl_name_list, transform=None):
self.imgs_list = imgs_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.imgs_list)
def __getitem__(self, idx):
image = np.array(self.imgs_list[idx])
imidx = np.array([idx])
if (0 == len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3
if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample['image']
class RescaleT(object):
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0,
preserve_range=True)
return {'imidx': imidx, 'image': img, 'label': lbl}
class ToTensorLab(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self, flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
tmpLbl = np.zeros(label.shape)
if (np.max(label) < 1e-6):
label = label
else:
label = label / np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]))
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]))
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]))
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]))
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]))
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]))
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3])
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4])
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5])
elif self.flag == 1: # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]))
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]))
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
else: # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}