Skip to content

Instantly share code, notes, and snippets.

@hobwekiva
Last active March 30, 2024 14:02
Show Gist options
  • Save hobwekiva/c0efc9844629bd5d5d9e5655c5a9aaa5 to your computer and use it in GitHub Desktop.
Save hobwekiva/c0efc9844629bd5d5d9e5655c5a9aaa5 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <cmath>
#include <vector>
#include <cassert>
#include <memory>
enum NodeOp {
OP_CONST,
OP_ADD,
OP_MUL,
OP_EXP
};
class NodeData;
class Node {
public:
// A shared pointer does reference-counting,
// so I don't have to worry about memory leaks.
std::shared_ptr<NodeData> const ptr;
double grad() const;
double value() const;
void set_value(double x);
// Sets all gradient variables to zero in the entire
// expression graph.
void reset_grad();
// Does backpropagation of the gradients.
void backward(double g = 1.0);
// Forward propagation of the values,
// can be used on precomputed expression graphs.
double forward();
// A constructor for constant values and variables.
Node(double c);
Node operator*(const Node& that) const;
Node operator+(const Node& that) const;
Node exp() const;
private:
Node(std::shared_ptr<NodeData> const ptr);
};
class NodeData {
public:
NodeOp op;
double value;
// This variable will be mutated during the backpropagation pass.
double grad;
std::vector<Node> children;
};
Node::Node(double c)
: ptr(std::shared_ptr<NodeData>(
new NodeData { OP_CONST, c, 0.0, {} }
)) { }
Node::Node(std::shared_ptr<NodeData> const ptr) : ptr(ptr) { };
double Node::grad() const {
return ptr->grad;
}
double Node::value() const {
return ptr->value;
}
void Node::set_value(double x) {
ptr->value = x;
}
////////////////////////////////////////////////////////////////
// Backpropagation
////////////////////////////////////////////////////////////////
void Node::reset_grad() {
ptr->grad = 0.0;
for (auto c : ptr->children) {
c.reset_grad();
}
}
void Node::backward(double g) {
ptr->grad += g;
if (ptr->op == OP_CONST) {}
else if (ptr->op == OP_ADD) {
for (auto c : ptr->children) {
c.backward(g);
}
}
else if (ptr->op == OP_MUL) {
// Easier to assume it's always 2 than deal with the other cases.
// If it is not 2 I would have to multiply all args but one, which is messy.
assert(ptr->children.size() == 2);
Node x = ptr->children[0];
Node y = ptr->children[1];
x.backward(g * y.value());
y.backward(g * x.value());
// Let f(x, y) = x * y
// and z = f(x, y)
// We know dL / dz = g
// then dL / dx = dL / dz * df / dx = g * y
}
else if (ptr->op == OP_EXP) {
assert(ptr->children.size() == 1);
Node x = ptr->children[0];
x.backward(g * std::exp(x.value())); //
// Let f(x) = exp(x)
// and z = f(x)
// We know dL / dz = g
// then dL / dx = dL / dz * df / dx = g * exp(x)
}
else assert(false);
}
// Forward pass.
double Node::forward() {
if (ptr->op == OP_CONST) {
// do nothing.
}
else if (ptr->op == OP_ADD) {
double sum = 0.0;
for (auto c : ptr->children) {
sum += c.forward();
}
set_value(sum);
}
else if (ptr->op == OP_MUL) {
double product = 1.0;
for (auto c : ptr->children) {
product *= c.forward();
}
set_value(product);
}
else if (ptr->op == OP_EXP) {
assert(ptr->children.size() == 1);
set_value(std::exp(ptr->children[0].value()));
}
else assert(false);
return value();
}
////////////////////////////////////////////////////////////////
// Operators
////////////////////////////////////////////////////////////////
Node Node::operator*(const Node& that) const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_MUL, ptr->value * that.ptr->value, 0.0, { *this, that }}
)
};
}
Node Node::operator+(const Node& that) const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_ADD, ptr->value + that.ptr->value, 0.0, { *this, that }}
)
};
}
Node Node::exp() const {
return Node {
std::shared_ptr<NodeData>(
new NodeData { OP_EXP, std::exp(ptr->value), 0.0, { *this }}
)
};
}
Node operator+(double x, Node& n) {
return Node(x) + n;
}
Node operator+(Node& n, double x) {
return Node(x) + n;
}
Node operator*(double x, Node& n) {
return Node(x) * n;
}
Node operator*(Node& n, double x) {
return Node(x) * n;
}
////////////////////////////////////////////////////////////////
int main() {
Node x(2);
Node y(1);
// Precomputed expression graph.
// It is not *necessary* to precompute graphs, see below.
Node l0 = 3 * x * x + y * y + x * y + 10;
for (int i = 0; i < 30; i++) {
// Some loss function with a global minimum.
Node loss = 3 * x * x + y * y + x * y + 10;
// Check that precomputed expression graph produces the same result.
l0.forward();
assert(std::abs(l0.value() - loss.value()) < 0.00001);
loss.reset_grad();
// Backpropagation step.
loss.backward(1.0);
// One SGD step.
x.set_value(x.value() - x.grad() * 0.01);
y.set_value(y.value() - y.grad() * 0.01);
std::cout << i << " " << loss.value() << " " << x.grad() << " " << y.grad() << std::endl;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment