Created
February 3, 2021 21:50
-
-
Save jamesr66a/64c9301ccbde8cec7905818c6453689e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit ade784875322bf8fbb620b02739337a664191389 | |
Author: James Reed <jamesreed@fb.com> | |
Date: Wed Feb 3 13:49:58 2021 -0800 | |
writing transformations | |
diff --git a/docs/source/fx.rst b/docs/source/fx.rst | |
index 1d638cc15f..aebdc507a8 100644 | |
--- a/docs/source/fx.rst | |
+++ b/docs/source/fx.rst | |
@@ -12,7 +12,172 @@ Overview | |
Writing Transformations | |
----------------------- | |
-TODO | |
+What is a FX transform? Essentially, it's a function that looks like this. | |
+ | |
+:: | |
+ | |
+ def transform(m: nn.Module) -> nn.Module: | |
+ fx_model: GraphModule = fx.symbolice_trace(m) | |
+ new_model = ... | |
+ return new_model | |
+ | |
+Your transform will take in an :class:`torch.nn.Module`, convert it into a | |
+:class:`GraphModule` with :meth:``symbolic_trace``, and return a new | |
+``nn.Module``. You should think of the ``nn.Module`` that your FX transform | |
+returns as identical to a regular ``nn.Module`` -- you can pass it to another | |
+FX transform, you can pass it to TorchScript, or you can | |
+run it. Ensuring that the inputs and outputs of your FX transform are a | |
+``nn.Module`` will allow for composability. | |
+ | |
+Given that you’ve passed in an ``nn.Module`` that has been traced into a | |
+graph, there are now two primary approaches you can take to building a new | |
+graph. | |
+ | |
+Graph Manipulation | |
+^^^^^^^^^^^^^^^^^^ | |
+ | |
+One approach to building this new graph is to simply transform your old | |
+one. To aid in this, we can simply take the graph we obtain from | |
+symbolic tracing and modify it. For example, let’s say we desire to | |
+replace ``torch.add`` with ``torch.mul``. | |
+ | |
+:: | |
+ | |
+ # Sample module | |
+ class M(torch.nn.Module): | |
+ def forward(self, x, y): | |
+ return torch.add(x, y) | |
+ | |
+ def transform(m: nn.Module) -> nn.Module: | |
+ fx_model: GraphModule = fx.symbolic_trace(m) | |
+ # FX represents its graph as an ordered list of nodes, so we can iterate through them. | |
+ for node in fx_model.graph.nodes: | |
+ # Checks if we're calling a function (i.e: torch.add) | |
+ if node.op == 'call_function': | |
+ # The target attribute is the function that call_function calls. | |
+ if node.target == torch.add: | |
+ node.target = torch.mul | |
+ | |
+ traced.lint() # Does some checks to make sure the graph is well-formed. | |
+ traced.recompile() # regenerates the python code that corresponds to the graph. | |
+ | |
+We can also do more involved graph rewrites, such as deleting or appending | |
+nodes. after a node. To aid in these transformations, FX has utility | |
+functions for transforming the graph that can be found in :class:`Graph`. An | |
+example of using these APIs to append a relu can be found below. | |
+ | |
+:: | |
+ | |
+ with traced.graph.inserting_after(node): # Specifies the insertion point | |
+ new_node = traced.graph.call_function(torch.relu, args=(node,)) # builds a new relu node | |
+ node.replace_all_uses_with(new_node) | |
+ | |
+This approach is also a good fit for graph optimizations such as | |
+`conv/batch norm | |
+fusion! <https://github.com/pytorch/pytorch/blob/ec86cec20a8a2312a2295d7bc8be6e88256a2de4/torch/fx/experimental/fuser.py>`__ | |
+ | |
+For simple transformations that only consist of substitutions, you can also | |
+make use of the `subgraph rewriter. <https://github.com/pytorch/pytorch/blob/master/torch/fx/subgraph_rewriter.py>`__ | |
+ | |
+In general, writing your transformation through graph manipulation is a good | |
+fit if you need to make a few small changes or if you need to match multiple | |
+nodes at once. However, if you need to entirely rewrite your graph, you may | |
+want to look at constructing your graph with Proxies (i.e. retracing). | |
+ | |
+Examples | |
+~~~~~~~~ | |
+ | |
+- `Replace one | |
+ op <https://github.com/pytorch/pytorch/blob/master/torch/fx/examples/replace_op.py>`__ | |
+- `Conv/Batch Norm | |
+ fusion <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py>`__ | |
+- `Quantization <https://github.com/pytorch/pytorch/tree/master/torch/quantization/fx>`__ | |
+ | |
+Proxy/Retracing | |
+^^^^^^^^^^^^^^^ | |
+ | |
+Although most transformations can be implemented as graph | |
+transformations, transformations that involve a lot of graph rewrites | |
+are often more easily represented through retracing. For example, let’s | |
+imagine that we wanted to write a pass that decomposed | |
+PyTorch functions. It would transform every ``F.relu(x)`` | |
+into ``(x > 0)*x``. One possibility would be to perform the requisite | |
+graph rewriting to insert the comparison and multiplication after the | |
+``F.relu``, and then clean up the original ``F.relu``. However, graph | |
+manipulation can be awkward, and it’s often easier to implicitly | |
+generate the graph by retracing. | |
+ | |
+To use this method, we write the graph that we want inserted as regular | |
+PyTorch code and pass in Proxy objects. These Proxy objects | |
+will capture the operations that are performed on them and append them to | |
+the graph. | |
+ | |
+:: | |
+ | |
+ # Note that this decomposition rule can be read as regular Python | |
+ def relu_decomposition(x): | |
+ return (x > 0)*x | |
+ | |
+ decomposition_rules = {} | |
+ decomposition_rules[F.relu] = relu_decomposition | |
+ | |
+ def decompose(model: torch.nn.Module) -> torch.nn.Module: | |
+ model = fx.symbolic_trace(model) | |
+ new_graph = fx.Graph() | |
+ env = {} | |
+ for node in model.graph.nodes: | |
+ if node.op == 'call_function' and node.target in decomposition_rules: | |
+ # By wrapping the arguments with proxies, we can dispatch to | |
+ # the appropriate decomposition rule and add it to the graph by | |
+ # symbolically tracing it. | |
+ proxy_args = [fx.Proxy(env[x.name]) if isinstance(x, fx.Node) else x for x in node.args] | |
+ new_node = decomposition_rules[node.target](*proxy_args).node | |
+ env[node.name] = new_node | |
+ else: | |
+ new_node = new_graph.node_copy(node, lambda x: env[x.name]) | |
+ env[node.name] = new_node | |
+ return fx.GraphModule(model, new_graph) | |
+ | |
+In addition to avoiding explicit graph manipulation, using Proxies also allows you to | |
+specify your rewrite rules as native Python code. For transformations | |
+that require a large amount of rewrite rules (such as vmap or grad), | |
+this can often improve readability and maintainability of the rules. | |
+ | |
+TODO: Example transformations (need to be included first) | |
+ | |
+The Interpreter Pattern | |
+^^^^^^^^^^^^^^^^^^^^^^^ | |
+ | |
+In addition to FX passes that take in a module and return a module, | |
+there may be other things you wish to do with the FX graph. For example, | |
+let’s say that you’d like to obtain | |
+the shape information of tensors in your graph. In this case, instead of | |
+looping over the FX graph and modifying it, you can write an interpreter | |
+on top of the FX graph! As the FX IR is quite simple, it’s easy to | |
+reimplement an interpreter that also captures your desired attributes. | |
+ | |
+As this pattern is quite useful, we we can also use an abstraction of this pattern | |
+-- the `Interpreter | |
+<https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py>`__. | |
+You can see an example using this for `shape propagation | |
+<https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py>`__ | |
+, which reinterprets the FX graph with example inputs while annotating the | |
+graph with the shapes. | |
+ | |
+Reinterpreting the FX graph is generally most useful when you want | |
+runtime information that FX typically doesn’t capture (due to being a | |
+symbolic trace). This can be used for capturing shape information for | |
+downstream passes, but it can also be used to capture other information | |
+about execution. | |
+TODO: Add roofline analysis pass once it gets merged. | |
+ | |
+Examples | |
+~~~~~~~~ | |
+ | |
+- `Shape | |
+ Propagation <https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/shape_prop.py>`__ | |
+- `Roofline | |
+ Analyzer <https://github.com/pytorch/pytorch/blob/a9f88511b8155ba9620730fb175dee8c54e346d5/torch/fx/experimental/cost_model.py>`__ | |
Debugging |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment