Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created October 4, 2020 00:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/92f92a0562c21921d1318d3a9459f355 to your computer and use it in GitHub Desktop.
Save DuaneNielsen/92f92a0562c21921d1318d3a9459f355 to your computer and use it in GitHub Desktop.
Example of creating dot file of libtorch computation graph in c++
//
// Created by duane on 10/2/20.
//
#include <torch/torch.h>
int main(int arg, char *argv[]){
auto x = torch::randn(3, torch::requires_grad());
auto z = torch::randn(3, torch::requires_grad());
auto y = x * 2;
y = y * 2;
y = y * z;
y += z;
y = y * 2;
std::cout << "digraph G {" << std::endl;
std::deque<std::shared_ptr<torch::autograd::Node>> nodes;
nodes.push_front(y.grad_fn());
while (!nodes.empty()){
std::cout << "\"" << nodes.back() << "\"" << " [label=\"" << nodes.back()->name() << "\"]" << std::endl;
for (const auto &edge : nodes.back()->next_edges()) {
nodes.push_front(edge.function);
std::cout << "\"" << nodes.back() << "\"" << " -> " << "\"" << edge.function << "\"" << std::endl;
}
nodes.pop_back();
}
std::cout << "}" << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment