Skip to content

Instantly share code, notes, and snippets.

@thomasahle
thomasahle / topk.py
Created August 5, 2022 19:16
Simple Differentiable TopK for PyTorch
import torch
from functorch import vmap, grad
from torch.autograd import Function
sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))
class TopK(Function):
@staticmethod
def forward(ctx, xs, k):