Skip to content

Instantly share code, notes, and snippets.

@sclark39
Last active September 28, 2021 19:09
Show Gist options
  • Save sclark39/7e8e434e92e85ffb9d60906aa9d6a538 to your computer and use it in GitHub Desktop.
Save sclark39/7e8e434e92e85ffb9d60906aa9d6a538 to your computer and use it in GitHub Desktop.
Function to take in a tensor and a desired set of named dimensions and then squeeze, unsqueeze and permute to match
def NTremix( tensor, target_dims, align='left', squeeze=True ):
# remove extra dimensions
if ( squeeze ):
tensor = tensor.squeeze()
# add missing dimensions
dim0 = tensor.names[0]
target_set = set(target_dims)
missing = [(dim0,tensor.size(dim0))]
missing += [(v,1) for v in (target_set - set(tensor.names)) if v != Ellipsis]
tensor = tensor.unflatten( dim0, missing )
# align to target
if (Ellipsis not in target_set):
if align == 'left':
target_dims = list(target_dims) + [...]
elif align == 'right':
target_dims = [...] + list(target_dims)
tensor = tensor.align_to( *target_dims )
return tensor
def NTstack(tensors, name, dim=0 ):
names = set(tensors[0].names)
tensors[0].names = None
for t in tensors[1:]:
if ( names != set(t.names) ):
print('Failed')
t.names = None
out = torch.stack( tensors, dim=dim )
names = list(names)
names.insert(dim, name)
out.names = names
return out
def NTsum(tensor, name, keepdim=False):
dim = tensor.names.index(name)
return tensor.sum(dim, keepdim)
class RemixTensor(nn.Module):
def __init__(self, target_dims, align='left', squeeze=True ):
super().__init__()
self.target_dims = target_dims
self.align = align
self.squeeze = squeeze
def forward(self, tensor):
return NTremix(tensor, self.target_dims, self.align, self.squeeze)
a = torch.rand(1,3,5,7, names=('batch','channel','x','y'))
print(a.shape) # torch.Size([1, 3, 5, 7])
a = NTremix( a, ('channel','wave','x','y','z') )
print(a.shape) # torch.Size([3, 1, 5, 7, 1])
print(a.names) # ('channel', 'wave', 'x', 'y', 'z')
b = NTremix( a, ('channel', ..., 'x') )
print(b.shape) # torch.Size([3, 7, 5])
print(b.names) # ('channel', 'y', 'x'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment