Skip to content

Instantly share code, notes, and snippets.

View mattjj's full-sized avatar

Matthew Johnson mattjj

View GitHub Profile
from __future__ import annotations
from contextlib import contextmanager
from typing import NamedTuple, Callable, Optional, Any
import numpy as np
Array = Any
class Node(NamedTuple):
vjp: Optional[Callable]
parents: List[Node]
# referenced @chhillee's
from typing import Callable, Dict, Any, List
from functools import partial
import numpy as np
import torch
import torch._C._te as te
from jax import core
import torch
import torch.utils.dlpack
import jax
import jax.dlpack
# A generic mechanism for turning a JAX function into a PyTorch function.
def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
return x_torch
from typing import Callable, TypeVar
from collections import defaultdict
def ensure_tuple(x):
return x if isinstance(x, tuple) else (x,)
def safe_zip(*args):
x, *xs = args
assert all(len(x_) == len(x) for x_ in xs)
return list(zip(*args))
set -e
current_branch=$(git branch --show-current)
from jax import core
# A primitive is just a name to which we associate rules.
sincos_p = core.Primitive('sincos')
# A primitive's "bind" is how it gets applied, in a way that interacts with the
# trace/transform machinery. As a convention we wrap them in Python functions
# like this:
def sincos(x):
from functools import partial
import jax
import jax.numpy as np
from jax.scipy.special import logsumexp
from jax import lax, random
from jax import jit, grad
def log_normalizer(params, seq):
from functools import partial
from jax import core
from jax.util import safe_map, safe_zip
import jax.linear_util as lu
map = safe_map
zip = safe_zip
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax import custom_transforms
from jax import core
from jax import grad
def f(x, y):
return x**2 + 3 * y
mattjj /
Last active September 30, 2021 11:58
from functools import partial
from math import factorial
import jax.numpy as np
import matplotlib.pyplot as plt
from jax import jvp, vmap
def f(x):
return 1./5 * x**3 + 3 * x**2 - x + 1.