Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@mikehamer
Created April 9, 2019 13:28
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mikehamer/df0af5ec7ff98d3cae487975d0c921df to your computer and use it in GitHub Desktop.
Save mikehamer/df0af5ec7ff98d3cae487975d0c921df to your computer and use it in GitHub Desktop.
An example of using the PyTorch C++ API to implement a custom forward and backward function
// An example of using the PyTorch C++ API to implement a custom forward and backward function
#include <iostream>
#include <vector>
#include <torch/torch.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/csrc/autograd/functions/utils.h>
using torch::Tensor;
using at::Scalar;
using torch::autograd::Function;
using torch::autograd::deleteFunction;
using torch::autograd::SavedVariable;
using torch::autograd::variable_list;
using torch::autograd::tensor_list;
using torch::autograd::as_variable;
using torch::autograd::as_variable_ref;
using torch::autograd::compute_requires_grad;
using torch::autograd::collect_next_edges;
using torch::autograd::flatten_tensor_args;
struct MyPowBackward : public Function {
// Public members that we use to store the forward pass, such that we can use it in gradient calculation
SavedVariable self_;
Scalar exponent_;
// The following function is called during the backward pass
variable_list apply(variable_list&& grads) override {
std::cout << "-> Computing MyPow Backward!" << std::endl;
// Our function had one output, so we only expect 1 gradient
auto& grad = grads[0];
// Grab the data out of the saved variable
auto self = self_.unpack();
double exponent = exponent_.toDouble();
// Variable list to hold the gradients at the function's input variables
variable_list grad_inputs(1);
// Do gradient computation for each of the inputs
if (should_compute_output(0)) {
auto grad_result = exponent != 0.0 ? grad * exponent * self.pow(exponent - 1) : torch::zeros_like(self);
grad_inputs[0] = grad_result;
}
return grad_inputs;
}
// Apparently we need to manually handle destruction of SavedVaribles
void release_variables() override {
self_.reset_data();
self_.reset_grad_function();
}
};
Tensor MyPowForward(const Tensor & self, Scalar exponent) {
std::cout << "-> Computing MyPow Forward!" << std::endl;
// Compute the function's output
auto& self_ = as_variable_ref(self);
auto tmp = self_.data().pow(exponent); // compute the output based on the tensor's data
auto result = as_variable(tmp);
// Prepare the infrastructure for computing the function's gradient
if (compute_requires_grad( self )) {
// Initialize the gradient function
auto grad_fn = std::shared_ptr<MyPowBackward>(new MyPowBackward(), deleteFunction);
// Connect into the autograd graph
grad_fn->set_next_edges(collect_next_edges( self ));
// Save the function arguments for use in the backwards pass
grad_fn->self_ = SavedVariable(self, false);
grad_fn->exponent_ = exponent;
// Attach the gradient function to the result
set_history(flatten_tensor_args( result ), grad_fn);
}
return result;
}
int main() {
auto a = 3*torch::ones({3,3});
a.set_requires_grad(true);
std::cout << "Begin Forward Pass" << std::endl;
auto b = MyPowForward(a, 2).sum();
std::cout << "Begin Backward Pass" << std::endl;
b.backward();
std::cout << a.grad() << std::endl;
}
@tigerneil
Copy link

reset_grad_function also removed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment