Skip to content

Instantly share code, notes, and snippets.

@tonyf
Last active June 3, 2024 18:14
Show Gist options
  • Save tonyf/2a652cc9fa525e79dff5711f6c353886 to your computer and use it in GitHub Desktop.
Save tonyf/2a652cc9fa525e79dff5711f6c353886 to your computer and use it in GitHub Desktop.
LightningDataModule + TorchData DataPipe & DataLoader2
from typing import Any, Dict
import lightning as L
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import Adapter
from torchdata.dataloader2.reading_service import (
ReadingServiceInterface,
)
from torchdata.datapipes.iter import IterDataPipe
class IterDataPipeDataModule(L.LightningDataModule):
@classmethod
def from_datasets(
cls,
train_dataset: IterDataPipe,
val_dataset: IterDataPipe,
train_reading_service: ReadingServiceInterface,
val_reading_service: ReadingServiceInterface,
adapters: list[Adapter] = [],
**datamodule_kwargs: Any,
) -> "IterDataPipeDataModule":
def train_dataloader():
return DataLoader2(
train_dataset,
datapipe_adapter_fn=adapters,
reading_service=train_reading_service,
)
def val_dataloader():
return DataLoader2(
val_dataset,
datapipe_adapter_fn=adapters,
reading_service=val_reading_service,
)
datamodule = cls(**datamodule_kwargs)
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
return datamodule
def state_dict(self) -> Dict[str, Any]:
state_dict_ = {}
if (train := self.train_dataloader()) is not None:
state_dict_["train"] = train.state_dict()
if (val := self.val_dataloader()) is not None:
state_dict_["val"] = val.state_dict()
return state_dict_
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if "train" in state_dict and (train := self.train_dataloader()) is not None:
train.load_state_dict(state_dict["train"])
if "val" in state_dict and (val := self.val_dataloader()):
val.load_state_dict(state_dict["val"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment