Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@soumith
Created November 12, 2018 06:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/f28ab0dd859f8772ff98b1c0f683acbc to your computer and use it in GitHub Desktop.
Save soumith/f28ab0dd859f8772ff98b1c0f683acbc to your computer and use it in GitHub Desktop.
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 59eb7ca11..75abe0097 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -77,7 +77,8 @@ bool isDifferentiable(Node * n) {
"aten::trunc(Tensor self) -> Tensor",
"aten::log_softmax(Tensor self, int dim) -> Tensor",
"aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
- "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)"
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
};
// TODO: add support for the following fusible operators.
@@ -88,7 +89,8 @@ bool isDifferentiable(Node * n) {
if (n->kind() == prim::Constant ||
n->kind() == prim::AutogradAdd ||
- n->kind() == prim::ConstantChunk)
+ n->kind() == prim::ConstantChunk ||
+ n->kind() == prim::TupleConstruct)
return true;
if (differentiable_ops.find(n))
return true;
@@ -144,6 +146,17 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
} else if (node->kind() == prim::AutogradAdd) {
return {grads.at(0), grads.at(0)};
+ } else if (node->kind() == prim::TupleConstruct) {
+ auto graph = node->owningGraph();
+ auto node = graph->createTupleUnpack(grads.at(0).value());
+ graph->insertNode(node);
+ auto outputs = node->outputs();
+ std::vector<SymbolicVariable> sym_outputs;
+ for (auto output : outputs) {
+ sym_outputs.emplace_back(output);
+ }
+ return sym_outputs;
+
} else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr};
@@ -395,6 +408,25 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
});
return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};
+ } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
+ auto graph = node->owningGraph();
+
+ auto convNode = graph->create(aten::thnn_conv2d_backward, 3);
+ convNode->addInput(grads.at(0).value());
+ convNode->addInput(inputs.at(0));
+ convNode->addInput(inputs.at(1));
+ convNode->addInput(node->namedInput(attr::kernel_size));
+ convNode->addInput(node->namedInput(attr::stride));
+ convNode->addInput(node->namedInput(attr::padding));
+ convNode->addInput(outputs.at(1));
+ convNode->addInput(outputs.at(2));
+ convNode->addInput(graph->insertConstant(std::vector<bool>{true, true, true}));
+ graph->insertNode(convNode);
+ auto outputs = convNode->outputs();
+ JIT_ASSERT(outputs.size() == size_t(3));
+
+ return {outputs[0], outputs[1], nullptr, outputs[2], nullptr, nullptr};
+
} else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
JIT_ASSERT(grads.size() == 1);
auto graph = node->owningGraph();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment