Skip to content

Instantly share code, notes, and snippets.

@mjhong0708
Created June 14, 2021 05:42
Show Gist options
  • Save mjhong0708/614fc5c4996f780d62e518a53996caa9 to your computer and use it in GitHub Desktop.
Save mjhong0708/614fc5c4996f780d62e518a53996caa9 to your computer and use it in GitHub Desktop.
PINN torch code snippets
class CollateTensors:
"""
Pads and collates multidimensional tensors with variable length in a batch, according to maximum length.
Adapted from https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418/8
"""
def __init__(self, dim=0):
"""
:param dim: the dimension to be padded (dimension of time in sequences)
"""
self.dim = dim
@staticmethod
def pad_tensor(vec, pad, dim):
"""
Pads tensor with zeros. Adapted from https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418/8
:param vec: tensor to pad
:param pad: the size to pad to
:param dim: dimension to pad
:return: padded tensor
"""
pad_size = list(vec.shape)
pad_size[dim] = pad - vec.size(dim)
return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)
def pad_and_collate(self, batch):
"""
Collates batch of data.
:param batch: list of tuples (tensor, label)
:return: tuple of padded batch data and batch labels
"""
# find longest sequence
max_len = max([t[0].shape[0] for t in batch])
# pad according to max_len
padded_tensors = [self.pad_tensor(t[0], max_len, self.dim) for t in batch]
# stack all
xs = torch.stack(padded_tensors, dim=0)
ys = torch.tensor([t[1] for t in batch])
return xs, ys
def __call__(self, batch):
return self.pad_and_collate(batch)
@mjhong0708
Copy link
Author

Usage:

from torch.utils.data import DataLoader

dataset = ... # Dataset object
dataloader = DataLoader(dataset, batch_size=16, collate_fn=CollateTensors(dim=0))

for batch_data, batch_labels in dataloader:
    ... # do some train

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment