Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active December 13, 2024 18:54
Show Gist options
  • Save AmosLewis/f1b0dcad2c9ee2381f19d5836df952ff to your computer and use it in GitHub Desktop.
Save AmosLewis/f1b0dcad2c9ee2381f19d5836df952ff to your computer and use it in GitHub Desktop.
import torch
def nonzero(t):
print("t: ", t) # tensor([0, 0, 1, 1, 0, 0])
# Flatten the input tensor
t_flat = t.flatten() # torch.flatten(t, 0, 0)
print(
"t_flat: ", t_flat
) # tensortensor([0, 0, 1, 1, 0, 0]), torch.Size([6]), #!torch.vtensor<[?],si64>
nonzero_mask = t_flat != 0
nonzero_mask = nonzero_mask.int()
print(
"nonzero_mask: ", nonzero_mask
) # tensor([0, 0, 1, 1, 0, 0], dtype=torch.int32)
destination_indices = torch.cumsum(nonzero_mask.long(), 0) - 1
print(
"destination_indices: ", destination_indices
) # tensor([-1, -1, 0, 1, 1, 1])
destination_indices_clamp = torch.clamp(destination_indices, min=0)
print(
"destination_indices_clamp: ", destination_indices_clamp
) # tensor([0, 0, 0, 1, 1, 1])
iota = torch.arange(t_flat.size(0)) * nonzero_mask
print("iota: ", iota) # tensor([0, 0, 2, 3, 0, 0])
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
print("scatter_self: ", scatter_self) # tensor([0, 0, 0, 0, 0, 0])
# compacted = scatter_self.scatter_(
# dim=0,
# index=destination_indices_clamp,
# src=iota,
# reduce='add'
# )
compacted = torch.scatter_add(
scatter_self, dim=0, index=destination_indices_clamp, src=iota
)
print("compacted: ", compacted) # tensor([2, 3, 0, 0, 0, 0])
result_flat = compacted[: torch.sum(nonzero_mask)]
print("result_flat: ", result_flat) # tensor([2, 3])
print("result_flat.shape: ", result_flat.shape) # torch.Size([2])
# Convert flattened indices back to multi-dimensional indices using PyTorch operations
original_shape = t.shape
print(
"original_shape: ", original_shape
) # torch.Size([6]) , #!torch.vtensor<[1],si64>
input_shape_tensor = torch.tensor(original_shape)
print("dims: ", input_shape_tensor) # tensor([6])
strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(
0
) #!torch.vtensor<[1],si64>
print("strides: ", strides) # tensor([6])
strides = torch.cat([strides[1:-1], torch.tensor([1])])
print("strides: ", strides) # tensor([1]) !torch.vtensor<[1],si64>
a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1])
c = a // b
# c: tensor([[2], [3]]) torch.Size([2, 1])
multi_indices = c % input_shape_tensor
print("multi_indices: ", multi_indices) # tensor([ [2], [3] ]) torch.Size([2, 1])
return multi_indices
def test(a):
a = torch.tensor(a)
myout = nonzero(a)
ptout = torch.nonzero(a)
print("myout: ", myout) # tensor([ [2], [3] ])
print("myout.shape: ", myout.shape) # torch.Size([2, 1])
print("ptout: ", ptout) # torch.Size([2, 1])
print("ptout.shape: ", ptout.shape) # torch.Size([2, 1])
myout_reshaped = myout.reshape(ptout.shape)
print("myout_reshaped: ", myout_reshaped) # tensor([ [2], [3] ])
return myout_reshaped
test(torch.tensor([0, 0, 1, 1, 0, 0]))
# t = torch.tensor([0, 0, 1, 1, 0, 0])
# t.size(0)
# def nonzero(t):
# print("t: ", t) # tensor([0, 0, 0, 1, 0, 0])
# # Flatten the input tensor
# t_flat = t.flatten() # torch.flatten(t, 0, 0)
# print("t_flat: ", t_flat) # tensor([0, 0, 0, 1, 0, 0]), torch.Size([6]), #!torch.vtensor<[?],si64>
# nonzero_mask = (t_flat != 0)
# nonzero_mask = nonzero_mask.int()
# print("nonzero_mask: ", nonzero_mask) # tensor([0, 0, 0, 1, 0, 0], dtype=torch.int32)
# destination_indices = torch.cumsum(nonzero_mask, 0) - 1
# print("destination_indices: ", destination_indices) # tensor([-1, -1, -1, 0, 0, 0])
# destination_indices_clamp = torch.clamp(destination_indices, min=0)
# print("destination_indices_clamp: ", destination_indices_clamp) # tensor([0, 0, 0, 0, 0, 0])
# iota = torch.arange(len(t_flat)) * nonzero_mask
# print("iota: ", iota) # tensor([0, 0, 0, 3, 0, 0])
# scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
# print("scatter_self: ", scatter_self) # tensor([0, 0, 0, 0, 0, 0])
# compacted = scatter_self.scatter_(
# dim=0,
# index=destination_indices_clamp,
# src=iota,
# reduce='add'
# )
# print("compacted: ", compacted) # tensor([3, 0, 0, 0, 0, 0])
# result_flat = compacted[:torch.sum(nonzero_mask)]
# print("result_flat: ", result_flat) # tensor([3])
# print("result_flat.shape: ", result_flat.shape) # torch.Size([1])
# # Convert flattened indices back to multi-dimensional indices using PyTorch operations
# original_shape = t.shape
# print("original_shape: ", original_shape) # torch.Size([6]) , #!torch.vtensor<[1],si64>
# input_shape_tensor = torch.tensor(original_shape)
# print("dims: ", input_shape_tensor) # tensor([6])
# strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) #!torch.vtensor<[1],si64>
# print("strides: ", strides) # tensor([6])
# strides = torch.cat([strides[1:], torch.tensor([1])])
# print("strides: ", strides) # tensor([1]) !torch.vtensor<[1],si64>
# multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % input_shape_tensor
# print("multi_indices: ", multi_indices) # tensor([[3]])
# return multi_indices
# def test(a):
# a = torch.tensor(a)
# myout = nonzero(a)
# ptout = torch.nonzero(a)
# print("myout: ", myout) # tensor([[3]])
# print("ptout: ", ptout) # tensor([[3]])
# myout_reshaped = myout.reshape(ptout.shape)
# print("myout_reshaped: ", myout_reshaped) # tensor([[3]])
# return myout_reshaped
# test([0,0,0,1,0,0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment