Skip to content

Instantly share code, notes, and snippets.

View mattjj's full-sized avatar

Matthew Johnson mattjj

View GitHub Profile
import torch
import torch.utils.dlpack
import jax
import jax.dlpack
# A generic mechanism for turning a JAX function into a PyTorch function.
def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
return x_torch
@mattjj
mattjj / nkp.py
Last active December 27, 2023 22:52
import numpy as np
def gram_matrix(Xs):
temp = np.vstack([np.ravel(X) for X in Xs])
return np.dot(temp, temp.T)
def eig(X):
vals, vecs = np.linalg.eig(X)
idx = np.argsort(np.abs(vals))
return vals[idx], vecs[...,idx]
@mattjj
mattjj / cholupdate.py
Last active December 13, 2023 04:00
cholesky updates and downdates
from __future__ import division
import numpy as np
from numpy.random import randn
from scipy.linalg.blas import drot, drotg
# references for updates:
# - Golub and van Loan (4th ed.) Section 6.5.4
# - http://mathoverflow.net/questions/30162/is-there-a-way-to-simplify-block-cholesky-decomposition-if-you-already-have-deco
#
# references for downdates:
from __future__ import annotations
from contextlib import contextmanager
from typing import NamedTuple, Callable, Optional, Any
import numpy as np
Array = Any
class Node(NamedTuple):
vjp: Optional[Callable]
parents: List[Node]
@mattjj
mattjj / .ctags
Created April 21, 2013 18:47
my ~/.ctags file
--recurse=yes
--tag-relative=yes
--exclude=*.git*
--exclude=*.pyc
--exclude=*.pyo
--exclude=.DS_Store
--exclude=*.md
--exclude=*.mkd
--langdef=Clojure
@mattjj
mattjj / websocketcat-receive.js
Created October 27, 2012 20:47
a tiny node.js server that routes stdin to a websocket
var socket = io.connect(window.location.host);
var received;
socket.on('data', function (data) {
received = JSON.parse(data);
});
# referenced @chhillee's https://github.com/pytorch/functorch/blob/main/functorch/_src/nnc_compile.py
from typing import Callable, Dict, Any, List
from functools import partial
import numpy as np
import torch
import torch._C._te as te
from jax import core
#!/bin/bash
set -e
current_branch=$(git branch --show-current)
base=${1:-master}
alt=${2:-${current_branch}}
bench=${3:-benchmarks/api_benchmark.py}
rest="${@:4}"
@mattjj
mattjj / jax_taylor.py
Last active September 30, 2021 11:58
from functools import partial
from math import factorial
import jax.numpy as np
import matplotlib.pyplot as plt
from jax import jvp, vmap
def f(x):
return 1./5 * x**3 + 3 * x**2 - x + 1.
from typing import Callable, TypeVar
from collections import defaultdict
def ensure_tuple(x):
return x if isinstance(x, tuple) else (x,)
def safe_zip(*args):
x, *xs = args
assert all(len(x_) == len(x) for x_ in xs)
return list(zip(*args))