Created
November 12, 2018 06:56
-
-
Save soumith/f28ab0dd859f8772ff98b1c0f683acbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp | |
index 59eb7ca11..75abe0097 100644 | |
--- a/torch/csrc/jit/autodiff.cpp | |
+++ b/torch/csrc/jit/autodiff.cpp | |
@@ -77,7 +77,8 @@ bool isDifferentiable(Node * n) { | |
"aten::trunc(Tensor self) -> Tensor", | |
"aten::log_softmax(Tensor self, int dim) -> Tensor", | |
"aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", | |
- "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)" | |
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)", | |
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", | |
}; | |
// TODO: add support for the following fusible operators. | |
@@ -88,7 +89,8 @@ bool isDifferentiable(Node * n) { | |
if (n->kind() == prim::Constant || | |
n->kind() == prim::AutogradAdd || | |
- n->kind() == prim::ConstantChunk) | |
+ n->kind() == prim::ConstantChunk || | |
+ n->kind() == prim::TupleConstruct) | |
return true; | |
if (differentiable_ops.find(n)) | |
return true; | |
@@ -144,6 +146,17 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val | |
} else if (node->kind() == prim::AutogradAdd) { | |
return {grads.at(0), grads.at(0)}; | |
+ } else if (node->kind() == prim::TupleConstruct) { | |
+ auto graph = node->owningGraph(); | |
+ auto node = graph->createTupleUnpack(grads.at(0).value()); | |
+ graph->insertNode(node); | |
+ auto outputs = node->outputs(); | |
+ std::vector<SymbolicVariable> sym_outputs; | |
+ for (auto output : outputs) { | |
+ sym_outputs.emplace_back(output); | |
+ } | |
+ return sym_outputs; | |
+ | |
} else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { | |
return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; | |
@@ -395,6 +408,25 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val | |
}); | |
return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr}; | |
+ } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { | |
+ auto graph = node->owningGraph(); | |
+ | |
+ auto convNode = graph->create(aten::thnn_conv2d_backward, 3); | |
+ convNode->addInput(grads.at(0).value()); | |
+ convNode->addInput(inputs.at(0)); | |
+ convNode->addInput(inputs.at(1)); | |
+ convNode->addInput(node->namedInput(attr::kernel_size)); | |
+ convNode->addInput(node->namedInput(attr::stride)); | |
+ convNode->addInput(node->namedInput(attr::padding)); | |
+ convNode->addInput(outputs.at(1)); | |
+ convNode->addInput(outputs.at(2)); | |
+ convNode->addInput(graph->insertConstant(std::vector<bool>{true, true, true})); | |
+ graph->insertNode(convNode); | |
+ auto outputs = convNode->outputs(); | |
+ JIT_ASSERT(outputs.size() == size_t(3)); | |
+ | |
+ return {outputs[0], outputs[1], nullptr, outputs[2], nullptr, nullptr}; | |
+ | |
} else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) { | |
JIT_ASSERT(grads.size() == 1); | |
auto graph = node->owningGraph(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment