Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
An example of using the PyTorch C++ API to implement a custom forward and backward function
// An example of using the PyTorch C++ API to implement a custom forward and backward function
#include <iostream>
#include <vector>
#include <torch/torch.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/csrc/autograd/functions/utils.h>
using torch::Tensor;
using at::Scalar;
using torch::autograd::Function;
using torch::autograd::deleteFunction;
using torch::autograd::SavedVariable;
using torch::autograd::variable_list;
using torch::autograd::tensor_list;
using torch::autograd::as_variable;
using torch::autograd::as_variable_ref;
using torch::autograd::compute_requires_grad;
using torch::autograd::collect_next_edges;
using torch::autograd::flatten_tensor_args;
struct MyPowBackward : public Function {
// Public members that we use to store the forward pass, such that we can use it in gradient calculation
SavedVariable self_;
Scalar exponent_;
// The following function is called during the backward pass
variable_list apply(variable_list&& grads) override {
std::cout << "-> Computing MyPow Backward!" << std::endl;
// Our function had one output, so we only expect 1 gradient
auto& grad = grads[0];
// Grab the data out of the saved variable
auto self = self_.unpack();
double exponent = exponent_.toDouble();
// Variable list to hold the gradients at the function's input variables
variable_list grad_inputs(1);
// Do gradient computation for each of the inputs
if (should_compute_output(0)) {
auto grad_result = exponent != 0.0 ? grad * exponent * self.pow(exponent - 1) : torch::zeros_like(self);
grad_inputs[0] = grad_result;
}
return grad_inputs;
}
// Apparently we need to manually handle destruction of SavedVaribles
void release_variables() override {
self_.reset_data();
self_.reset_grad_function();
}
};
Tensor MyPowForward(const Tensor & self, Scalar exponent) {
std::cout << "-> Computing MyPow Forward!" << std::endl;
// Compute the function's output
auto& self_ = as_variable_ref(self);
auto tmp = self_.data().pow(exponent); // compute the output based on the tensor's data
auto result = as_variable(tmp);
// Prepare the infrastructure for computing the function's gradient
if (compute_requires_grad( self )) {
// Initialize the gradient function
auto grad_fn = std::shared_ptr<MyPowBackward>(new MyPowBackward(), deleteFunction);
// Connect into the autograd graph
grad_fn->set_next_edges(collect_next_edges( self ));
// Save the function arguments for use in the backwards pass
grad_fn->self_ = SavedVariable(self, false);
grad_fn->exponent_ = exponent;
// Attach the gradient function to the result
set_history(flatten_tensor_args( result ), grad_fn);
}
return result;
}
int main() {
auto a = 3*torch::ones({3,3});
a.set_requires_grad(true);
std::cout << "Begin Forward Pass" << std::endl;
auto b = MyPowForward(a, 2).sum();
std::cout << "Begin Backward Pass" << std::endl;
b.backward();
std::cout << a.grad() << std::endl;
}
@janursa
Copy link

janursa commented Nov 26, 2019

Hi,
Is this code still valid? i tried it and most of the elements used are not recognized such as "deleteFunction" and "Function".
Best

@A-Malone
Copy link

A-Malone commented Dec 11, 2019

It is, but Function has since been renamed to Node, along with all the functions which handle Functions. try "deleteNode" and "Node"

@vokhidovhusan
Copy link

vokhidovhusan commented Apr 23, 2021

What about as_variable? it says as_variable has not been declared

@tigerneil
Copy link

tigerneil commented Aug 13, 2021

reset_grad_function also removed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment