Skip to content

Instantly share code, notes, and snippets.

View thomasahle's full-sized avatar
♟️

Thomas Dybdahl Ahle thomasahle

♟️
View GitHub Profile
@thomasahle
thomasahle / topk.py
Created August 5, 2022 19:16
Simple Differentiable TopK for PyTorch
import torch
from functorch import vmap, grad
from torch.autograd import Function
sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))
class TopK(Function):
@staticmethod
def forward(ctx, xs, k):
import dspy
from pydantic import BaseModel
from typing import List
class State(BaseModel):
name: str
abbreviation: str
capital: str
class States(BaseModel):
@thomasahle
thomasahle / sinkhorn.py
Created March 15, 2021 15:19
Code from Differentiable Top-k Operator with Optimal Transport
def sinkhorn_forward(C, mu, nu, epsilon, max_iter):
bs, n, k_ = C.size()
v = torch.ones([bs, 1, k_])/(k_)
G = torch.exp(-C/epsilon)
if torch.cuda.is_available():
v = v.cuda()
for i in range(max_iter):
u = mu/(G*v).sum(-1, keepdim=True)
################################################################################
# NFA Implementation using greenery
################################################################################
import greenery
from greenery import rxelems as rx
from collections import defaultdict
class State:
def __init__(self, is_accept=False):
@thomasahle
thomasahle / soft_topk_bce.py
Created December 6, 2023 20:58
Soft TopK with BCE loss
import torch
from torch.autograd import Function
import torch.nn.functional as F
@torch.no_grad()
def _find_ts(xs, ks, binary_iter=16, newton_iter=1):
n = xs.shape[-1]
assert torch.all((0 < ks) & (ks < n)), "We don't support k=0 or k=n"
# Lo should be small enough that all sigmoids are in the 0 area.
@thomasahle
thomasahle / twelve.py
Created November 28, 2023 23:47
Probability of winning "game of twelve" with optimal play
import itertools
dp = [0] * 2**12
dp[0] = 1
for state in range(1, 2**12):
for d1, d2 in itertools.product(range(6), repeat=2):
o = 0
s = d1 + d2 + 1
if s < 12 and state & (1 << s):
o = max(o, dp[state & ~(1 << s)])
@thomasahle
thomasahle / scale_and_square.py
Created November 21, 2023 19:13
Implementation of Highams Algorithm 10.20 (scaling and squaring algorithm). This algorithm evaluates the matrix exponential X = e^A of A ∈ C^{n×n} using the scaling and squaring method.
import numpy as np
import scipy.linalg as sla
theta_m = {3: 1.5e-2, 5: 5.4e-1, 7: 9.5e-1, 9: 2.1e0, 13: 5.4e0}
pade_coefficients = {
3: [120, 60, 12, 1],
5: [30240, 15120, 3360, 420, 30, 1],
7: [17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1],
9: [17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1],
@thomasahle
thomasahle / nonlocals.py
Created November 18, 2023 21:36
Test of nonlocals functions
def nonlocals():
import inspect
stack = inspect.stack()
if len(stack) < 3: return {}
f = stack[2][0]
res = {}
while f.f_back:
res.update({k:v for k,v in f.f_locals.items() if k not in res})
f = f.f_back
return res
@thomasahle
thomasahle / ndiv.tex
Created November 12, 2023 18:34
Latex code for "does not divide" symbols
\usepackage{graphicx}
\newcommand{\shortslash}{\raisebox{0.2ex}{\scalebox{0.65}{/}}}
\newcommand{\notdivides}{\!\mathrel{\backslash\kern-0.4em\shortslash}\!}
\newcommand{\notdividesTim}{\!\!\mathrel{\rotatebox[origin=c]{20}{$\nmid$}}\!\!}
\def\notdividesHeinrich{\mathpalette\notdiv\relax}
\let\divides=\backslash
\def\notdiv#1#2{\setbox0=\hbox{$#1\divides$}%
\vcenter{\hbox to\wd0{\hss$\scriptscriptstyle/\hss$}}\kern-\wd0
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
import random