Skip to content

Instantly share code, notes, and snippets.

@CharlesJQuarra
Created August 16, 2018 15:56
Show Gist options
  • Save CharlesJQuarra/40751a6301084db2bf35e35c0ccb3369 to your computer and use it in GitHub Desktop.
Save CharlesJQuarra/40751a6301084db2bf35e35c0ccb3369 to your computer and use it in GitHub Desktop.
broadcastable version of `torch.cat`
import torch
""""
behavior:
fat_cat([torch.randn(1,7,20), torch.randn(5,1,13)], dim=-1).size() == torch.Size([5, 7, 33])
""""
def axis_repeat(t, dim, times):
if t.size()[dim] != 1:
raise Exception("dimension {0} of tensor of shape {1} is non-singleton, cannot repeat".format(dim, t.size()))
return torch.cat(times * [t], dim)
def fat_cat(tensor_list, dim=0):
shapes = []
for t in tensor_list:
shapes.append(list(t.size()))
shape_mat = torch.tensor(shapes).transpose(0,1)
reshaped_tensors = tensor_list
nb_dims = shape_mat.size()[0]
for d in range(nb_dims):
if d == dim % nb_dims:
continue
tensor_dims = shape_mat[d]
non_singleton_dim = None
first_nonsingleton_tensor = None
singleton_dims = []
for t in range(tensor_dims.size()[0]):
tensor_dim = tensor_dims[t].item()
if tensor_dim != 1:
if non_singleton_dim is None:
non_singleton_dim = tensor_dim
first_nonsingleton_tensor = t
if non_singleton_dim != tensor_dim:
raise Exception("dimension {0} of {1}th tensor of shape {2} does not match non-singleton dimension of {3}th tensor of shape {4}".format(d, t, tensor_list[t].size(), first_nonsingleton_tensor, tensor_list[first_nonsingleton_tensor].size()))
else:
singleton_dims.append(t)
if non_singleton_dim is not None:
for sd in singleton_dims:
def reshape_tensor_idx(i_, rt_):
if i_ == sd:
return axis_repeat(rt_, d, non_singleton_dim)
return rt_
reshaped_tensors = [reshape_tensor_idx(idx, rt) for idx, rt in enumerate(reshaped_tensors)]
return torch.cat(reshaped_tensors, dim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment