Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created November 20, 2024 04:37
Show Gist options
  • Save AmosLewis/f1b0dcad2c9ee2381f19d5836df952ff to your computer and use it in GitHub Desktop.
Save AmosLewis/f1b0dcad2c9ee2381f19d5836df952ff to your computer and use it in GitHub Desktop.
import torch
def nonzero(t):
print("t: ", t) # tensor([0, 0, 0, 1, 0, 0])
# Flatten the input tensor
original_shape = t.shape
print("original_shape: ", original_shape) # torch.Size([6])
t_flat = t.flatten()
print("t_flat: ", t_flat) # tensor([0, 0, 0, 1, 0, 0])
nonzero_mask = (t_flat != 0)
nonzero_mask = nonzero_mask.int()
print("nonzero_mask: ", nonzero_mask) # tensor([0, 0, 0, 1, 0, 0], dtype=torch.int32)
destination_indices = torch.cumsum(nonzero_mask, 0) - 1
print("destination_indices: ", destination_indices)
destination_indices_clamp = torch.clamp(destination_indices, min=0) # tensor([-1, -1, -1, 0, 0, 0])
print("destination_indices_clamp: ", destination_indices) # tensor([-1, -1, -1, 0, 0, 0])
iota = torch.arange(len(t_flat), device=t.device) * nonzero_mask
print("iota: ", iota) # tensor([0, 0, 0, 3, 0, 0])
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
print("scatter_self: ", scatter_self) # tensor([0, 0, 0, 0, 0, 0])
compacted = scatter_self.scatter_(
dim=0,
index=destination_indices_clamp,
src=iota,
reduce='add'
)
print("compacted: ", compacted) # tensor([3, 0, 0, 0, 0, 0])
result_flat = compacted[:torch.sum(nonzero_mask)]
print("result_flat: ", result_flat) # tensor([3])
# Convert flattened indices back to multi-dimensional indices using PyTorch operations
dims = torch.tensor(original_shape, device=t.device)
print("dims: ", dims) # tensor([6])
strides = torch.cumprod(torch.flip(dims, [0]), 0).flip(0)
print("strides: ", strides) # tensor([6])
strides = torch.cat([strides[1:], torch.tensor([1], device=t.device)])
print("strides: ", strides) # tensor([1])
multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % dims
print("multi_indices: ", multi_indices) # tensor([[3]])
return multi_indices
def test(a):
a = torch.tensor(a)
myout = nonzero(a)
ptout = torch.nonzero(a)
print("myout: ", myout) # tensor([[3]])
print("ptout: ", ptout) # tensor([[3]])
myout_reshaped = myout.reshape(ptout.shape)
print("myout_reshaped: ", myout_reshaped) # tensor([[3]])
return myout_reshaped
test([0,0,0,1,0,0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment