Last active
July 31, 2020 05:44
-
-
Save brandonwillard/4ac0d006371497e1099529f8cc1fd31d to your computer and use it in GitHub Desktop.
JAX Compilation Using Theano via a Custom Linker Class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as tt | |
import jax | |
import jax.numpy as jnp | |
from warnings import warn | |
from functools import partial, update_wrapper | |
from collections.abc import Sequence | |
from multipledispatch import dispatch | |
from theano.gof.graph import Node, Constant | |
from theano.gof.link import (PerformLinker, map_storage, gc_helper, utils, | |
add_clear_storage, Container, streamline) | |
from theano.ifelse import IfElse | |
from theano.tensor.subtensor import ( | |
get_idx_list, | |
Subtensor, | |
IncSubtensor, | |
# This is essentially `np.take` | |
AdvancedSubtensor1, | |
AdvancedIncSubtensor1, | |
# Boolean mask indexing and setting | |
BaseAdvancedSubtensor, | |
BaseAdvancedIncSubtensor, | |
) | |
from theano.scan_module.scan_op import Scan | |
from theano.tensor.basic import ( | |
TensorFromScalar, | |
ScalarFromTensor, | |
) | |
from theano.scalar.basic import ( | |
ScalarOp, | |
Composite, | |
Cast, | |
Clip, | |
) | |
from theano.tensor.elemwise import Elemwise | |
# from theano.printing import debugprint as tt_dprint | |
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor) | |
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor) | |
def compose_jax_funcs(out_var, memo, inputs): | |
"""Compose JAX implementations of node operations. | |
Parameters | |
---------- | |
out_var: Variable | |
The output variable. | |
memo: Mapping | |
A map from visited nodes to their JAX functions. | |
inputs: List[Variable] | |
The inputs--in a `FunctionGraph` sense--to `var`. | |
Outputs | |
------- | |
A `function` object that represents the composed JAX operations and takes | |
the same form of inputs as `inputs`. | |
""" | |
if out_var in memo: | |
return memo[out_var] | |
if out_var in inputs: | |
idx = inputs.index(out_var) | |
def jax_inputs_func(*inputs): | |
# TODO: What about `jax.device_put`? | |
jax_type = getattr(jnp, out_var.dtype) | |
return jax_type(inputs[idx]) | |
memo[out_var] = jax_inputs_func | |
return jax_inputs_func | |
if out_var.owner is None: | |
def jax_data_func(*inputs): | |
return out_var.data | |
memo[out_var] = jax_data_func | |
return jax_data_func | |
node = out_var.owner | |
jax_return_func = jax_funcify(node.op, node) | |
input_funcs = [] | |
for i in node.inputs: | |
input_f = compose_jax_funcs(i, memo, inputs) | |
input_funcs.append(input_f) | |
def jax_func(*inputs): | |
return jax_return_func(*[fn(*inputs) for fn in input_funcs]) | |
jax_func = update_wrapper(jax_func, jax_return_func) | |
memo[out_var] = jax_func | |
return jax_func | |
@dispatch(ScalarOp, Node) | |
def jax_funcify(op, node): | |
"""Create a JAX "perform" function for a Theano `Variable` and its `Op`.""" | |
jnp_func = getattr(jnp, op.nfunc_spec[0]) | |
return jnp_func | |
@jax_funcify.register(Clip, Node) | |
def jax_funcify_clip(op, node): | |
return partial(op.impl, None) | |
@jax_funcify.register(Cast, Node) | |
def jax_funcify_cast(op, node): | |
def cast(x): | |
return jnp.array(x).astype(op.o_type.dtype) | |
return cast | |
@jax_funcify.register(TensorFromScalar, Node) | |
def jax_funcify_TensorFromScalar(op, node): | |
def tensor_from_scalar(x): | |
return jnp.array(x) | |
return tensor_from_scalar | |
@jax_funcify.register(ScalarFromTensor, Node) | |
def jax_funcify_ScalarFromTensor(op, node): | |
def scalar_from_tensor(x): | |
return jnp.array(x).flatten()[0] | |
return scalar_from_tensor | |
@jax_funcify.register(Elemwise, Node) | |
def jax_funcify_Elemwise(op, node): | |
node_op = node.op.scalar_op | |
jax_scalar_func = jax_funcify(node_op, node) | |
res = jax_scalar_func | |
# TODO: Not sure when/if this is applicable. | |
# res = jax.vmap(jax_scalar_func) | |
return res | |
@jax_funcify.register(Composite, Node) | |
def jax_funcify_Composite(op, node): | |
fgraph_impls = [compose_jax_funcs(r, {}, op.fgraph.inputs) for r in op.fgraph.outputs] | |
if len(fgraph_impls) == 1: | |
jax_impl, = fgraph_impls | |
else: | |
def jax_impl(*inputs): | |
return [impl(*inputs) for impl in fgraph_impls] | |
return jax_impl | |
@jax_funcify.register(Scan, Node) | |
def jax_funcify_Scan(op, node): | |
def scan(x): | |
return NotImplementedError() | |
return scan | |
@jax_funcify.register(IfElse, Node) | |
def jax_funcify_IfElse(op, node): | |
def ifelse(cond, *args): | |
if cond: | |
return args[:op.n_outs] | |
else: | |
return args[op.n_outs:] | |
return ifelse | |
@jax_funcify.register(subtensor_ops, Node) | |
def jax_funcify_Subtensor(op, node): | |
idx_list = getattr(op, "idx_list", None) | |
def subtensor(x, *ilists): | |
cdata = get_idx_list((x,) + ilists, idx_list) | |
if len(cdata) == 1: | |
cdata = cdata[0] | |
return x.__getitem__(cdata) | |
# return x.take(ilists, axis=0) | |
return subtensor | |
@jax_funcify.register(incsubtensor_ops, Node) | |
def jax_funcify_IncSubtensor(op, node): | |
if getattr(op, "set_instead_of_inc", False): | |
def incsubtensor(x, y, *ilist): | |
return jax.ops.index_update(x, ilist, y) | |
else: | |
def incsubtensor(x, y, *ilist): | |
return jax.ops.index_add(x, ilist, y) | |
return incsubtensor | |
class JaxLinker(PerformLinker): | |
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" | |
def make_all(self, input_storage=None, output_storage=None, storage_map=None): | |
fgraph = self.fgraph | |
nodes = self.schedule(fgraph) | |
no_recycling = self.no_recycling | |
input_storage, output_storage, storage_map = map_storage( | |
fgraph, nodes, input_storage, output_storage, storage_map | |
) | |
compute_map = {} | |
for k in storage_map: | |
compute_map[k] = [k.owner is None] | |
thunks = [] | |
try: | |
node = nodes[-1] | |
jax_funcs = [compose_jax_funcs(out_var, {}, fgraph.inputs) for out_var in node.outputs] | |
if len(jax_funcs) == 1: | |
jax_func, = jax_funcs | |
else: | |
def jax_func(*inputs): | |
return [jf(*inputs) for jf in jax_funcs] | |
# I suppose we can consider `Constant`s to be "static" | |
# according to JAX. | |
static_argnums = [n for n, i in enumerate(fgraph.inputs) | |
if isinstance(i, Constant)] | |
jax_impl_jit = jax.jit(jax_func, static_argnums) | |
jax_impl_jit.nout = len(node.outputs) | |
thunk_inputs = [storage_map[v] for v in fgraph.inputs] | |
thunk_outputs = [storage_map[v] for v in fgraph.outputs] | |
def thunk(): | |
outputs = jax_impl_jit(*[x[0] for x in thunk_inputs]) | |
if not isinstance(outputs, Sequence): | |
outputs = [outputs] | |
for i, (o_node, o_storage, o_val) in enumerate(zip(fgraph.outputs, thunk_outputs, outputs)): | |
compute_map[o_node][0] = True | |
o_storage[i] = o_val | |
return outputs | |
thunk.inputs = thunk_inputs | |
thunk.outputs = thunk_outputs | |
thunk.lazy = False | |
nodes = [node] | |
thunks.append(thunk) | |
except NotImplementedError as e: | |
warn("JaxLinker could not JAXify graph: {}".format(e)) | |
for node in nodes: | |
thunk = node.op.make_thunk(node, storage_map, compute_map, | |
no_recycling, "py") | |
thunk_inputs = [storage_map[v] for v in node.inputs] | |
thunk_outputs = [storage_map[v] for v in node.outputs] | |
thunk.inputs = thunk_inputs | |
thunk.outputs = thunk_outputs | |
thunks.append(thunk) | |
computed, last_user = gc_helper(nodes) | |
if self.allow_gc: | |
post_thunk_old_storage = [] | |
for node in nodes: | |
post_thunk_old_storage.append( | |
[ | |
storage_map[input] | |
for input in node.inputs | |
if (input in computed) | |
and (input not in fgraph.outputs) | |
and (node == last_user[input]) | |
] | |
) | |
else: | |
post_thunk_old_storage = None | |
if no_recycling is True: | |
no_recycling = list(storage_map.values()) | |
no_recycling = utils.difference(no_recycling, input_storage) | |
else: | |
no_recycling = [ | |
storage_map[r] for r in no_recycling if r not in fgraph.inputs | |
] | |
fn = streamline( | |
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling | |
) | |
fn.allow_gc = self.allow_gc | |
add_clear_storage(fn, computed, storage_map) | |
fn.storage_map = storage_map | |
return ( | |
fn, | |
[ | |
Container(input, storage) | |
for input, storage in zip(fgraph.inputs, input_storage) | |
], | |
[ | |
Container(output, storage, True) | |
for output, storage in zip(fgraph.outputs, output_storage) | |
], | |
thunks, | |
nodes, | |
) | |
# Register this linker with the config system | |
# theano.compile.mode.register_linker("jax", JaxLinker) | |
# During development, we need to overwrite old classes: | |
theano.compile.mode.predefined_linkers["jax"] = JaxLinker() | |
# | |
# A test graph | |
# | |
x = tt.matrix('x') | |
y = tt.matrix('y') | |
z = tt.cosh(x**2 + y / 3.0) | |
out = tt.set_subtensor(z[0], -10.0) | |
out = tt.inc_subtensor(out[0, 1], 2.0) | |
out = out[:5, :3] | |
test_inputs = [x, y] | |
test_outputs = out | |
jax_mode = theano.compile.Mode(linker="jax") | |
theano_jax_fn = theano.function(test_inputs, test_outputs, mode=jax_mode) | |
test_input_vals = [ | |
np.tile(np.arange(10), (10, 1)), | |
np.tile(np.arange(10, 20), (10, 1)), | |
] | |
jax_res = theano_jax_fn(*test_input_vals) | |
# Confirm that the result is from JAX | |
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) | |
# Confirm that the `Subtensor` slice operations are correct | |
assert jax_res.shape == (5, 3) | |
# Confirm that the `IncSubtensor` operations are correct | |
assert jax_res[0, 0] == -10.0 | |
assert jax_res[0, 1] == -8.0 | |
# Confirm that the result is correct (according to regular Theano compilation) | |
py_mode = theano.compile.Mode(linker="py") | |
theano_py_fn = theano.function(test_inputs, test_outputs, mode=py_mode) | |
py_res = theano_py_fn(*test_input_vals) | |
assert np.allclose(jax_res, py_res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment