This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from contextlib import contextmanager | |
from typing import NamedTuple, Callable, Optional, Any | |
import numpy as np | |
Array = Any | |
class Node(NamedTuple): | |
vjp: Optional[Callable] | |
parents: List[Node] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# referenced @chhillee's https://github.com/pytorch/functorch/blob/main/functorch/_src/nnc_compile.py | |
from typing import Callable, Dict, Any, List | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch._C._te as te | |
from jax import core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.dlpack | |
import jax | |
import jax.dlpack | |
# A generic mechanism for turning a JAX function into a PyTorch function. | |
def j2t(x_jax): | |
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax)) | |
return x_torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, TypeVar | |
from collections import defaultdict | |
def ensure_tuple(x): | |
return x if isinstance(x, tuple) else (x,) | |
def safe_zip(*args): | |
x, *xs = args | |
assert all(len(x_) == len(x) for x_ in xs) | |
return list(zip(*args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
set -e | |
current_branch=$(git branch --show-current) | |
base=${1:-master} | |
alt=${2:-${current_branch}} | |
bench=${3:-benchmarks/api_benchmark.py} | |
rest="${@:4}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax import core | |
# A primitive is just a name to which we associate rules. | |
sincos_p = core.Primitive('sincos') | |
# A primitive's "bind" is how it gets applied, in a way that interacts with the | |
# trace/transform machinery. As a convention we wrap them in Python functions | |
# like this: | |
def sincos(x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import jax | |
import jax.numpy as np | |
from jax.scipy.special import logsumexp | |
from jax import lax, random | |
from jax import jit, grad | |
def log_normalizer(params, seq): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from jax import core | |
from jax.util import safe_map, safe_zip | |
import jax.linear_util as lu | |
map = safe_map | |
zip = safe_zip | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax.interpreters import ad | |
from jax.interpreters import partial_eval as pe | |
from jax import custom_transforms | |
from jax import core | |
from jax import grad | |
@custom_transforms | |
def f(x, y): | |
return x**2 + 3 * y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from math import factorial | |
import jax.numpy as np | |
import matplotlib.pyplot as plt | |
from jax import jvp, vmap | |
def f(x): | |
return 1./5 * x**3 + 3 * x**2 - x + 1. |
NewerOlder