Created
June 14, 2021 05:42
-
-
Save mjhong0708/614fc5c4996f780d62e518a53996caa9 to your computer and use it in GitHub Desktop.
PINN torch code snippets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CollateTensors: | |
""" | |
Pads and collates multidimensional tensors with variable length in a batch, according to maximum length. | |
Adapted from https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418/8 | |
""" | |
def __init__(self, dim=0): | |
""" | |
:param dim: the dimension to be padded (dimension of time in sequences) | |
""" | |
self.dim = dim | |
@staticmethod | |
def pad_tensor(vec, pad, dim): | |
""" | |
Pads tensor with zeros. Adapted from https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418/8 | |
:param vec: tensor to pad | |
:param pad: the size to pad to | |
:param dim: dimension to pad | |
:return: padded tensor | |
""" | |
pad_size = list(vec.shape) | |
pad_size[dim] = pad - vec.size(dim) | |
return torch.cat([vec, torch.zeros(*pad_size)], dim=dim) | |
def pad_and_collate(self, batch): | |
""" | |
Collates batch of data. | |
:param batch: list of tuples (tensor, label) | |
:return: tuple of padded batch data and batch labels | |
""" | |
# find longest sequence | |
max_len = max([t[0].shape[0] for t in batch]) | |
# pad according to max_len | |
padded_tensors = [self.pad_tensor(t[0], max_len, self.dim) for t in batch] | |
# stack all | |
xs = torch.stack(padded_tensors, dim=0) | |
ys = torch.tensor([t[1] for t in batch]) | |
return xs, ys | |
def __call__(self, batch): | |
return self.pad_and_collate(batch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage: