Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active August 15, 2018 00:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/e3247ff6b8de52f5adf31773dc3a41e2 to your computer and use it in GitHub Desktop.
Save jamesr66a/e3247ff6b8de52f5adf31773dc3a41e2 to your computer and use it in GitHub Desktop.
commit 60a1d2169b3e1efff3a08b5b534a4967863a4dfa
Author: James Reed <jamesreed@fb.com>
Date: Tue Aug 14 18:32:08 2018 -0400
Peephole pass to erase aten::index before ONNX export
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index f0a9ee4eb..93eee7e7a 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -702,6 +702,15 @@ class TestCaffe2Backend(unittest.TestCase):
x = torch.randn(*shape)
self.run_model_test(MyModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
+ def test_index_desugar(self):
+ def TestIndexing(torch.nn.Module):
+ def forward(self, x, y):
+ return x[y, y]
+
+ x = torch.rand(3, 4, 5)
+ y = torch.arange(2)
+ self.run_model_test(TestIndexing(), train=False, input=(x,y), batch_size=BATCH_SIZE, use_gpu=False)
+
def test_repeat(self):
class MyModel(torch.nn.Module):
def __init__(self):
diff --git a/test/test_jit.py b/test/test_jit.py
index 74611b738..f7d19e00a 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5302,6 +5302,23 @@ def func(t):
data = reader.get_record_with_key(offset)
assert(data == buffers[i])
+ def test_remove_index_lists(self):
+ def test_indexing(x, y):
+ return x[y, y]
+
+ x = torch.rand(3, 4, 5)
+ y = torch.arange(2)
+ test_indexing = torch.jit.trace(x, y)(test_indexing)
+
+ reference = test_indexing(x, y)
+
+ torch._C._jit_pass_eliminate_index_lists(test_indexing.graph)
+ torch._C._jit_pass_dce(test_indexing.graph)
+ torch._C._jit_pass_lint(test_indexing.graph)
+
+ compare = test_indexing(x, y)
+ np.testing.assert_allclose(reference, compare)
+
class TestEndToEndHybridFrontendModels(JitTestCase):
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 98f77213c..138c1625b 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -158,6 +158,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/passes/loop_unrolling.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lower_grad_of.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lower_tuples.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/eliminate_index_lists.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/peephole.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/remove_expands.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/shape_analysis.cpp
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 1b715c4b3..f483de2a2 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -11,6 +11,7 @@
#include "torch/csrc/jit/passes/onnx.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/erase_number_types.h"
+#include "torch/csrc/jit/passes/eliminate_index_lists.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/passes/canonicalize.h"
@@ -82,6 +83,7 @@ void initJITBindings(PyObject *module) {
.def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) {
return ConstantPropagation(g);
})
+ .def("_jit_pass_eliminate_index_lists", eraseIndexWithLists)
.def("_jit_run_cpp_tests", [] {
// We have to release the GIL inside this method, because if we happen to
// initialize the autograd engine in these tests, the newly spawned worker threads will
diff --git a/torch/csrc/jit/passes/eliminate_index_lists.cpp b/torch/csrc/jit/passes/eliminate_index_lists.cpp
new file mode 100644
index 000000000..1e1bc93ec
--- /dev/null
+++ b/torch/csrc/jit/passes/eliminate_index_lists.cpp
@@ -0,0 +1,41 @@
+#include "torch/csrc/jit/passes/eliminate_index_lists.h"
+
+namespace torch { namespace jit {
+
+void eraseIndexWithLists(Block* block) {
+ auto g = block->owningGraph();
+ for (auto it = block->nodes().begin(), end = block->nodes().end();
+ it != end;) {
+ Node * n = *it;
+ ++it;
+
+ for (auto b : n->blocks()) {
+ eraseIndexWithLists(b);
+ }
+
+ // Replace a sequence of ListConstruct -> Index with a series of index_select
+ // ops, one for each dimension in the list specification.
+ if (n->kind() == prim::ListConstruct && n->output()->uses().size() == 1
+ && n->output()->uses()[0].user->kind() == aten::index) {
+ Node *index_node = n->output()->uses()[0].user;
+ WithInsertPoint guard(n);
+ Value *self = index_node->inputs()[0]; // Note this is carried across iterations
+ for (size_t dim = 0; dim < n->inputs().size(); ++dim) {
+ at::ArrayRef<NamedValue> input_args = {
+ /*self=*/NamedValue(self),
+ /*dim=*/NamedValue(g->insertConstant((int64_t)dim)),
+ /*index=*/NamedValue(n->inputs()[dim])
+ };
+ self = g->insert(aten::index_select, input_args);
+ } // for (size_t dim = 0; ...
+ index_node->output()->replaceAllUsesWith(self);
+ // Let DCE clean up the original nodes.
+ } // if (n->kind() == prim::ListConstruct ...
+ }
+}
+
+void eraseIndexWithLists(Graph* graph) {
+ eraseIndexWithLists(graph->block());
+}
+
+}} // namespace torch::jit
diff --git a/torch/csrc/jit/passes/eliminate_index_lists.h b/torch/csrc/jit/passes/eliminate_index_lists.h
new file mode 100644
index 000000000..8fef2825e
--- /dev/null
+++ b/torch/csrc/jit/passes/eliminate_index_lists.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include "torch/csrc/jit/ir.h"
+
+namespace torch { namespace jit {
+
+TORCH_API void eraseIndexWithLists(Graph* graph);
+
+
+}} // namespace torch::jit
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index b770b900c..62966702d 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -120,6 +120,10 @@ def _optimize_graph(graph, operator_export_type):
torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
+ torch._C._jit_pass_eliminate_index_lists(graph)
+ torch._C._jit_pass_dce(graph)
+ torch._C._jit_pass_lint(graph)
+
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
torch._C._jit_pass_peephole(graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment