Skip to content

Instantly share code, notes, and snippets.

@ashleve
Last active April 24, 2024 10:54
Show Gist options
  • Star 26 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save ashleve/ac511f08c0d29e74566900fd3efbb3ec to your computer and use it in GitHub Desktop.
Save ashleve/ac511f08c0d29e74566900fd3efbb3ec to your computer and use it in GitHub Desktop.
Example of k-fold cross validation with PyTorch Lightning Datamodule
from pytorch_lightning import LightningDataModule
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold
class ProteinsKFoldDataModule(LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
k: int = 1, # fold number
split_seed: int = 12345, # split needs to be always the same for correct cross validation
num_splits: int = 10,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=False)
# num_splits = 10 means our dataset will be split to 10 parts
# so we train on 90% of the data and validate on 10%
assert 1 <= self.k <= self.num_splits, "incorrect fold number"
# data transformations
self.transforms = None
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
@property
def num_node_features() -> int:
return 4
@property
def num_classes() -> int:
return 2
def setup(self, stage=None):
if not self.data_train and not self.data_val:
dataset_full = TUDataset(self.hparams.data_dir, name="PROTEINS", use_node_attr=True, transform=self.transforms)
# choose fold to train on
kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
all_splits = [k for k in kf.split(dataset_full)]
train_indexes, val_indexes = all_splits[self.hparams.k]
train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes]
def train_dataloader(self):
return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory, shuffle=True)
def val_dataloader(self):
return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory)
@ashleve
Copy link
Author

ashleve commented Apr 21, 2021

How to use:

results = []
nums_folds = 10
split_seed = 12345

for k in range(nums_folds):
    datamodule = ProteinsKFoldDataModule(k=k, num_folds=num_folds, split_seed=split_seed, ...)
    datamodule.prepare_data()
    datamodule.setup()

    # here we train the model on given split...
    ...

    results.append(score)

score = sum(results) / num_folds

@DavidGomez00
Copy link

DavidGomez00 commented Dec 13, 2023

Hi, in ProteinsKFoldDataModule.init() you define

assert 1 <= self.k <= self.num_splits, "incorrect fold number"

but in the use case you are using for k in range(num_folds): ... where k is starting at 0. This is causing an error during the instantiation of the datamodule. Consider changing the assert value to 0, since using k >= 1 will omit the first split when you get the indexes at train_indexes, val_indexes = all_splits[self.hparams.k].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment