Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Forked from ptrblck/numpy split vs PyTorch split
Last active March 18, 2023 19:45
Show Gist options
  • Save Ryu1845/f15c4ad86ba234d6535b7a8fdd530557 to your computer and use it in GitHub Desktop.
Save Ryu1845/f15c4ad86ba234d6535b7a8fdd530557 to your computer and use it in GitHub Desktop.
numpy split vs PyTorch split
import torch
import numpy as np
# numpy
a = np.random.rand(10, 20)
tmp0 = np.split(a, indices_or_sections=5, axis=0) # split into 5 sections
for t in tmp0:
print(t.shape)
# (2, 20)
# (2, 20)
# (2, 20)
# (2, 20)
# (2, 20)
np.split(a, indices_or_sections=7, axis=0) # error, since no equal division
tmp1 = np.split(a, [5, 7], 0) # use indices ([:5], [5:7], [7:])
for t in tmp1:
print(t.shape)
# PyTorch
x = torch.randn(10, 20)
tmp2 = torch.split(x, split_size_or_sections=4, dim=0) # use size 4
for t in tmp2:
print(t.shape) # last split might be smaller
tmp3 = torch.split(x, split_size_or_sections=[5, 2, 3], dim=0)
for t in tmp3:
print(t.shape)
torch.split(x, split_size_or_sections=[5, 4], dim=0) # error, since 5+4 != dim(0)
# Should it return Tensors of size [5, 20], [4, 20] and [1, 20]?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment