|
import os |
|
import json |
|
import math |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, IterableDataset |
|
import torchvision.transforms.functional as TF |
|
|
|
import pytorch_lightning as pl |
|
|
|
import datasets |
|
from models.ray_utils import get_ray_directions |
|
from utils.misc import get_rank |
|
|
|
|
|
class BlenderDatasetBase(): |
|
def setup(self, config, split): |
|
self.config = config |
|
self.split = split |
|
self.rank = get_rank() |
|
|
|
self.has_mask = True |
|
self.apply_mask = True |
|
|
|
with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f: |
|
meta = json.load(f) |
|
|
|
if 'w' in meta and 'h' in meta: |
|
W, H = int(meta['w']), int(meta['h']) |
|
else: |
|
W, H = 800, 800 |
|
|
|
if 'img_wh' in self.config: |
|
w, h = self.config.img_wh |
|
assert round(W / w * h) == H |
|
elif 'img_downscale' in self.config: |
|
w, h = W // self.config.img_downscale, H // self.config.img_downscale |
|
else: |
|
raise KeyError("Either img_wh or img_downscale should be specified.") |
|
|
|
self.w, self.h = w, h |
|
self.img_wh = (self.w, self.h) |
|
|
|
self.near, self.far = self.config.near_plane, self.config.far_plane |
|
|
|
self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) |
|
|
|
|
|
self.directions = \ |
|
get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) |
|
|
|
self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] |
|
|
|
for i, frame in enumerate(meta['frames']): |
|
c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4]) |
|
self.all_c2w.append(c2w) |
|
|
|
img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") |
|
img = Image.open(img_path) |
|
img = img.resize(self.img_wh, Image.BICUBIC) |
|
img = TF.to_tensor(img).permute(1, 2, 0) |
|
|
|
self.all_fg_masks.append(img[..., -1]) |
|
self.all_images.append(img[...,:3]) |
|
|
|
self.all_c2w, self.all_images, self.all_fg_masks = \ |
|
torch.stack(self.all_c2w, dim=0).float().to(self.rank), \ |
|
torch.stack(self.all_images, dim=0).float().to(self.rank), \ |
|
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank) |
|
|
|
|
|
class BlenderDataset(Dataset, BlenderDatasetBase): |
|
def __init__(self, config, split): |
|
self.setup(config, split) |
|
|
|
def __len__(self): |
|
return len(self.all_images) |
|
|
|
def __getitem__(self, index): |
|
return { |
|
'index': index |
|
} |
|
|
|
|
|
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): |
|
def __init__(self, config, split): |
|
self.setup(config, split) |
|
|
|
def __iter__(self): |
|
while True: |
|
yield {} |
|
|
|
|
|
@datasets.register('blender') |
|
class BlenderDataModule(pl.LightningDataModule): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
def setup(self, stage=None): |
|
if stage in [None, 'fit']: |
|
self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split) |
|
if stage in [None, 'fit', 'validate']: |
|
self.val_dataset = BlenderDataset(self.config, self.config.val_split) |
|
if stage in [None, 'test']: |
|
self.test_dataset = BlenderDataset(self.config, self.config.test_split) |
|
if stage in [None, 'predict']: |
|
self.predict_dataset = BlenderDataset(self.config, self.config.train_split) |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def general_loader(self, dataset, batch_size): |
|
sampler = None |
|
return DataLoader( |
|
dataset, |
|
num_workers=os.cpu_count(), |
|
batch_size=batch_size, |
|
pin_memory=True, |
|
sampler=sampler |
|
) |
|
|
|
def train_dataloader(self): |
|
return self.general_loader(self.train_dataset, batch_size=1) |
|
|
|
def val_dataloader(self): |
|
return self.general_loader(self.val_dataset, batch_size=1) |
|
|
|
def test_dataloader(self): |
|
return self.general_loader(self.test_dataset, batch_size=1) |
|
|
|
def predict_dataloader(self): |
|
return self.general_loader(self.predict_dataset, batch_size=1) |
|
|