Skip to content

Instantly share code, notes, and snippets.

@maharjun
Created May 4, 2023 08:34
Show Gist options
  • Save maharjun/511fce91b641e5f24717099be981260a to your computer and use it in GitHub Desktop.
Save maharjun/511fce91b641e5f24717099be981260a to your computer and use it in GitHub Desktop.
Pytorch Utilities
"""
Utility functions for PyTorch data manipulation and device handling.
This module provides various utility functions for handling PyTorch datasets, tensors, and devices. Functions include:
- Splitting datasets into train and test sets
- Concatenating multiple datasets
- Generating random batches from a dataset
- Converting data to a specific device
- Retrieving the default device
- Getting GPU device names
- Getting the name of a specific device
"""
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2023, maharjun
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
###############################################################################
from typing import List
import torch
def train_test_data_split(dataset: torch.utils.data.Dataset, train_fraction: float, generator: torch.Generator):
"""
Splits the given dataset into train and test sets based on the specified train_fraction.
Parameters
----------
dataset: torch.utils.data.Dataset
Dataset to be split.
train_fraction: float
Fraction of the dataset to be used for training. Should be between 0 and 1.
generator: torch.Generator
Random number generator.
Returns
-------
tuple
A tuple containing two datasets, the train and test sets.
"""
n_all = len(dataset)
shuffle_inds = torch.randperm(n_all, generator=generator)
n_train = int(n_all * train_fraction)
return (dataset.__class__(*dataset[shuffle_inds[:n_train]]),
dataset.__class__(*dataset[shuffle_inds[n_train:]]))
def concatenate_data(datasets: List[torch.utils.data.Dataset]):
"""
Concatenates the given list of datasets.
Parameters
----------
datasets: List[torch.utils.data.Dataset]
List of datasets to concatenate.
Returns
-------
torch.utils.data.Dataset
The concatenated dataset.
"""
assert len(datasets) > 0, "Atleast one dataset should be given to concatenate"
all_data_tuples = [dset[:] for dset in datasets]
data_tuple_len = len(all_data_tuples[0])
cat_data_tuple = tuple(torch.cat([x[i] for x in all_data_tuples])
for i in range(data_tuple_len))
return datasets[0].__class__(*cat_data_tuple)
def random_batch_input(dataset: torch.utils.data.Dataset, batch_size: int, generator: torch.Generator):
"""
Generator that yields batches of data from the given dataset with random shuffling.
This is a generator that does something that torch can't: namely shuffle each epoch while ensuring samples don't
overlap over epochs while maintaining a constant batch size.
Parameters
----------
dataset: torch.utils.data.Dataset
Dataset to generate batches from.
batch_size: int
Number of samples per batch.
generator: torch.Generator
Random number generator.
Yields
------
tuple
A tuple containing a data batch and a boolean flag indicating if the epoch has ended.
"""
assert len(dataset) > 0, "Must specify at-least one tensor to batch"
num_data = len(dataset)
assert batch_size <= num_data, "The batch size must be less than or equal to the size of the dataset"
current_cursor = 0
shuffle_inds = torch.randperm(num_data, device=generator.device, generator=generator)
epoch_ended = False
while True:
end_cursor = current_cursor + batch_size
end_cursor_first = min(end_cursor, num_data)
batch_indices = shuffle_inds[current_cursor:end_cursor_first]
if end_cursor >= num_data:
# A piece of logic that ensures that the same elements in x_data are not taken again
perm1 = torch.randperm(current_cursor, device=generator.device, generator=generator)
new_shuffle_inds = shuffle_inds.detach().clone()
new_shuffle_inds[:current_cursor] = new_shuffle_inds[:current_cursor][perm1]
perm2 = torch.randperm(2*num_data - end_cursor, device=generator.device, generator=generator)
new_shuffle_inds[end_cursor-num_data:] = new_shuffle_inds[end_cursor-num_data:][perm2]
batch_indices_second = new_shuffle_inds[:end_cursor-num_data]
batch_indices = torch.cat([batch_indices, batch_indices_second], dim=0)
shuffle_inds = new_shuffle_inds
epoch_ended = True
data_batch = dataset[batch_indices]
yield data_batch, epoch_ended
epoch_ended = False
current_cursor = end_cursor % num_data
def convert_data_to_device(data, device: torch.device):
"""
Converts the data (dict or tensor) to the specified device.
Parameters
----------
data: dict or torch.Tensor
Data to be converted, can be a dict or tensor.
device: torch.device
Target device.
Returns
-------
Data converted to the target device.
"""
if isinstance(data, dict):
return {key: convert_data_to_device(val, device) for key, val in data.items()}
elif torch.is_tensor(data):
return data.to(device=device)
else:
raise TypeError("Require either tensor or dict to convert")
def get_default_device():
"""
Returns the default device available for PyTorch.
Returns
-------
torch.device
Default device for PyTorch.
"""
return torch.as_tensor([0., 1.0]).device
def get_gpu_name_if_available(gpu_index=None):
"""
Returns the GPU device name if available, otherwise returns 'cpu'.
Parameters
----------
gpu_index: int, optional
Index of the GPU device to use, if available.
Returns
-------
str
GPU device name or 'cpu' if GPU is not available.
"""
use_cuda = torch.cuda.is_available()
if use_cuda and gpu_index is not None:
device_name = 'cuda:{}'.format(gpu_index)
elif use_cuda:
device_name = 'cuda'
else:
device_name = 'cpu'
return device_name
def get_device_name(device: torch.device):
"""
Returns the name of the specified device object.
Parameters
----------
device: torch.device
The device object.
Returns
-------
str
The name of the device.
"""
assert isinstance(device, torch.device), "device must be a torch.device"
if device.index:
return f'{device.type}:{device.index}'
else:
return device.type
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment