Skip to content

Instantly share code, notes, and snippets.

@twiecki
Last active July 31, 2020 12:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twiecki/38dc98197eed5594c5518a3971064c92 to your computer and use it in GitHub Desktop.
Save twiecki/38dc98197eed5594c5518a3971064c92 to your computer and use it in GitHub Desktop.
pymc3jax.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy as np\n\nimport theano\nimport theano.tensor as tt\n\nimport jax\nimport jax.numpy as jnp\n\nfrom warnings import warn\nfrom functools import partial, update_wrapper\nfrom collections.abc import Sequence\n\nfrom multipledispatch import dispatch\n\nfrom theano.gof.graph import Node, Constant\nfrom theano.gof.link import (PerformLinker, map_storage, gc_helper, utils,\n add_clear_storage, Container, streamline)\n\nfrom theano.ifelse import IfElse\nfrom theano.tensor.subtensor import (\n get_idx_list,\n Subtensor,\n IncSubtensor,\n # This is essentially `np.take`\n AdvancedSubtensor1,\n AdvancedIncSubtensor1,\n # Boolean mask indexing and setting\n BaseAdvancedSubtensor,\n BaseAdvancedIncSubtensor,\n)\nfrom theano.scan_module.scan_op import Scan\nfrom theano.tensor.basic import (\n TensorFromScalar,\n ScalarFromTensor,\n)\nfrom theano.scalar.basic import (\n ScalarOp,\n Composite,\n Cast,\n Clip,\n)\nfrom theano.tensor.elemwise import Elemwise\n\n# from theano.printing import debugprint as tt_dprint\n\n\nsubtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)\nincsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)\n\n\ndef compose_jax_funcs(out_var, memo, inputs):\n \"\"\"Compose JAX implementations of node operations.\n Parameters\n ----------\n out_var: Variable\n The output variable.\n memo: Mapping\n A map from visited nodes to their JAX functions.\n inputs: List[Variable]\n The inputs--in a `FunctionGraph` sense--to `var`.\n Outputs\n -------\n A `function` object that represents the composed JAX operations and takes\n the same form of inputs as `inputs`.\n \"\"\"\n if out_var in memo:\n return memo[out_var]\n\n if out_var in inputs:\n idx = inputs.index(out_var)\n\n def jax_inputs_func(*inputs):\n # TODO: What about `jax.device_put`?\n jax_type = getattr(jnp, out_var.dtype)\n return jax_type(inputs[idx])\n\n memo[out_var] = jax_inputs_func\n return jax_inputs_func\n\n if out_var.owner is None:\n\n def jax_data_func(*inputs):\n return out_var.data\n\n memo[out_var] = jax_data_func\n return jax_data_func\n\n node = out_var.owner\n jax_return_func = jax_funcify(node.op, node)\n\n input_funcs = []\n for i in node.inputs:\n input_f = compose_jax_funcs(i, memo, inputs)\n input_funcs.append(input_f)\n\n def jax_func(*inputs):\n return jax_return_func(*[fn(*inputs) for fn in input_funcs])\n\n jax_func = update_wrapper(jax_func, jax_return_func)\n\n memo[out_var] = jax_func\n\n return jax_func\n\n\n@dispatch(ScalarOp, Node)\ndef jax_funcify(op, node):\n \"\"\"Create a JAX \"perform\" function for a Theano `Variable` and its `Op`.\"\"\"\n jnp_func = getattr(jnp, op.nfunc_spec[0])\n return jnp_func\n\n\n@jax_funcify.register(Clip, Node)\ndef jax_funcify_clip(op, node):\n return partial(op.impl, None)\n\n\n@jax_funcify.register(Cast, Node)\ndef jax_funcify_cast(op, node):\n def cast(x):\n return jnp.array(x).astype(op.o_type.dtype)\n return cast\n\n\n@jax_funcify.register(TensorFromScalar, Node)\ndef jax_funcify_TensorFromScalar(op, node):\n def tensor_from_scalar(x):\n return jnp.array(x)\n return tensor_from_scalar\n\n\n@jax_funcify.register(ScalarFromTensor, Node)\ndef jax_funcify_ScalarFromTensor(op, node):\n def scalar_from_tensor(x):\n return jnp.array(x).flatten()[0]\n return scalar_from_tensor\n\n\n@jax_funcify.register(Elemwise, Node)\ndef jax_funcify_Elemwise(op, node):\n node_op = node.op.scalar_op\n jax_scalar_func = jax_funcify(node_op, node)\n res = jax_scalar_func\n # TODO: Not sure when/if this is applicable.\n # res = jax.vmap(jax_scalar_func)\n return res\n\n\n@jax_funcify.register(Composite, Node)\ndef jax_funcify_Composite(op, node):\n\n fgraph_impls = [compose_jax_funcs(r, {}, op.fgraph.inputs) for r in op.fgraph.outputs]\n\n if len(fgraph_impls) == 1:\n jax_impl, = fgraph_impls\n else:\n def jax_impl(*inputs):\n return [impl(*inputs) for impl in fgraph_impls]\n\n return jax_impl\n\n\n@jax_funcify.register(Scan, Node)\ndef jax_funcify_Scan(op, node):\n def scan(x):\n return NotImplementedError()\n return scan\n\n\n@jax_funcify.register(IfElse, Node)\ndef jax_funcify_IfElse(op, node):\n\n def ifelse(cond, *args):\n if cond:\n return args[:op.n_outs]\n else:\n return args[op.n_outs:]\n\n return ifelse\n\n\n@jax_funcify.register(subtensor_ops, Node)\ndef jax_funcify_Subtensor(op, node):\n\n idx_list = getattr(op, \"idx_list\", None)\n\n def subtensor(x, *ilists):\n\n cdata = get_idx_list((x,) + ilists, idx_list)\n\n if len(cdata) == 1:\n cdata = cdata[0]\n\n return x.__getitem__(cdata)\n # return x.take(ilists, axis=0)\n\n return subtensor\n\n\n@jax_funcify.register(incsubtensor_ops, Node)\ndef jax_funcify_IncSubtensor(op, node):\n\n if getattr(op, \"set_instead_of_inc\", False):\n def incsubtensor(x, y, *ilist):\n return jax.ops.index_update(x, ilist, y)\n else:\n def incsubtensor(x, y, *ilist):\n return jax.ops.index_add(x, ilist, y)\n\n return incsubtensor\n\n\n\nclass JaxLinker(PerformLinker):\n \"\"\"A `Linker` that JIT-compiles NumPy-based operations using JAX.\"\"\"\n\n def make_all(self, input_storage=None, output_storage=None, storage_map=None):\n fgraph = self.fgraph\n nodes = self.schedule(fgraph)\n no_recycling = self.no_recycling\n\n input_storage, output_storage, storage_map = map_storage(\n fgraph, nodes, input_storage, output_storage, storage_map\n )\n\n compute_map = {}\n for k in storage_map:\n compute_map[k] = [k.owner is None]\n\n thunks = []\n try:\n node = nodes[-1]\n jax_funcs = [compose_jax_funcs(out_var, {}, fgraph.inputs) for out_var in node.outputs]\n\n if len(jax_funcs) == 1:\n jax_func, = jax_funcs\n else:\n def jax_func(*inputs):\n return [jf(*inputs) for jf in jax_funcs]\n\n # I suppose we can consider `Constant`s to be \"static\"\n # according to JAX.\n static_argnums = [n for n, i in enumerate(fgraph.inputs)\n if isinstance(i, Constant)]\n jax_impl_jit = jax.jit(jax_func, static_argnums)\n jax_impl_jit.nout = len(node.outputs)\n\n thunk_inputs = [storage_map[v] for v in fgraph.inputs]\n thunk_outputs = [storage_map[v] for v in fgraph.outputs]\n\n def thunk():\n outputs = jax_impl_jit(*[x[0] for x in thunk_inputs])\n\n if not isinstance(outputs, Sequence):\n outputs = [outputs]\n\n for i, (o_node, o_storage, o_val) in enumerate(zip(fgraph.outputs, thunk_outputs, outputs)):\n compute_map[o_node][0] = True\n o_storage[i] = o_val\n return outputs\n\n thunk.inputs = thunk_inputs\n thunk.outputs = thunk_outputs\n thunk.lazy = False\n\n nodes = [node]\n thunks.append(thunk)\n\n except NotImplementedError as e:\n warn(\"JaxLinker could not JAXify graph: {}\".format(e))\n\n for node in nodes:\n thunk = node.op.make_thunk(node, storage_map, compute_map,\n no_recycling, \"py\")\n thunk_inputs = [storage_map[v] for v in node.inputs]\n thunk_outputs = [storage_map[v] for v in node.outputs]\n\n thunk.inputs = thunk_inputs\n thunk.outputs = thunk_outputs\n\n thunks.append(thunk)\n\n computed, last_user = gc_helper(nodes)\n\n if self.allow_gc:\n post_thunk_old_storage = []\n\n for node in nodes:\n post_thunk_old_storage.append(\n [\n storage_map[input]\n for input in node.inputs\n if (input in computed)\n and (input not in fgraph.outputs)\n and (node == last_user[input])\n ]\n )\n else:\n post_thunk_old_storage = None\n\n if no_recycling is True:\n no_recycling = list(storage_map.values())\n no_recycling = utils.difference(no_recycling, input_storage)\n else:\n no_recycling = [\n storage_map[r] for r in no_recycling if r not in fgraph.inputs\n ]\n\n fn = streamline(\n fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling\n )\n\n fn.allow_gc = self.allow_gc\n add_clear_storage(fn, computed, storage_map)\n fn.storage_map = storage_map\n\n return (\n fn,\n [\n Container(input, storage)\n for input, storage in zip(fgraph.inputs, input_storage)\n ],\n [\n Container(output, storage, True)\n for output, storage in zip(fgraph.outputs, output_storage)\n ],\n thunks,\n nodes,\n )",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#\n# A test graph\n#\nx = tt.matrix('x')\ny = tt.matrix('y')\nz = tt.cosh(x**2 + y / 3.0)\nout = tt.set_subtensor(z[0], -10.0)\nout = tt.inc_subtensor(out[0, 1], 2.0)\nout = out[:5, :3]\n\ntest_inputs = [x, y]\ntest_outputs = out\n\ntest_input_vals = [\n np.tile(np.arange(10), (10, 1)),\n np.tile(np.arange(10, 20), (10, 1)),\n]",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "theano_c_fn = theano.function(test_inputs, test_outputs)\n\n# Register this linker with the config system\n# theano.compile.mode.register_linker(\"jax\", JaxLinker)\n\n# During development, we need to overwrite old classes:\ntheano.compile.mode.predefined_linkers[\"jax\"] = JaxLinker()\n\njax_mode = theano.compile.Mode(linker=\"jax\")\ntheano_jax_fn = theano.function(test_inputs, test_outputs, mode=jax_mode)\n\njax_res = theano_jax_fn(*test_input_vals)\n\n# Confirm that the result is from JAX\nassert isinstance(jax_res, jax.interpreters.xla.DeviceArray)\n\n# Confirm that the `Subtensor` slice operations are correct\nassert jax_res.shape == (5, 3)\n\n# Confirm that the `IncSubtensor` operations are correct\nassert jax_res[0, 0] == -10.0\nassert jax_res[0, 1] == -8.0\n\n# Confirm that the result is correct (according to regular Theano compilation)\npy_mode = theano.compile.Mode(linker=\"py\")\ntheano_py_fn = theano.function(test_inputs, test_outputs, mode=py_mode)\npy_res = theano_py_fn(*test_input_vals)\n\nassert np.allclose(jax_res, py_res)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit theano_jax_fn(*test_input_vals)",
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": "417 µs ± 61.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit theano_py_fn(*test_input_vals)",
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": "2.89 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit theano_c_fn(*test_input_vals)",
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": "67.5 µs ± 4.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "417/67.5",
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 15,
"data": {
"text/plain": "6.177777777777778"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "test_input_vals = [\n np.tile(np.arange(1000), (1000, 1)),\n np.tile(np.arange(1000), (1000, 1)),\n]",
"execution_count": 25,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit theano_c_fn(*test_input_vals)",
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"text": "7.77 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit theano_jax_fn(*test_input_vals)",
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": "6 ms ± 606 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import pymc3 as pm",
"execution_count": 28,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "with pm.Model() as model:\n x = pm.Normal(\"x\")",
"execution_count": 29,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model.logp_dlogp_function(mode=jax_mode)",
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": "<ipython-input-1-e78ca8561185>:270: UserWarning: JaxLinker could not JAXify graph: Could not find signature for jax_funcify: <Alloc, Apply>\n warn(\"JaxLinker could not JAXify graph: {}\".format(e))\n",
"name": "stderr"
},
{
"output_type": "execute_result",
"execution_count": 32,
"data": {
"text/plain": "<pymc3.model.ValueGradFunction at 0x13af19310>"
},
"metadata": {}
}
]
}
],
"metadata": {
"kernelspec": {
"name": "pymc3py38",
"display_name": "pymc3py38",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.8.2",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "38dc98197eed5594c5518a3971064c92",
"data": {
"description": "pymc3jax.ipynb",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/38dc98197eed5594c5518a3971064c92"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment