Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active May 26, 2023 15:07
Show Gist options
  • Save ricardoV94/deaf7b18660588faac1c30bc5b31c011 to your computer and use it in GitHub Desktop.
Save ricardoV94/deaf7b18660588faac1c30bc5b31c011 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTensor graph rewrites from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manipulating nodes directly"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This section walks through the low level details of PyTensor graph manipulation. \n",
"Users are not supposed to work or even be aware of these details, but it may be helpful for developers.\n",
"We start with very **bad practices** and move on towards the **right** way of doing rewrites.\n",
"\n",
"* [Graph structures](https://pytensor.readthedocs.io/en/latest/extending/graphstructures.html)\n",
"is a required precursor to this guide\n",
"* [Graph rewriting](https://pytensor.readthedocs.io/en/latest/extending/graph_rewriting.html) provides the user-level summary of what is covered in here. Feel free to revisit once you're done here.\n",
"\n",
"As described in [Graph structures](https://pytensor.readthedocs.io/en/latest/extending/graphstructures.html) PyTensor graphs are composed of sequences `Apply` nodes, which link `Variable`s\n",
"that form the inputs and outputs of a computational `Op`eration.\n",
"\n",
"The list of inputs of an Apply node can be changed inplace to modify the computational path that leads to it.\n",
"Consider the following simple example:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%env PYTENSOR_FLAGS=cxx=\"\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mul\n",
" ├─ Log\n",
" │ └─ Add\n",
" │ ├─ TensorConstant{1}\n",
" │ └─ x\n",
" └─ TensorConstant{2}\n"
]
}
],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt\n",
"\n",
"x = pt.scalar(\"x\")\n",
"y = pt.log(1 + x)\n",
"out = y * 2\n",
"pytensor.dprint(out, id_type=\"\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A standard rewrite replaces `pt.log(1 + x)` by the more stable form `pt.log1p(x)`.\n",
"We can do this by changing the inputs of the `out` node inplace."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mul\n",
" ├─ Log1p\n",
" │ └─ x\n",
" └─ TensorConstant{2}\n"
]
}
],
"source": [
"out.owner.inputs[0] = pt.log1p(x)\n",
"pytensor.dprint(out, id_type=\"\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are two problems with this direct approach:\n",
"1. We are modifying variables in place\n",
"2. We have to know which nodes have as input the variable we want to replace\n",
"\n",
"Point 1. is important because some rewrites are \"destructive\" and the user may want to reuse the same graph in multiple functions.\n",
"\n",
"Point 2. is important because it forces us to shift the focus of attention from the operation we want to rewrite to the variables where the operation is used. It also risks unneccessary duplication of variables, if we perform the same replacement independently for each use. This could make graph rewriting consideraby slower!\n",
"\n",
"PyTensor makes use of :class:`FunctionGraph`s to solve these two issues.\n",
"By default, a FunctionGraph will clone all the variables between the inputs and outputs,\n",
"so that the corresponding graph can be rewritten.\n",
"In addition, it will create a :term:`clients` dictionary that maps all the variables to the nodes where they are used.\n",
"\n",
"\n",
"Let's see how we can use a FunctionGraph to achieve the same rewrite:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before:\n",
"\n",
"Mul [id A]\n",
" ├─ Log [id B]\n",
" │ └─ Add [id C]\n",
" │ ├─ TensorConstant{1} [id D]\n",
" │ └─ x [id E]\n",
" └─ TensorConstant{2} [id F]\n",
"True_div [id G]\n",
" ├─ TensorConstant{2} [id H]\n",
" └─ Log [id B]\n",
" └─ ···\n",
"\n",
"After:\n",
"\n",
"Mul [id A]\n",
" ├─ Log1p [id B]\n",
" │ └─ x [id C]\n",
" └─ TensorConstant{2} [id D]\n",
"True_div [id E]\n",
" ├─ TensorConstant{2} [id F]\n",
" └─ Log1p [id B]\n",
" └─ ···\n"
]
}
],
"source": [
"from pytensor.graph import FunctionGraph\n",
"\n",
"x = pt.scalar(\"x\")\n",
"y = pt.log(1 + x)\n",
"out1 = y * 2\n",
"out2 = 2 / y\n",
"\n",
"# Create an empty dictionary which FunctionGraph will populate\n",
"# with the mappings from old variables to cloned ones\n",
"memo = {}\n",
"fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n",
"fg_x = memo[x]\n",
"fg_y = memo[y]\n",
"print(\"Before:\\n\")\n",
"pytensor.dprint(fg.outputs)\n",
"\n",
"# Create expression of interest with cloned variables\n",
"fg_y_repl = pt.log1p(fg_x)\n",
"\n",
"# Update all uses of old variable to new one\n",
"# Each entry in the clients dictionary, \n",
"# contains a node and the input index where the variable is used\n",
"# Note: Some variables could be used multiple times in a single node\n",
"for client, idx in fg.clients[fg_y]:\n",
" client.inputs[idx] = fg_y_repl\n",
" \n",
"print(\"\\nAfter:\\n\")\n",
"pytensor.dprint(fg.outputs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that both uses of `log(1 + x)` were replaced by the new `log1p(x)`.\n",
"\n",
"It would probably be a good idea to update the clients dictionary\n",
"if we wanted to perform another rewrite.\n",
"\n",
"There are a couple of other variables in the FunctionGraph that we would also want to update,\n",
"but there is no point to doing all this bookeeping manually. \n",
"FunctionGraph offers a `replace` method that takes care of all this for the user."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before:\n",
"\n",
"Mul [id A]\n",
" ├─ Log [id B]\n",
" │ └─ Add [id C]\n",
" │ ├─ TensorConstant{1} [id D]\n",
" │ └─ x [id E]\n",
" └─ TensorConstant{2} [id F]\n",
"True_div [id G]\n",
" ├─ TensorConstant{2} [id H]\n",
" └─ Log [id B]\n",
" └─ ···\n",
"\n",
"After:\n",
"\n",
"Mul [id A]\n",
" ├─ Log1p [id B]\n",
" │ └─ x [id C]\n",
" └─ TensorConstant{2} [id D]\n",
"True_div [id E]\n",
" ├─ TensorConstant{2} [id F]\n",
" └─ Log1p [id B]\n",
" └─ ···\n"
]
}
],
"source": [
"# We didn't modify the variables in place so we can just reuse them!\n",
"memo = {}\n",
"fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)\n",
"fg_x = memo[x]\n",
"fg_y = memo[y]\n",
"print(\"Before:\\n\")\n",
"pytensor.dprint(fg.outputs)\n",
"\n",
"# Create expression of interest with cloned variables\n",
"fg_y_repl = pt.log1p(fg_x)\n",
"fg.replace(fg_y, fg_y_repl)\n",
" \n",
"print(\"\\nAfter:\\n\")\n",
"pytensor.dprint(fg.outputs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There is still one big limitation with this approach.\n",
"We have to know in advance \"where\" the variable we want to replace is present.\n",
"It also doesn't scale to multiple instances of the same pattern.\n",
"\n",
"A more sensible approach would be to iterate over the nodes in the FunctionGraph\n",
"and apply the rewrite wherever `log(1 + x)` may be present.\n",
"\n",
"To keep code organized we will create a function \n",
"that takes as input a node and returns a valid replacement."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.graph import Constant\n",
"\n",
"def local_log1p(node):\n",
" # Check that this node is a Log op\n",
" if node.op != pt.log:\n",
" return None\n",
" \n",
" # Check that the input is another node (it could be an input variable)\n",
" add_node = node.inputs[0].owner\n",
" if add_node is None:\n",
" return None\n",
" \n",
" # Check that the input to this node is an Add op\n",
" # with 2 inputs (Add can have more inputs)\n",
" if add_node.op != pt.add or len(add_node.inputs) != 2:\n",
" return None\n",
" \n",
" # Check wether we have add(1, y) or add(x, 1)\n",
" [x, y] = add_node.inputs\n",
" if isinstance(x, Constant) and x.data == 1:\n",
" return [pt.log1p(y)]\n",
" if isinstance(y, Constant) and y.data == 1:\n",
" return [pt.log1p(x)]\n",
"\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mul [id A] 1\n",
" ├─ Log1p [id B] 0\n",
" │ └─ x [id C]\n",
" └─ TensorConstant{2} [id D]\n",
"True_div [id E] 2\n",
" ├─ TensorConstant{2} [id F]\n",
" └─ Log1p [id B] 0\n",
" └─ ···\n"
]
}
],
"source": [
"# We no longer need the memo, because our rewrite works with the node information\n",
"fg = FunctionGraph([x], [out1, out2], clone=True)\n",
"\n",
"# Toposort gives a list of all nodes in a graph in topological order\n",
"# The strategy of iteration can be important when we are dealing with multiple rewrites\n",
"for node in fg.toposort():\n",
" repl = local_log1p(node)\n",
" if repl is None:\n",
" continue\n",
" # We should get one replacement of each output of the node\n",
" assert len(repl) == len(node.outputs)\n",
" # We could use `fg.replace_all` to avoid this loop\n",
" for old, new in zip(node.outputs, repl):\n",
" fg.replace(old, new)\n",
"\n",
"pytensor.dprint(fg);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is starting to look much more scalable!\n",
"\n",
"We are still reinventing may wheels that already exist in PyTensor, but we're getting there.\n",
"Before we move up the ladder of abstraction, let's discuss two gotchas:\n",
"\n",
"1. The replacement variables should have types that are compatible with the original ones.\n",
"2. We have to be careful about introducing circular dependencies\n",
"\n",
"For 1. let's look at a simple graph simplication, \n",
"where we replace a costly operation that is ultimately multiplied by zero."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable Alloc.0) into Type Vector(float64, shape=(?,)). You can try to manually convert Alloc.0 into a Vector(float64, shape=(?,)).\n"
]
}
],
"source": [
"x = pt.vector(\"x\", dtype=\"float32\")\n",
"zero = pt.zeros(())\n",
"zero.name = \"zero\"\n",
"y = pt.exp(x) * zero\n",
"\n",
"fg = FunctionGraph([x], [y], clone=False)\n",
"try:\n",
" fg.replace(y, pt.zeros(()))\n",
"except TypeError as exc:\n",
" print(f\"TypeError: {exc}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first achievement of a new PyTensor developer is unlocked by stumbling upon an error like that!\n",
"\n",
"It's important to keep in mind the Tensor part of PyTensor.\n",
"\n",
"The problem here is that we are trying to \n",
"replace the `y` variable which is a float32 vector \n",
"by the `zero` variable which is a float64 scalar!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mul <Vector(float64, shape=(?,))>\n",
" ├─ Exp <Vector(float32, shape=(?,))>\n",
" │ └─ x <Vector(float32, shape=(?,))>\n",
" └─ ExpandDims{axis=0} <Vector(float64, shape=(1,))>\n",
" └─ Alloc <Scalar(float64, shape=())> 'zero'\n",
" └─ TensorConstant{0.0} <Scalar(float64, shape=())>\n"
]
}
],
"source": [
"pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Alloc <Vector(float64, shape=(?,))> 'vector_zero'\n",
" ├─ TensorConstant{0.0} <Scalar(float64, shape=())>\n",
" └─ Subtensor{i} <Scalar(int64, shape=())>\n",
" ├─ Shape <Vector(int64, shape=(1,))>\n",
" │ └─ x <Vector(float32, shape=(?,))>\n",
" └─ ScalarConstant{0} <int64>\n"
]
}
],
"source": [
"vector_zero = pt.zeros(x.shape)\n",
"vector_zero.name = \"vector_zero\"\n",
"fg.replace(y, vector_zero)\n",
"pytensor.dprint(fg.outputs, id_type=\"\", print_type=True);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now to the second (less common) gotcha. Introducing circular dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A] 'z'\n",
" ├─ Add [id B] 'y'\n",
" │ ├─ Add [id A] 'z'\n",
" │ │ └─ ···\n",
" │ └─ TensorConstant{1} [id C]\n",
" └─ TensorConstant{1} [id D]\n"
]
}
],
"source": [
"x = pt.scalar(\"x\")\n",
"y = x + 1\n",
"y.name = \"y\"\n",
"z = y + 1\n",
"z.name = \"z\"\n",
"\n",
"fg = FunctionGraph([x], [z], clone=False)\n",
"fg.replace(x, z)\n",
"pytensor.dprint(fg.outputs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oops! There is not much to say about this one, other than don't do it!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using graph rewriters"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.graph.rewriting.basic import NodeRewriter\n",
"\n",
"class LocalLog1pNodeRewriter(NodeRewriter):\n",
" \n",
" def tracks(self):\n",
" return [pt.log]\n",
" \n",
" def transform(self, fgraph, node):\n",
" return local_log1p(node) \n",
" \n",
" def __str__(self):\n",
" return \"local_log1p\"\n",
" \n",
" \n",
"local_log1p_node_rewriter = LocalLog1pNodeRewriter()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A `NodeRewriter` is required to implement only the `transform` method.\n",
"As before, this method expects a node and should return a valid replacement for each output or `None`.\n",
"\n",
"We also receive the `FunctionGraph` object, \n",
"as some node rewriters may want to use global information to decide whether to return a replacement or not.\n",
"\n",
"For example some rewrites that skip intermediate computations \n",
"may not be useful if those intermediate computations are used by other variables.\n",
"\n",
"The `tracks` optional method is very useful for filtering out \"useless\" rewrites.\n",
"When `NodeRewriter`s only applies to a specific rare `Op` \n",
"it can be ignored completely when that `Op` is not present in the graph.\n",
"\n",
"On its own, a `NodeRewriter` isn't any better than what we had before. \n",
"Where it becomes useful is when included inside a `GraphRewriter`,\n",
"which will apply it to a whole `FunctionGraph`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exp [id A]\n",
" └─ Log1p [id B]\n",
" └─ x [id C]\n"
]
}
],
"source": [
"from pytensor.graph.rewriting.basic import in2out\n",
"\n",
"x = pt.scalar(\"x\")\n",
"y = pt.log(1 + x)\n",
"out = pt.exp(y)\n",
"\n",
"fg = FunctionGraph([x], [out])\n",
"in2out(local_log1p_node_rewriter, name=\"local_log1p\").rewrite(fg)\n",
"\n",
"pytensor.dprint(fg.outputs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we used `in2out` which creates a `GraphRewriter` \n",
"(specifically a `WalkingGraphRewriter`) \n",
"which walks from the inputs to the outputs of a FunctionGraph\n",
"trying to apply whatever nodes are \"registered\" in it.\n",
"\n",
"Wrapping simple functions in `NodeRewriters` is so common that PyTensor \n",
"offers a decorator for it.\n",
"\n",
"Let's create a new rewrite that removes useless `abs(exp(x)) -> exp(x)`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.graph.rewriting.basic import node_rewriter\n",
"\n",
"@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n",
" exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n",
" return exp_node.outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another very useful helper is the `PatternNodeRewriter`,\n",
"which allows you to specify a rewrite via \"template matching\"."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.graph.rewriting.basic import PatternNodeRewriter\n",
"\n",
"local_useless_abs_square = PatternNodeRewriter(\n",
" (pt.abs, (pt.pow, \"x\", 2)),\n",
" (pt.pow, \"x\", 2),\n",
" name=\"local_useless_abs_square\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is very useful for simple Elemwise rewrites, but becomes a bit cumbersome with Ops that must be parametrized\n",
"everytime they are used."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pow [id A]\n",
" ├─ Log1p [id B]\n",
" │ └─ Exp [id C]\n",
" │ └─ x [id D]\n",
" └─ TensorConstant{2} [id E]\n"
]
}
],
"source": [
"x = pt.scalar(\"x\")\n",
"y = pt.exp(x)\n",
"z = pt.abs(y)\n",
"w = pt.log(1.0 + z)\n",
"out = pt.abs(w ** 2)\n",
"\n",
"fg = FunctionGraph([x], [out])\n",
"in2out_rewrite = in2out(\n",
" local_log1p_node_rewriter, \n",
" local_useless_abs_exp, \n",
" local_useless_abs_square,\n",
" name=\"custom_rewrites\"\n",
")\n",
"in2out_rewrite.rewrite(fg)\n",
"\n",
"pytensor.dprint(fg.outputs);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Besides `WalkingGraphRewriter`s, \n",
"there are `SequentialGraphRewriter`s which apply a set of GraphRewriters sequentially\n",
"and `EquilibriumGraphRewriter`s which apply a set of GraphRewriters (and NodeRewriters) repeatedly until the graph stops changing.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Registering graph rewriters in a database"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, at the top of the rewrite mountain, there are `RewriteDatabase`s!\n",
"These allow \"querying\" for subsets of rewrites registered in a database.\n",
"\n",
"Most users trigger this when they change the `mode` of a PyTensor function\n",
"`mode=\"FAST_COMPILE\"` or `mode=\"FAST_RUN\"`, or `mode=\"JAX\"` will lead to a different rewrite database query \n",
"to be applied to the function before compilation.\n",
"\n",
"The most relevant `RewriteDatabase` is called `optdb` and contains all the standard rewrites in PyTensor.\n",
"You can manually register your `GraphRewriter` in it. \n",
"\n",
"More often than not, you will want to register your rewrite in a pre-existing sub-database, like \n",
"`canonicalize`, `stabilize`, or `specialize`."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.compile.mode import optdb"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"optdb[\"canonicalize\"].register(\n",
" \"local_log1p_node_rewriter\",\n",
" local_log1p_node_rewriter,\n",
" \"fast_compile\",\n",
" \"fast_run\",\n",
" \"custom\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n",
"\n",
"Abs [id A] 4\n",
" └─ Pow [id B] 3\n",
" ├─ Log1p [id C] 2\n",
" │ └─ Abs [id D] 1\n",
" │ └─ Exp [id E] 0\n",
" │ └─ x [id F]\n",
" └─ TensorConstant{2} [id G]\n"
]
}
],
"source": [
"with pytensor.config.change_flags(optimizer_verbose = True):\n",
" fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n",
" \n",
"print(\"\")\n",
"pytensor.dprint(fn);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There's also a decorator that automatically registers a `NodeRewriter` in one of these standard databases.\n",
"(It's placed in a weird location)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.tensor.rewriting.basic import register_canonicalize\n",
"\n",
"@register_canonicalize(\"custom\")\n",
"@node_rewriter(tracks=[pt.abs])\n",
"def local_useless_abs_exp(fgraph, node):\n",
" # Because of the tracks we don't need to check \n",
" # that `node` has a `Sign` Op.\n",
" # We still need to check whether it's input is an `Abs` Op\n",
" exp_node = node.inputs[0].owner\n",
" if exp_node is None or exp_node.op != pt.exp:\n",
" return None\n",
" return exp_node.outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And you can also use the decorator directly"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"local_useless_abs_square"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"register_canonicalize(local_useless_abs_square, \"custom\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewriting: rewrite local_useless_abs_square replaces Abs.0 of Abs(Pow.0) with Pow.0 of Pow(Log.0, TensorConstant{2})\n",
"rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)\n",
"rewriting: rewrite local_useless_abs_exp replaces Abs.0 of Abs(Exp.0) with Exp.0 of Exp(x)\n",
"\n",
"Pow [id A] 2\n",
" ├─ Log1p [id B] 1\n",
" │ └─ Exp [id C] 0\n",
" │ └─ x [id D]\n",
" └─ TensorConstant{2} [id E]\n"
]
}
],
"source": [
"with pytensor.config.change_flags(optimizer_verbose = True):\n",
" fn = pytensor.function([x], out, mode=\"FAST_COMPILE\")\n",
" \n",
"print(\"\")\n",
"pytensor.dprint(fn);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And if you wanted to exclude your custom rewrites you can do it like this:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewriting: rewrite local_upcast_elemwise_constant_inputs replaces Add.0 of Add(TensorConstant{1.0}, Abs.0) with Add.0 of Add(DimShuffle{order=[]}.0, Abs.0)\n",
"rewriting: rewrite local_dimshuffle_lift replaces DimShuffle{order=[]}.0 of DimShuffle{order=[]}(Cast{float64}.0) with Cast{float64}.0 of Cast{float64}(TensorConstant{1.0})\n",
"rewriting: rewrite constant_folding replaces Cast{float64}.0 of Cast{float64}(TensorConstant{1.0}) with TensorConstant{1.0} of None\n",
"\n",
"Abs [id A] 5\n",
" └─ Pow [id B] 4\n",
" ├─ Log [id C] 3\n",
" │ └─ Add [id D] 2\n",
" │ ├─ TensorConstant{1.0} [id E]\n",
" │ └─ Abs [id F] 1\n",
" │ └─ Exp [id G] 0\n",
" │ └─ x [id H]\n",
" └─ TensorConstant{2} [id I]\n"
]
}
],
"source": [
"from pytensor.compile.mode import get_mode\n",
"\n",
"with pytensor.config.change_flags(optimizer_verbose = True):\n",
" fn = pytensor.function([x], out, mode=get_mode(\"FAST_COMPILE\").excluding(\"custom\"))\n",
" \n",
"print(\"\")\n",
"pytensor.dprint(fn);"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pytensor",
"language": "python",
"name": "pytensor"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment