Skip to content

Instantly share code, notes, and snippets.

@soumith
Created November 21, 2018 00:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/68e55c24845ca16770f55409c0230776 to your computer and use it in GitHub Desktop.
Save soumith/68e55c24845ca16770f55409c0230776 to your computer and use it in GitHub Desktop.
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
std::vector<int64_t> kernel_size = {3, 5};
std::vector<int64_t> stride = {1, 2};
std::vector<int64_t> padding = {2, 1};
constexpr int out_channels = 5;
// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});
// run forward eagerly
at::Tensor output, finput, fgradinput;
std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size,
bias, stride, padding);
// make grad_outputs
at::Tensor grad_output = torch::randn_like(output);
at::Tensor grad_finput = torch::zeros_like(finput);
at::Tensor grad_fgradinput = torch::zeros_like(fgradinput);
// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight,
kernel_size, stride, padding,
finput, fgradinput, {true, true, true});
// make JIT graph
auto graph = std::make_shared<Graph>();
auto ksz_val = graph->insertConstant(IValue(kernel_size));
auto kst_val = graph->insertConstant(IValue(stride));
auto pad_val = graph->insertConstant(IValue(padding));
auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto tuple_index_node = graph->insertNode(graph->createTupleIndex(conv->node()->output(0), 0));
graph->registerOutput(tuple_index_node->output(0));
LowerAllTuples(graph);
graph->lint();
std::cout << *graph << std::endl;
// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment