Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created August 27, 2020 15:35
Show Gist options
  • Save yukw777/085b799ac11570e589a1ef27b463b366 to your computer and use it in GitHub Desktop.
Save yukw777/085b799ac11570e589a1ef27b463b366 to your computer and use it in GitHub Desktop.
PyTorch Lightning Data Module Example
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_data_dir: str,
val_data_dir: str,
test_data_dir: str,
train_dataloader_conf: Optional[DictConfig] = None,
val_dataloader_conf: Optional[DictConfig] = None,
test_dataloader_conf: Optional[DictConfig] = None,
):
super().__init__()
self.train_filenames = glob.glob(
os.path.join(to_absolute_path(train_data_dir), "*.gz")
)
self.val_filenames = glob.glob(
os.path.join(to_absolute_path(val_data_dir), "*.gz")
)
self.test_filenames = glob.glob(
os.path.join(to_absolute_path(test_data_dir), "*.gz")
)
self.train_dataloader_conf = train_dataloader_conf or OmegaConf.create()
self.val_dataloader_conf = val_dataloader_conf or OmegaConf.create()
self.test_dataloader_conf = test_dataloader_conf or OmegaConf.create()
# no need for prepare_data
def setup(self, stage: Optional[str] = None):
if stage == "fit" or stage is None:
self.train = Dataset(self.train_filenames, transform=True)
self.val = Dataset(self.val_filenames)
if stage == "test" or stage is None:
self.test = Dataset(self.test_filenames)
def train_dataloader(self):
return DataLoader(self.train, **self.train_dataloader_conf)
def val_dataloader(self):
return DataLoader(self.val, **self.val_dataloader_conf)
def test_dataloader(self):
return DataLoader(self.test, **self.test_dataloader_conf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment