Skip to content

Instantly share code, notes, and snippets.

@Quentin18
Created March 16, 2022 17:26
Show Gist options
  • Save Quentin18/412fc0a66d92fe3323cb093d5ffdf0bc to your computer and use it in GitHub Desktop.
Save Quentin18/412fc0a66d92fe3323cb093d5ffdf0bc to your computer and use it in GitHub Desktop.
Split dataset function for PyTorch
from typing import Optional, Tuple
import torch
from torch.utils.data import Dataset, random_split
def train_test_split(
dataset: Dataset,
test_ratio: float,
seed: Optional[int] = None,
) -> Tuple[Dataset, Dataset]:
"""Splits a dataset into random train and test subsets.
Args:
dataset (Dataset): dataset.
test_ratio (float): test proportion (between 0 and 1).
seed (int, optional): seed. Defaults to None.
Returns:
Tuple[Dataset, Dataset]: train and test datasets.
"""
# Define generator
generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)
# Define lengths of subsets
train_ratio = 1 - test_ratio
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
lengths = [train_size, test_size]
# Split
train_dataset, test_dataset = random_split(dataset, lengths, generator)
return train_dataset, test_dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment