Skip to content

Instantly share code, notes, and snippets.

# Some cliff's notes on installing empy on a new machine. Haven't tested it yet.
pip3 install userpath
# install y
git clone ~/ml/y
# add y to PATH
userpath append ~/ml/y/bin
exec $SHELL
from typing import NamedTuple, Callable
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
shawwn / gist:d97b6b948d58111b51226f2012cbee30
Created Nov 11, 2021
How to dump XLA HLO pipelines in tensorflow
View gist:d97b6b948d58111b51226f2012cbee30
`XLA_FLAGS=--xla_dump_to=/tmp/xladump --xla_dump_hlo_pass_re=.*`
shawwn / What happens when you allocate a JAX tensor on a
Last active Nov 7, 2021
JAX C++ stack trace walkthrough for TpuExecutor_Allocate
View What happens when you allocate a JAX tensor on a
shawwn / libtpujesus.c
Created Nov 2, 2021
An example of building a custom "stub" library, with the ultimate goal of implementing your own "TPU" device for JAX.
View libtpujesus.c
/* libtpujesus.c
Copyright 2021 Shawn Presser
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
View gist:2591c555ab918020d6be2ee121000c23
/Users/spresser/ml/jax/jax/_src/lib/ UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
2021-11-02 01:28:09.661877: I external/org_tensorflow/tensorflow/core/tpu/] InitializeTpuStructFns...
2021-11-02 01:28:09.661949: I external/org_tensorflow/tensorflow/core/tpu/] TFTPU_SET_FN(ops_api_fn, ...) starting...
2021-11-02 01:28:09.661994: I external/org_tensorflow/tensorflow/core/tpu/] TFTPU_SET_FN(ops_api_fn, ...) finished
2021-11-02 01:28:09.661997: I external/org_tensorflow/tensorflow/core/tpu/] TFTPU_SET_FN(executor_fn, ...) starting...
2021-11-02 01:28:09.662064: I external/org_tensorflow/tensorflow/core/tpu/] TFTPU_SET_FN(executor_fn, ...) finished
2021-11-02 01:28:09.662067:
shawwn / jax_bazel_caching.diff
Created Nov 2, 2021
A diff to enable bazel caching. I have no idea why it works, or what the implications are of the various commands. In fact, I have no idea where I even got these commands from. Buyer beware.
View jax_bazel_caching.diff
diff --git a/build/ b/build/
index 2a812632..a78b1d03 100755
--- a/build/
+++ b/build/
@@ -507,6 +507,9 @@ def main():
config_args += ["--config=rocm"]
config_args += ["--config=nonccl"]
+ # don't rebuild everything every time.
+ config_args.extend("--action_env=PATH --remote_accept_cached=true --spawn_strategy=standalone --remote_local_fallback=false --remote_timeout=600".split())
# The public plotly graphs to include in the report. These can also be generated with `py.plot(figure, filename)`
graphs = [
def report_block_template(report_type, graph_url, caption=''):
if report_type == 'interactive':
def ConvMixr(h,d,k,p,n):
def A(x):
return Sequential(x, GELU(), BatchNorm2d(h))
class R(Sequential):
def forward(self, x):
return self[0](x) + x
return Sequential(
View pypi-json tensorflow.json
This file has been truncated, but you can view the full file.
"info": {
"author": "Google Inc.",
"author_email": "",
"bugtrack_url": null,
"classifiers": [
"Development Status :: 5 - Production/Stable",
"Environment :: GPU :: NVIDIA CUDA :: 11.0",
"Intended Audience :: Developers",