Last active
August 15, 2018 00:52
-
-
Save jamesr66a/e3247ff6b8de52f5adf31773dc3a41e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 60a1d2169b3e1efff3a08b5b534a4967863a4dfa | |
Author: James Reed <jamesreed@fb.com> | |
Date: Tue Aug 14 18:32:08 2018 -0400 | |
Peephole pass to erase aten::index before ONNX export | |
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py | |
index f0a9ee4eb..93eee7e7a 100644 | |
--- a/test/onnx/test_pytorch_onnx_caffe2.py | |
+++ b/test/onnx/test_pytorch_onnx_caffe2.py | |
@@ -702,6 +702,15 @@ class TestCaffe2Backend(unittest.TestCase): | |
x = torch.randn(*shape) | |
self.run_model_test(MyModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) | |
+ def test_index_desugar(self): | |
+ def TestIndexing(torch.nn.Module): | |
+ def forward(self, x, y): | |
+ return x[y, y] | |
+ | |
+ x = torch.rand(3, 4, 5) | |
+ y = torch.arange(2) | |
+ self.run_model_test(TestIndexing(), train=False, input=(x,y), batch_size=BATCH_SIZE, use_gpu=False) | |
+ | |
def test_repeat(self): | |
class MyModel(torch.nn.Module): | |
def __init__(self): | |
diff --git a/test/test_jit.py b/test/test_jit.py | |
index 74611b738..f7d19e00a 100644 | |
--- a/test/test_jit.py | |
+++ b/test/test_jit.py | |
@@ -5302,6 +5302,23 @@ def func(t): | |
data = reader.get_record_with_key(offset) | |
assert(data == buffers[i]) | |
+ def test_remove_index_lists(self): | |
+ def test_indexing(x, y): | |
+ return x[y, y] | |
+ | |
+ x = torch.rand(3, 4, 5) | |
+ y = torch.arange(2) | |
+ test_indexing = torch.jit.trace(x, y)(test_indexing) | |
+ | |
+ reference = test_indexing(x, y) | |
+ | |
+ torch._C._jit_pass_eliminate_index_lists(test_indexing.graph) | |
+ torch._C._jit_pass_dce(test_indexing.graph) | |
+ torch._C._jit_pass_lint(test_indexing.graph) | |
+ | |
+ compare = test_indexing(x, y) | |
+ np.testing.assert_allclose(reference, compare) | |
+ | |
class TestEndToEndHybridFrontendModels(JitTestCase): | |
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt | |
index 98f77213c..138c1625b 100644 | |
--- a/torch/CMakeLists.txt | |
+++ b/torch/CMakeLists.txt | |
@@ -158,6 +158,7 @@ set(TORCH_SRCS | |
${TORCH_SRC_DIR}/csrc/jit/passes/loop_unrolling.cpp | |
${TORCH_SRC_DIR}/csrc/jit/passes/lower_grad_of.cpp | |
${TORCH_SRC_DIR}/csrc/jit/passes/lower_tuples.cpp | |
+ ${TORCH_SRC_DIR}/csrc/jit/passes/eliminate_index_lists.cpp | |
${TORCH_SRC_DIR}/csrc/jit/passes/peephole.cpp | |
${TORCH_SRC_DIR}/csrc/jit/passes/remove_expands.cpp | |
${TORCH_SRC_DIR}/csrc/jit/passes/shape_analysis.cpp | |
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp | |
index 1b715c4b3..f483de2a2 100644 | |
--- a/torch/csrc/jit/init.cpp | |
+++ b/torch/csrc/jit/init.cpp | |
@@ -11,6 +11,7 @@ | |
#include "torch/csrc/jit/passes/onnx.h" | |
#include "torch/csrc/jit/passes/dead_code_elimination.h" | |
#include "torch/csrc/jit/passes/erase_number_types.h" | |
+#include "torch/csrc/jit/passes/eliminate_index_lists.h" | |
#include "torch/csrc/jit/passes/common_subexpression_elimination.h" | |
#include "torch/csrc/jit/passes/peephole.h" | |
#include "torch/csrc/jit/passes/canonicalize.h" | |
@@ -82,6 +83,7 @@ void initJITBindings(PyObject *module) { | |
.def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) { | |
return ConstantPropagation(g); | |
}) | |
+ .def("_jit_pass_eliminate_index_lists", eraseIndexWithLists) | |
.def("_jit_run_cpp_tests", [] { | |
// We have to release the GIL inside this method, because if we happen to | |
// initialize the autograd engine in these tests, the newly spawned worker threads will | |
diff --git a/torch/csrc/jit/passes/eliminate_index_lists.cpp b/torch/csrc/jit/passes/eliminate_index_lists.cpp | |
new file mode 100644 | |
index 000000000..1e1bc93ec | |
--- /dev/null | |
+++ b/torch/csrc/jit/passes/eliminate_index_lists.cpp | |
@@ -0,0 +1,41 @@ | |
+#include "torch/csrc/jit/passes/eliminate_index_lists.h" | |
+ | |
+namespace torch { namespace jit { | |
+ | |
+void eraseIndexWithLists(Block* block) { | |
+ auto g = block->owningGraph(); | |
+ for (auto it = block->nodes().begin(), end = block->nodes().end(); | |
+ it != end;) { | |
+ Node * n = *it; | |
+ ++it; | |
+ | |
+ for (auto b : n->blocks()) { | |
+ eraseIndexWithLists(b); | |
+ } | |
+ | |
+ // Replace a sequence of ListConstruct -> Index with a series of index_select | |
+ // ops, one for each dimension in the list specification. | |
+ if (n->kind() == prim::ListConstruct && n->output()->uses().size() == 1 | |
+ && n->output()->uses()[0].user->kind() == aten::index) { | |
+ Node *index_node = n->output()->uses()[0].user; | |
+ WithInsertPoint guard(n); | |
+ Value *self = index_node->inputs()[0]; // Note this is carried across iterations | |
+ for (size_t dim = 0; dim < n->inputs().size(); ++dim) { | |
+ at::ArrayRef<NamedValue> input_args = { | |
+ /*self=*/NamedValue(self), | |
+ /*dim=*/NamedValue(g->insertConstant((int64_t)dim)), | |
+ /*index=*/NamedValue(n->inputs()[dim]) | |
+ }; | |
+ self = g->insert(aten::index_select, input_args); | |
+ } // for (size_t dim = 0; ... | |
+ index_node->output()->replaceAllUsesWith(self); | |
+ // Let DCE clean up the original nodes. | |
+ } // if (n->kind() == prim::ListConstruct ... | |
+ } | |
+} | |
+ | |
+void eraseIndexWithLists(Graph* graph) { | |
+ eraseIndexWithLists(graph->block()); | |
+} | |
+ | |
+}} // namespace torch::jit | |
diff --git a/torch/csrc/jit/passes/eliminate_index_lists.h b/torch/csrc/jit/passes/eliminate_index_lists.h | |
new file mode 100644 | |
index 000000000..8fef2825e | |
--- /dev/null | |
+++ b/torch/csrc/jit/passes/eliminate_index_lists.h | |
@@ -0,0 +1,10 @@ | |
+#pragma once | |
+ | |
+#include "torch/csrc/jit/ir.h" | |
+ | |
+namespace torch { namespace jit { | |
+ | |
+TORCH_API void eraseIndexWithLists(Graph* graph); | |
+ | |
+ | |
+}} // namespace torch::jit | |
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py | |
index b770b900c..62966702d 100644 | |
--- a/torch/onnx/utils.py | |
+++ b/torch/onnx/utils.py | |
@@ -120,6 +120,10 @@ def _optimize_graph(graph, operator_export_type): | |
torch._C._jit_pass_peephole(graph) | |
torch._C._jit_pass_lint(graph) | |
+ torch._C._jit_pass_eliminate_index_lists(graph) | |
+ torch._C._jit_pass_dce(graph) | |
+ torch._C._jit_pass_lint(graph) | |
+ | |
# onnx only supports tensors, so we turn all out number types into tensors | |
torch._C._jit_pass_erase_number_types(graph) | |
torch._C._jit_pass_peephole(graph) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment