Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active July 31, 2020 05:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/4ac0d006371497e1099529f8cc1fd31d to your computer and use it in GitHub Desktop.
Save brandonwillard/4ac0d006371497e1099529f8cc1fd31d to your computer and use it in GitHub Desktop.
JAX Compilation Using Theano via a Custom Linker Class
import numpy as np
import theano
import theano.tensor as tt
import jax
import jax.numpy as jnp
from warnings import warn
from functools import partial, update_wrapper
from collections.abc import Sequence
from multipledispatch import dispatch
from theano.gof.graph import Node, Constant
from theano.gof.link import (PerformLinker, map_storage, gc_helper, utils,
add_clear_storage, Container, streamline)
from theano.ifelse import IfElse
from theano.tensor.subtensor import (
get_idx_list,
Subtensor,
IncSubtensor,
# This is essentially `np.take`
AdvancedSubtensor1,
AdvancedIncSubtensor1,
# Boolean mask indexing and setting
BaseAdvancedSubtensor,
BaseAdvancedIncSubtensor,
)
from theano.scan_module.scan_op import Scan
from theano.tensor.basic import (
TensorFromScalar,
ScalarFromTensor,
)
from theano.scalar.basic import (
ScalarOp,
Composite,
Cast,
Clip,
)
from theano.tensor.elemwise import Elemwise
# from theano.printing import debugprint as tt_dprint
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
def compose_jax_funcs(out_var, memo, inputs):
"""Compose JAX implementations of node operations.
Parameters
----------
out_var: Variable
The output variable.
memo: Mapping
A map from visited nodes to their JAX functions.
inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `var`.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `inputs`.
"""
if out_var in memo:
return memo[out_var]
if out_var in inputs:
idx = inputs.index(out_var)
def jax_inputs_func(*inputs):
# TODO: What about `jax.device_put`?
jax_type = getattr(jnp, out_var.dtype)
return jax_type(inputs[idx])
memo[out_var] = jax_inputs_func
return jax_inputs_func
if out_var.owner is None:
def jax_data_func(*inputs):
return out_var.data
memo[out_var] = jax_data_func
return jax_data_func
node = out_var.owner
jax_return_func = jax_funcify(node.op, node)
input_funcs = []
for i in node.inputs:
input_f = compose_jax_funcs(i, memo, inputs)
input_funcs.append(input_f)
def jax_func(*inputs):
return jax_return_func(*[fn(*inputs) for fn in input_funcs])
jax_func = update_wrapper(jax_func, jax_return_func)
memo[out_var] = jax_func
return jax_func
@dispatch(ScalarOp, Node)
def jax_funcify(op, node):
"""Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
jnp_func = getattr(jnp, op.nfunc_spec[0])
return jnp_func
@jax_funcify.register(Clip, Node)
def jax_funcify_clip(op, node):
return partial(op.impl, None)
@jax_funcify.register(Cast, Node)
def jax_funcify_cast(op, node):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
return cast
@jax_funcify.register(TensorFromScalar, Node)
def jax_funcify_TensorFromScalar(op, node):
def tensor_from_scalar(x):
return jnp.array(x)
return tensor_from_scalar
@jax_funcify.register(ScalarFromTensor, Node)
def jax_funcify_ScalarFromTensor(op, node):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
@jax_funcify.register(Elemwise, Node)
def jax_funcify_Elemwise(op, node):
node_op = node.op.scalar_op
jax_scalar_func = jax_funcify(node_op, node)
res = jax_scalar_func
# TODO: Not sure when/if this is applicable.
# res = jax.vmap(jax_scalar_func)
return res
@jax_funcify.register(Composite, Node)
def jax_funcify_Composite(op, node):
fgraph_impls = [compose_jax_funcs(r, {}, op.fgraph.inputs) for r in op.fgraph.outputs]
if len(fgraph_impls) == 1:
jax_impl, = fgraph_impls
else:
def jax_impl(*inputs):
return [impl(*inputs) for impl in fgraph_impls]
return jax_impl
@jax_funcify.register(Scan, Node)
def jax_funcify_Scan(op, node):
def scan(x):
return NotImplementedError()
return scan
@jax_funcify.register(IfElse, Node)
def jax_funcify_IfElse(op, node):
def ifelse(cond, *args):
if cond:
return args[:op.n_outs]
else:
return args[op.n_outs:]
return ifelse
@jax_funcify.register(subtensor_ops, Node)
def jax_funcify_Subtensor(op, node):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
cdata = get_idx_list((x,) + ilists, idx_list)
if len(cdata) == 1:
cdata = cdata[0]
return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor
@jax_funcify.register(incsubtensor_ops, Node)
def jax_funcify_IncSubtensor(op, node):
if getattr(op, "set_instead_of_inc", False):
def incsubtensor(x, y, *ilist):
return jax.ops.index_update(x, ilist, y)
else:
def incsubtensor(x, y, *ilist):
return jax.ops.index_add(x, ilist, y)
return incsubtensor
class JaxLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = []
try:
node = nodes[-1]
jax_funcs = [compose_jax_funcs(out_var, {}, fgraph.inputs) for out_var in node.outputs]
if len(jax_funcs) == 1:
jax_func, = jax_funcs
else:
def jax_func(*inputs):
return [jf(*inputs) for jf in jax_funcs]
# I suppose we can consider `Constant`s to be "static"
# according to JAX.
static_argnums = [n for n, i in enumerate(fgraph.inputs)
if isinstance(i, Constant)]
jax_impl_jit = jax.jit(jax_func, static_argnums)
jax_impl_jit.nout = len(node.outputs)
thunk_inputs = [storage_map[v] for v in fgraph.inputs]
thunk_outputs = [storage_map[v] for v in fgraph.outputs]
def thunk():
outputs = jax_impl_jit(*[x[0] for x in thunk_inputs])
if not isinstance(outputs, Sequence):
outputs = [outputs]
for i, (o_node, o_storage, o_val) in enumerate(zip(fgraph.outputs, thunk_outputs, outputs)):
compute_map[o_node][0] = True
o_storage[i] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
nodes = [node]
thunks.append(thunk)
except NotImplementedError as e:
warn("JaxLinker could not JAXify graph: {}".format(e))
for node in nodes:
thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling, "py")
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
# Register this linker with the config system
# theano.compile.mode.register_linker("jax", JaxLinker)
# During development, we need to overwrite old classes:
theano.compile.mode.predefined_linkers["jax"] = JaxLinker()
#
# A test graph
#
x = tt.matrix('x')
y = tt.matrix('y')
z = tt.cosh(x**2 + y / 3.0)
out = tt.set_subtensor(z[0], -10.0)
out = tt.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3]
test_inputs = [x, y]
test_outputs = out
jax_mode = theano.compile.Mode(linker="jax")
theano_jax_fn = theano.function(test_inputs, test_outputs, mode=jax_mode)
test_input_vals = [
np.tile(np.arange(10), (10, 1)),
np.tile(np.arange(10, 20), (10, 1)),
]
jax_res = theano_jax_fn(*test_input_vals)
# Confirm that the result is from JAX
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
# Confirm that the `IncSubtensor` operations are correct
assert jax_res[0, 0] == -10.0
assert jax_res[0, 1] == -8.0
# Confirm that the result is correct (according to regular Theano compilation)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(test_inputs, test_outputs, mode=py_mode)
py_res = theano_py_fn(*test_input_vals)
assert np.allclose(jax_res, py_res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment