Last active
July 31, 2020 12:19
-
-
Save twiecki/38dc98197eed5594c5518a3971064c92 to your computer and use it in GitHub Desktop.
pymc3jax.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy as np\n\nimport theano\nimport theano.tensor as tt\n\nimport jax\nimport jax.numpy as jnp\n\nfrom warnings import warn\nfrom functools import partial, update_wrapper\nfrom collections.abc import Sequence\n\nfrom multipledispatch import dispatch\n\nfrom theano.gof.graph import Node, Constant\nfrom theano.gof.link import (PerformLinker, map_storage, gc_helper, utils,\n add_clear_storage, Container, streamline)\n\nfrom theano.ifelse import IfElse\nfrom theano.tensor.subtensor import (\n get_idx_list,\n Subtensor,\n IncSubtensor,\n # This is essentially `np.take`\n AdvancedSubtensor1,\n AdvancedIncSubtensor1,\n # Boolean mask indexing and setting\n BaseAdvancedSubtensor,\n BaseAdvancedIncSubtensor,\n)\nfrom theano.scan_module.scan_op import Scan\nfrom theano.tensor.basic import (\n TensorFromScalar,\n ScalarFromTensor,\n)\nfrom theano.scalar.basic import (\n ScalarOp,\n Composite,\n Cast,\n Clip,\n)\nfrom theano.tensor.elemwise import Elemwise\n\n# from theano.printing import debugprint as tt_dprint\n\n\nsubtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)\nincsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)\n\n\ndef compose_jax_funcs(out_var, memo, inputs):\n \"\"\"Compose JAX implementations of node operations.\n Parameters\n ----------\n out_var: Variable\n The output variable.\n memo: Mapping\n A map from visited nodes to their JAX functions.\n inputs: List[Variable]\n The inputs--in a `FunctionGraph` sense--to `var`.\n Outputs\n -------\n A `function` object that represents the composed JAX operations and takes\n the same form of inputs as `inputs`.\n \"\"\"\n if out_var in memo:\n return memo[out_var]\n\n if out_var in inputs:\n idx = inputs.index(out_var)\n\n def jax_inputs_func(*inputs):\n # TODO: What about `jax.device_put`?\n jax_type = getattr(jnp, out_var.dtype)\n return jax_type(inputs[idx])\n\n memo[out_var] = jax_inputs_func\n return jax_inputs_func\n\n if out_var.owner is None:\n\n def jax_data_func(*inputs):\n return out_var.data\n\n memo[out_var] = jax_data_func\n return jax_data_func\n\n node = out_var.owner\n jax_return_func = jax_funcify(node.op, node)\n\n input_funcs = []\n for i in node.inputs:\n input_f = compose_jax_funcs(i, memo, inputs)\n input_funcs.append(input_f)\n\n def jax_func(*inputs):\n return jax_return_func(*[fn(*inputs) for fn in input_funcs])\n\n jax_func = update_wrapper(jax_func, jax_return_func)\n\n memo[out_var] = jax_func\n\n return jax_func\n\n\n@dispatch(ScalarOp, Node)\ndef jax_funcify(op, node):\n \"\"\"Create a JAX \"perform\" function for a Theano `Variable` and its `Op`.\"\"\"\n jnp_func = getattr(jnp, op.nfunc_spec[0])\n return jnp_func\n\n\n@jax_funcify.register(Clip, Node)\ndef jax_funcify_clip(op, node):\n return partial(op.impl, None)\n\n\n@jax_funcify.register(Cast, Node)\ndef jax_funcify_cast(op, node):\n def cast(x):\n return jnp.array(x).astype(op.o_type.dtype)\n return cast\n\n\n@jax_funcify.register(TensorFromScalar, Node)\ndef jax_funcify_TensorFromScalar(op, node):\n def tensor_from_scalar(x):\n return jnp.array(x)\n return tensor_from_scalar\n\n\n@jax_funcify.register(ScalarFromTensor, Node)\ndef jax_funcify_ScalarFromTensor(op, node):\n def scalar_from_tensor(x):\n return jnp.array(x).flatten()[0]\n return scalar_from_tensor\n\n\n@jax_funcify.register(Elemwise, Node)\ndef jax_funcify_Elemwise(op, node):\n node_op = node.op.scalar_op\n jax_scalar_func = jax_funcify(node_op, node)\n res = jax_scalar_func\n # TODO: Not sure when/if this is applicable.\n # res = jax.vmap(jax_scalar_func)\n return res\n\n\n@jax_funcify.register(Composite, Node)\ndef jax_funcify_Composite(op, node):\n\n fgraph_impls = [compose_jax_funcs(r, {}, op.fgraph.inputs) for r in op.fgraph.outputs]\n\n if len(fgraph_impls) == 1:\n jax_impl, = fgraph_impls\n else:\n def jax_impl(*inputs):\n return [impl(*inputs) for impl in fgraph_impls]\n\n return jax_impl\n\n\n@jax_funcify.register(Scan, Node)\ndef jax_funcify_Scan(op, node):\n def scan(x):\n return NotImplementedError()\n return scan\n\n\n@jax_funcify.register(IfElse, Node)\ndef jax_funcify_IfElse(op, node):\n\n def ifelse(cond, *args):\n if cond:\n return args[:op.n_outs]\n else:\n return args[op.n_outs:]\n\n return ifelse\n\n\n@jax_funcify.register(subtensor_ops, Node)\ndef jax_funcify_Subtensor(op, node):\n\n idx_list = getattr(op, \"idx_list\", None)\n\n def subtensor(x, *ilists):\n\n cdata = get_idx_list((x,) + ilists, idx_list)\n\n if len(cdata) == 1:\n cdata = cdata[0]\n\n return x.__getitem__(cdata)\n # return x.take(ilists, axis=0)\n\n return subtensor\n\n\n@jax_funcify.register(incsubtensor_ops, Node)\ndef jax_funcify_IncSubtensor(op, node):\n\n if getattr(op, \"set_instead_of_inc\", False):\n def incsubtensor(x, y, *ilist):\n return jax.ops.index_update(x, ilist, y)\n else:\n def incsubtensor(x, y, *ilist):\n return jax.ops.index_add(x, ilist, y)\n\n return incsubtensor\n\n\n\nclass JaxLinker(PerformLinker):\n \"\"\"A `Linker` that JIT-compiles NumPy-based operations using JAX.\"\"\"\n\n def make_all(self, input_storage=None, output_storage=None, storage_map=None):\n fgraph = self.fgraph\n nodes = self.schedule(fgraph)\n no_recycling = self.no_recycling\n\n input_storage, output_storage, storage_map = map_storage(\n fgraph, nodes, input_storage, output_storage, storage_map\n )\n\n compute_map = {}\n for k in storage_map:\n compute_map[k] = [k.owner is None]\n\n thunks = []\n try:\n node = nodes[-1]\n jax_funcs = [compose_jax_funcs(out_var, {}, fgraph.inputs) for out_var in node.outputs]\n\n if len(jax_funcs) == 1:\n jax_func, = jax_funcs\n else:\n def jax_func(*inputs):\n return [jf(*inputs) for jf in jax_funcs]\n\n # I suppose we can consider `Constant`s to be \"static\"\n # according to JAX.\n static_argnums = [n for n, i in enumerate(fgraph.inputs)\n if isinstance(i, Constant)]\n jax_impl_jit = jax.jit(jax_func, static_argnums)\n jax_impl_jit.nout = len(node.outputs)\n\n thunk_inputs = [storage_map[v] for v in fgraph.inputs]\n thunk_outputs = [storage_map[v] for v in fgraph.outputs]\n\n def thunk():\n outputs = jax_impl_jit(*[x[0] for x in thunk_inputs])\n\n if not isinstance(outputs, Sequence):\n outputs = [outputs]\n\n for i, (o_node, o_storage, o_val) in enumerate(zip(fgraph.outputs, thunk_outputs, outputs)):\n compute_map[o_node][0] = True\n o_storage[i] = o_val\n return outputs\n\n thunk.inputs = thunk_inputs\n thunk.outputs = thunk_outputs\n thunk.lazy = False\n\n nodes = [node]\n thunks.append(thunk)\n\n except NotImplementedError as e:\n warn(\"JaxLinker could not JAXify graph: {}\".format(e))\n\n for node in nodes:\n thunk = node.op.make_thunk(node, storage_map, compute_map,\n no_recycling, \"py\")\n thunk_inputs = [storage_map[v] for v in node.inputs]\n thunk_outputs = [storage_map[v] for v in node.outputs]\n\n thunk.inputs = thunk_inputs\n thunk.outputs = thunk_outputs\n\n thunks.append(thunk)\n\n computed, last_user = gc_helper(nodes)\n\n if self.allow_gc:\n post_thunk_old_storage = []\n\n for node in nodes:\n post_thunk_old_storage.append(\n [\n storage_map[input]\n for input in node.inputs\n if (input in computed)\n and (input not in fgraph.outputs)\n and (node == last_user[input])\n ]\n )\n else:\n post_thunk_old_storage = None\n\n if no_recycling is True:\n no_recycling = list(storage_map.values())\n no_recycling = utils.difference(no_recycling, input_storage)\n else:\n no_recycling = [\n storage_map[r] for r in no_recycling if r not in fgraph.inputs\n ]\n\n fn = streamline(\n fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling\n )\n\n fn.allow_gc = self.allow_gc\n add_clear_storage(fn, computed, storage_map)\n fn.storage_map = storage_map\n\n return (\n fn,\n [\n Container(input, storage)\n for input, storage in zip(fgraph.inputs, input_storage)\n ],\n [\n Container(output, storage, True)\n for output, storage in zip(fgraph.outputs, output_storage)\n ],\n thunks,\n nodes,\n )", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#\n# A test graph\n#\nx = tt.matrix('x')\ny = tt.matrix('y')\nz = tt.cosh(x**2 + y / 3.0)\nout = tt.set_subtensor(z[0], -10.0)\nout = tt.inc_subtensor(out[0, 1], 2.0)\nout = out[:5, :3]\n\ntest_inputs = [x, y]\ntest_outputs = out\n\ntest_input_vals = [\n np.tile(np.arange(10), (10, 1)),\n np.tile(np.arange(10, 20), (10, 1)),\n]", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "theano_c_fn = theano.function(test_inputs, test_outputs)\n\n# Register this linker with the config system\n# theano.compile.mode.register_linker(\"jax\", JaxLinker)\n\n# During development, we need to overwrite old classes:\ntheano.compile.mode.predefined_linkers[\"jax\"] = JaxLinker()\n\njax_mode = theano.compile.Mode(linker=\"jax\")\ntheano_jax_fn = theano.function(test_inputs, test_outputs, mode=jax_mode)\n\njax_res = theano_jax_fn(*test_input_vals)\n\n# Confirm that the result is from JAX\nassert isinstance(jax_res, jax.interpreters.xla.DeviceArray)\n\n# Confirm that the `Subtensor` slice operations are correct\nassert jax_res.shape == (5, 3)\n\n# Confirm that the `IncSubtensor` operations are correct\nassert jax_res[0, 0] == -10.0\nassert jax_res[0, 1] == -8.0\n\n# Confirm that the result is correct (according to regular Theano compilation)\npy_mode = theano.compile.Mode(linker=\"py\")\ntheano_py_fn = theano.function(test_inputs, test_outputs, mode=py_mode)\npy_res = theano_py_fn(*test_input_vals)\n\nassert np.allclose(jax_res, py_res)", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit theano_jax_fn(*test_input_vals)", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "417 µs ± 61.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit theano_py_fn(*test_input_vals)", | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "2.89 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit theano_c_fn(*test_input_vals)", | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "67.5 µs ± 4.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "417/67.5", | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 15, | |
"data": { | |
"text/plain": "6.177777777777778" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "test_input_vals = [\n np.tile(np.arange(1000), (1000, 1)),\n np.tile(np.arange(1000), (1000, 1)),\n]", | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit theano_c_fn(*test_input_vals)", | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "7.77 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%timeit theano_jax_fn(*test_input_vals)", | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "6 ms ± 606 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import pymc3 as pm", | |
"execution_count": 28, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "with pm.Model() as model:\n x = pm.Normal(\"x\")", | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "model.logp_dlogp_function(mode=jax_mode)", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "<ipython-input-1-e78ca8561185>:270: UserWarning: JaxLinker could not JAXify graph: Could not find signature for jax_funcify: <Alloc, Apply>\n warn(\"JaxLinker could not JAXify graph: {}\".format(e))\n", | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 32, | |
"data": { | |
"text/plain": "<pymc3.model.ValueGradFunction at 0x13af19310>" | |
}, | |
"metadata": {} | |
} | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "pymc3py38", | |
"display_name": "pymc3py38", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.2", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "38dc98197eed5594c5518a3971064c92", | |
"data": { | |
"description": "pymc3jax.ipynb", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/38dc98197eed5594c5518a3971064c92" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment