Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created July 23, 2022 21:00
Show Gist options
  • Save ptrblck/ed837c8f34caf8313332363c6602cdee to your computer and use it in GitHub Desktop.
Save ptrblck/ed837c8f34caf8313332363c6602cdee to your computer and use it in GitHub Desktop.
# for https://twitter.com/francoisfleuret/status/1550886362815012865
import torch
# setup
N, Q, R = 5, 20, 10
U = torch.randn(N, Q)
V = torch.arange(N*R).view(N, R).float()
# add -1s to U
U[torch.arange(U.size(0)), torch.randint(0, Q-R, (U.size(0),))] = -1.
# use another pass for 50% of the rows to make sure we are seeing some duplicates
idx = torch.randint(0, U.size(0), (U.size(0)//2,))
U[idx, torch.randint(0, Q-R, idx.size())] = -1.
print(U)
# get min indices for U==-1. for each row
r, c = (U==-1.).nonzero(as_tuple=True)
idx = torch.zeros(N).long()
idx.scatter_reduce_(0, r, c, reduce="amin", include_self=False)
print(idx)
# create mask to index U
mask = torch.ones(N, R).long()
mask[:, 0] = idx
mask.cumsum_(dim=1)
print(mask)
# copy V into U
U[torch.arange(U.size(0)).unsqueeze(1), mask] = V
print(U)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment