Skip to content

Instantly share code, notes, and snippets.

@zxteloiv
Created January 12, 2021 10:28
Show Gist options
  • Save zxteloiv/edd26db881ced300f3e0032d663ef201 to your computer and use it in GitHub Desktop.
Save zxteloiv/edd26db881ced300f3e0032d663ef201 to your computer and use it in GitHub Desktop.
Convert nested number lists in to pytorch tensors, which will be useful for batching a bunch of tensors, during text data processing.
from collections import defaultdict
import torch
def _nested_number_list_to_tensors(nested: list, padding=0, example=None):
"""Turn a list of list of list of list ... of integers to a tensor with the given padding"""
ndim_max = defaultdict(lambda: 0)
def _count_nested_max(nested, depth):
if not isinstance(nested, list):
return
ndim_max[depth] = max(ndim_max[depth], len(nested))
for x in nested:
_count_nested_max(x, depth + 1)
_count_nested_max(nested, 0)
ndim_max = [ndim_max[d] for d in sorted(ndim_max.keys())]
def _get_padding_at_depth(depth):
size = ndim_max[depth:]
lump = padding
for i in reversed(size):
lump = [lump] * i
return lump
def _pad_nested(nested, depth):
if not isinstance(nested, list):
return nested
if len(nested) < ndim_max[depth]:
nested = nested + [_get_padding_at_depth(depth + 1)] * (ndim_max[depth] - len(nested))
return [_pad_nested(x, depth + 1) for x in nested]
full_fledged = _pad_nested(nested, 0)
dev = dtype = None
if example is not None:
dev = example.device
dtype = example.dtype
return torch.tensor(full_fledged, device=dev, dtype=dtype)
if __name__ == '__main__':
for case in (
[[[], []], [[3, 3], [1, 2, 3]], [[7], [1, 3], [9, 9, 9, 9]]],
[0, 1, 2],
[[], [1], [2]],
[[], [1], [2, 3]],
[[1], [2, 3], [3, 3, 3]],
):
t = _nested_number_list_to_tensors(case)
print(t, t.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment