MultiMAE / utils /taskonomy /taskonomy_dataset.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame
No virus
2.37 kB
import os
import pandas as pd
from PIL import Image, ImageFile
from torch.utils.data import Dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
from .transforms import task_transform
class TaskonomyDataset(Dataset):
def __init__(self,
data_root,
tasks,
split='train',
variant='tiny',
image_size=256,
max_images=None):
"""
Taskonomy dataloader.
Args:
data_root: Root of Taskonomy data directory
tasks: List of tasks. Any of ['rgb', 'depth_euclidean', 'depth_zbuffer',
'edge_occlusion', 'edge_texture', 'keypoints2d', 'keypoints3d', 'normal',
'principal_curvature', 'reshading', 'mask_valid'].
split: One of {'train', 'val', 'test'}
variant: One of {'debug', 'tiny', 'medium', 'full', 'fullplus'}
image_size: Target image size
max_images: Optional subset selection
"""
super(TaskonomyDataset, self).__init__()
self.data_root = data_root
self.tasks = tasks
self.split = split
self.variant = variant
self.image_size=image_size
self.max_images = max_images
self.image_ids = pd.read_csv(
os.path.join(os.path.dirname(__file__), 'splits', f'{self.variant}_{self.split}.csv')
).to_numpy()
if isinstance(self.max_images, int):
self.image_ids = self.image_ids[:self.max_images]
print(f'Initialized TaskonomyDataset with {len(self.image_ids)} images from variant {self.variant} in split {self.split}.')
def __len__(self):
return len(self.image_ids)
def __getitem__(self, index):
# building / point / view
building, point, view = self.image_ids[index]
result = {}
for task in self.tasks:
task_id = 'depth_zbuffer' if task == 'mask_valid' else task
path = os.path.join(
self.data_root, task, building, f'point_{point}_view_{view}_domain_{task_id}.png'
)
img = Image.open(path)
# Perform transformations
img = task_transform(img, task=task, image_size=self.image_size)
result[task] = img
return result