Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created February 3, 2021 21:50
Show Gist options
  • Save jamesr66a/64c9301ccbde8cec7905818c6453689e to your computer and use it in GitHub Desktop.
Save jamesr66a/64c9301ccbde8cec7905818c6453689e to your computer and use it in GitHub Desktop.
commit ade784875322bf8fbb620b02739337a664191389
Author: James Reed <jamesreed@fb.com>
Date: Wed Feb 3 13:49:58 2021 -0800
writing transformations
diff --git a/docs/source/fx.rst b/docs/source/fx.rst
index 1d638cc15f..aebdc507a8 100644
--- a/docs/source/fx.rst
+++ b/docs/source/fx.rst
@@ -12,7 +12,172 @@ Overview
Writing Transformations
-----------------------
-TODO
+What is a FX transform? Essentially, it's a function that looks like this.
+
+::
+
+ def transform(m: nn.Module) -> nn.Module:
+ fx_model: GraphModule = fx.symbolice_trace(m)
+ new_model = ...
+ return new_model
+
+Your transform will take in an :class:`torch.nn.Module`, convert it into a
+:class:`GraphModule` with :meth:``symbolic_trace``, and return a new
+``nn.Module``. You should think of the ``nn.Module`` that your FX transform
+returns as identical to a regular ``nn.Module`` -- you can pass it to another
+FX transform, you can pass it to TorchScript, or you can
+run it. Ensuring that the inputs and outputs of your FX transform are a
+``nn.Module`` will allow for composability.
+
+Given that you’ve passed in an ``nn.Module`` that has been traced into a
+graph, there are now two primary approaches you can take to building a new
+graph.
+
+Graph Manipulation
+^^^^^^^^^^^^^^^^^^
+
+One approach to building this new graph is to simply transform your old
+one. To aid in this, we can simply take the graph we obtain from
+symbolic tracing and modify it. For example, let’s say we desire to
+replace ``torch.add`` with ``torch.mul``.
+
+::
+
+ # Sample module
+ class M(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.add(x, y)
+
+ def transform(m: nn.Module) -> nn.Module:
+ fx_model: GraphModule = fx.symbolic_trace(m)
+ # FX represents its graph as an ordered list of nodes, so we can iterate through them.
+ for node in fx_model.graph.nodes:
+ # Checks if we're calling a function (i.e: torch.add)
+ if node.op == 'call_function':
+ # The target attribute is the function that call_function calls.
+ if node.target == torch.add:
+ node.target = torch.mul
+
+ traced.lint() # Does some checks to make sure the graph is well-formed.
+ traced.recompile() # regenerates the python code that corresponds to the graph.
+
+We can also do more involved graph rewrites, such as deleting or appending
+nodes. after a node. To aid in these transformations, FX has utility
+functions for transforming the graph that can be found in :class:`Graph`. An
+example of using these APIs to append a relu can be found below.
+
+::
+
+ with traced.graph.inserting_after(node): # Specifies the insertion point
+ new_node = traced.graph.call_function(torch.relu, args=(node,)) # builds a new relu node
+ node.replace_all_uses_with(new_node)
+
+This approach is also a good fit for graph optimizations such as
+`conv/batch norm
+fusion! <https://github.com/pytorch/pytorch/blob/ec86cec20a8a2312a2295d7bc8be6e88256a2de4/torch/fx/experimental/fuser.py>`__
+
+For simple transformations that only consist of substitutions, you can also
+make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/master/torch/fx/subgraph_rewriter.py>`__
+
+In general, writing your transformation through graph manipulation is a good
+fit if you need to make a few small changes or if you need to match multiple
+nodes at once. However, if you need to entirely rewrite your graph, you may
+want to look at constructing your graph with Proxies (i.e. retracing).
+
+Examples
+~~~~~~~~
+
+- `Replace one
+ op <https://github.com/pytorch/pytorch/blob/master/torch/fx/examples/replace_op.py>`__
+- `Conv/Batch Norm
+ fusion <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py>`__
+- `Quantization <https://github.com/pytorch/pytorch/tree/master/torch/quantization/fx>`__
+
+Proxy/Retracing
+^^^^^^^^^^^^^^^
+
+Although most transformations can be implemented as graph
+transformations, transformations that involve a lot of graph rewrites
+are often more easily represented through retracing. For example, let’s
+imagine that we wanted to write a pass that decomposed
+PyTorch functions. It would transform every ``F.relu(x)``
+into ``(x > 0)*x``. One possibility would be to perform the requisite
+graph rewriting to insert the comparison and multiplication after the
+``F.relu``, and then clean up the original ``F.relu``. However, graph
+manipulation can be awkward, and it’s often easier to implicitly
+generate the graph by retracing.
+
+To use this method, we write the graph that we want inserted as regular
+PyTorch code and pass in Proxy objects. These Proxy objects
+will capture the operations that are performed on them and append them to
+the graph.
+
+::
+
+ # Note that this decomposition rule can be read as regular Python
+ def relu_decomposition(x):
+ return (x > 0)*x
+
+ decomposition_rules = {}
+ decomposition_rules[F.relu] = relu_decomposition
+
+ def decompose(model: torch.nn.Module) -> torch.nn.Module:
+ model = fx.symbolic_trace(model)
+ new_graph = fx.Graph()
+ env = {}
+ for node in model.graph.nodes:
+ if node.op == 'call_function' and node.target in decomposition_rules:
+ # By wrapping the arguments with proxies, we can dispatch to
+ # the appropriate decomposition rule and add it to the graph by
+ # symbolically tracing it.
+ proxy_args = [fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args]
+ new_node = decomposition_rules[node.target](*proxy_args).node
+ env[node.name] = new_node
+ else:
+ new_node = new_graph.node_copy(node, lambda x: env[x.name])
+ env[node.name] = new_node
+ return fx.GraphModule(model, new_graph)
+
+In addition to avoiding explicit graph manipulation, using Proxies also allows you to
+specify your rewrite rules as native Python code. For transformations
+that require a large amount of rewrite rules (such as vmap or grad),
+this can often improve readability and maintainability of the rules.
+
+TODO: Example transformations (need to be included first)
+
+The Interpreter Pattern
+^^^^^^^^^^^^^^^^^^^^^^^
+
+In addition to FX passes that take in a module and return a module,
+there may be other things you wish to do with the FX graph. For example,
+let’s say that you’d like to obtain
+the shape information of tensors in your graph. In this case, instead of
+looping over the FX graph and modifying it, you can write an interpreter
+on top of the FX graph! As the FX IR is quite simple, it’s easy to
+reimplement an interpreter that also captures your desired attributes.
+
+As this pattern is quite useful, we we can also use an abstraction of this pattern
+-- the `Interpreter
+<https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py>`__.
+You can see an example using this for `shape propagation
+<https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__
+, which reinterprets the FX graph with example inputs while annotating the
+graph with the shapes.
+
+Reinterpreting the FX graph is generally most useful when you want
+runtime information that FX typically doesn’t capture (due to being a
+symbolic trace). This can be used for capturing shape information for
+downstream passes, but it can also be used to capture other information
+about execution.
+TODO: Add roofline analysis pass once it gets merged.
+
+Examples
+~~~~~~~~
+
+- `Shape
+ Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/shape_prop.py>`__
+- `Roofline
+ Analyzer <https://github.com/pytorch/pytorch/blob/a9f88511b8155ba9620730fb175dee8c54e346d5/torch/fx/experimental/cost_model.py>`__
Debugging
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment