AudioSep / data /datamodules.py
badayvedat's picture
Initial commit
ae29df4
raw
history blame contribute delete
No virus
3.41 kB
from typing import Dict, List, Optional, NoReturn
import torch
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from data.audiotext_dataset import AudioTextDataset
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_dataset: object,
batch_size: int,
num_workers: int
):
r"""Data module. To get one batch of data:
code-block:: python
data_module.setup()
for batch_data_dict in data_module.train_dataloader():
print(batch_data_dict.keys())
break
Args:
train_sampler: Sampler object
train_dataset: Dataset object
num_workers: int
distributed: bool
"""
super().__init__()
self._train_dataset = train_dataset
self.num_workers = num_workers
self.batch_size = batch_size
self.collate_fn = collate_fn
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
pass
def setup(self, stage: Optional[str] = None) -> NoReturn:
r"""called on every device."""
# make assignments here (val/train/test split)
# called on every process in DDP
# SegmentSampler is used for selecting segments for training.
# On multiple devices, each SegmentSampler samples a part of mini-batch
# data.
self.train_dataset = self._train_dataset
def train_dataloader(self) -> torch.utils.data.DataLoader:
r"""Get train loader."""
train_loader = DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=False,
shuffle=True
)
return train_loader
def val_dataloader(self):
# val_split = Dataset(...)
# return DataLoader(val_split)
pass
def test_dataloader(self):
# test_split = Dataset(...)
# return DataLoader(test_split)
pass
def teardown(self):
# clean up after fit or test
# called on every process in DDP
pass
def collate_fn(list_data_dict):
r"""Collate mini-batch data to inputs and targets for training.
Args:
list_data_dict: e.g., [
{
'text': 'a sound of dog',
'waveform': (1, samples),
'modality': 'audio_text'
}
...
]
Returns:
data_dict: e.g.
'audio_text': {
'text': ['a sound of dog', ...]
'waveform': (batch_size, 1, samples)
}
"""
at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text']
at_data_dict = {}
if len(at_list_data_dict) > 0:
for key in at_list_data_dict[0].keys():
at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict]
if key == 'waveform':
at_data_dict[key] = torch.stack(at_data_dict[key])
elif key == 'text':
at_data_dict[key] = [text for text in at_data_dict[key]]
data_dict = {
'audio_text': at_data_dict
}
return data_dict