Created
November 21, 2018 00:18
-
-
Save soumith/68e55c24845ca16770f55409c0230776 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W | |
std::vector<int64_t> kernel_size = {3, 5}; | |
std::vector<int64_t> stride = {1, 2}; | |
std::vector<int64_t> padding = {2, 1}; | |
constexpr int out_channels = 5; | |
// make inputs | |
at::Tensor input = torch::randn(input_size); | |
at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]}); | |
at::Tensor bias = torch::randn({out_channels}); | |
// run forward eagerly | |
at::Tensor output, finput, fgradinput; | |
std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size, | |
bias, stride, padding); | |
// make grad_outputs | |
at::Tensor grad_output = torch::randn_like(output); | |
at::Tensor grad_finput = torch::zeros_like(finput); | |
at::Tensor grad_fgradinput = torch::zeros_like(fgradinput); | |
// run backward eagerly | |
at::Tensor grad_input, grad_weight, grad_bias; | |
std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight, | |
kernel_size, stride, padding, | |
finput, fgradinput, {true, true, true}); | |
// make JIT graph | |
auto graph = std::make_shared<Graph>(); | |
auto ksz_val = graph->insertConstant(IValue(kernel_size)); | |
auto kst_val = graph->insertConstant(IValue(stride)); | |
auto pad_val = graph->insertConstant(IValue(padding)); | |
auto inputg = graph->addInput("self"); | |
auto weightg = graph->addInput("weight"); | |
auto biasg = graph->addInput("bias"); | |
Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); | |
auto tuple_index_node = graph->insertNode(graph->createTupleIndex(conv->node()->output(0), 0)); | |
graph->registerOutput(tuple_index_node->output(0)); | |
LowerAllTuples(graph); | |
graph->lint(); | |
std::cout << *graph << std::endl; | |
// differentiate JIT graph | |
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick | |
ConstantPropagation(graph); | |
auto grad_spec = differentiate(graph); | |
LowerGradOf(*grad_spec.df); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment