Skip to content

Instantly share code, notes, and snippets.

@twiecki
Last active January 1, 2021 11:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save twiecki/e758db2c3d2df5f3368fc49e6087e58f to your computer and use it in GitHub Desktop.
Save twiecki/e758db2c3d2df5f3368fc49e6087e58f to your computer and use it in GitHub Desktop.
Skeleton to write graph optimizer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy as np\nimport theano.tensor as tt\nfrom theano import config\nfrom theano.compile import optdb\nfrom theano.gof.fg import FunctionGraph\nfrom theano.gof.graph import inputs as tt_inputs\nfrom theano.gof.opt import EquilibriumOptimizer, PatternSub\nfrom theano.gof.optdb import Query\nfrom theano.printing import debugprint as tt_dprint\nfrom theano.tensor.opt import get_clients\n\n# We don't need to waste time compiling graphs to C\nconfig.cxx = \"\"\n\n\n# a / b -> a * 1/b, for a != 1 and b != 1\ndiv_to_mul_pattern = PatternSub(\n (tt.true_div, \"a\", \"b\"),\n (tt.mul, \"a\", (tt.inv, \"b\")),\n allow_multiple_clients=True,\n name=\"div_to_mul\",\n tracks=[tt.true_div],\n get_nodes=get_clients,\n)\n\n# a - b -> a + (-b)\nsub_to_add_pattern = PatternSub(\n (tt.sub, \"a\", \"b\"),\n (tt.add, \"a\", (tt.neg, \"b\")),\n allow_multiple_clients=True,\n name=\"sub_to_add\",\n tracks=[tt.sub],\n get_nodes=get_clients,\n)\n\n# a * (x + y) -> a * x + a * y\ndistribute_mul_pattern = PatternSub(\n (tt.mul, \"a\", (tt.add, \"x\", \"y\")),\n (tt.add, (tt.mul, \"a\", \"x\"), (tt.mul, \"a\", \"y\")),\n allow_multiple_clients=True,\n name=\"distribute_mul\",\n tracks=[tt.mul],\n get_nodes=get_clients,\n)\n\nfrom theano.scalar import float64, add, mul, true_div\n \nclass RemoveNormalizingConstants(gof.GlobalOptimizer):\n def add_requirements(self, fgraph):\n fgraph.attach_feature(toolbox.ReplaceValidate())\n\n def apply(self, fgraph):\n for node in fgraph.toposort():\n #print(node)\n if node.op == tt.add:\n x, y = node.inputs\n z = node.outputs[0]\n # Find if value occurs in either branch\n #import pdb; pdb.set_trace()\n if x.name == \"value\" or y.name == \"value\":\n print(\"value found in subgraph\")\n # Mark subgraph as not to be deleted\n \n\nexpand_opt = EquilibriumOptimizer(\n [div_to_mul_pattern, distribute_mul_pattern, sub_to_add_pattern, RemoveNormalizingConstants()],\n ignore_newtrees=False,\n tracks_on_change_inputs=True,\n max_use_ratio=config.optdb__max_use_ratio,\n)\n\n\ndef optimize_graph(fgraph, include=[\"canonicalize\"], custom_opt=None, **kwargs):\n if not isinstance(fgraph, FunctionGraph):\n inputs = tt_inputs([fgraph])\n fgraph = FunctionGraph(inputs, [fgraph], clone=False)\n\n canonicalize_opt = optdb.query(Query(include=include, **kwargs))\n _ = canonicalize_opt.optimize(fgraph)\n\n if custom_opt:\n custom_opt.optimize(fgraph)\n\n return fgraph\n\n\ntau = tt.dscalar(\"tau\")\nvalue = tt.dscalar(\"value\")\nmu = tt.dscalar(\"mu\")\n\nlogp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0\n\nlogp_fg = optimize_graph(logp, custom_opt=expand_opt)\n\ntt_dprint(logp_fg)\n\n# TODO: Remove additive terms that do not contain the desired terms (e.g. `mu`\n# and `tau` when is only a function of `mu`, `tau`)\n\n# This is what we want from the optimization\n# logp_goal = -tau * (value - mu) ** 2",
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": "value found in subgraph\nElemwise{add,no_inplace} [id A] '' 8\n |Elemwise{mul,no_inplace} [id B] '' 7\n | |TensorConstant{0.5} [id C]\n | |Elemwise{mul,no_inplace} [id D] '' 6\n | |TensorConstant{-1.0} [id E]\n | |tau [id F]\n | |Elemwise{pow,no_inplace} [id G] '' 5\n | |Elemwise{add,no_inplace} [id H] '' 4\n | | |value [id I]\n | | |Elemwise{neg,no_inplace} [id J] '' 3\n | | |mu [id K]\n | |TensorConstant{2} [id L]\n |Elemwise{mul,no_inplace} [id M] '' 2\n |TensorConstant{0.5} [id C]\n |Elemwise{log,no_inplace} [id N] '' 1\n |Elemwise{mul,no_inplace} [id O] '' 0\n |TensorConstant{0.15915494309189535} [id P]\n |tau [id F]\n",
"name": "stdout"
}
]
}
],
"metadata": {
"kernelspec": {
"name": "pymc3theano",
"display_name": "pymc3theano",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.8.5",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "e758db2c3d2df5f3368fc49e6087e58f",
"data": {
"description": "Skeleton to write graph optimizer",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/e758db2c3d2df5f3368fc49e6087e58f"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment