Skip to content

Instantly share code, notes, and snippets.

@Daiver
Created August 15, 2018 03:57
Show Gist options
  • Save Daiver/61a91c67d36ba086eaa04a45d351ec42 to your computer and use it in GitHub Desktop.
Save Daiver/61a91c67d36ba086eaa04a45d351ec42 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <memory>
using Scalar = float;
class Node;
using NodePtr = std::shared_ptr<Node>;
class Node
{
public:
Node(const bool requiresGrad = false)
{
this->m_requiresGrad = requiresGrad;
}
virtual ~Node() {}
virtual Scalar value() const = 0;
virtual void backward(const Scalar &sensetivity) = 0;
virtual void zeroGrad() {}
bool requiresGrad() const {return m_requiresGrad;}
void setRequiresGrad(bool requiresGrad) {m_requiresGrad = requiresGrad;}
protected:
bool m_requiresGrad;
};
class VarNode : public Node
{
public:
VarNode(const Scalar value, const bool requiresGrad = false):
Node(requiresGrad)
{
this->m_value = value;
this->zeroGrad();
}
virtual Scalar value() const {return m_value;}
virtual void backward(const Scalar &sensetivity)
{
this->m_grad += sensetivity;
}
virtual void zeroGrad() {this->m_grad = 0;}
Scalar grad() const {return m_grad;}
protected:
Scalar m_value;
Scalar m_grad;
};
class BinNode : public Node
{
public:
BinNode(const NodePtr &lhs, const NodePtr &rhs):
Node(lhs->requiresGrad() || rhs->requiresGrad())
{
m_lhs = lhs;
m_rhs = rhs;
}
template<typename Lhs, typename Rhs>
BinNode(const Lhs &lhs, const Rhs &rhs):
Node(lhs.requiresGrad() || rhs.requiresGrad())
{
m_lhs = std::make_shared<Lhs>(lhs);
m_rhs = std::make_shared<Rhs>(rhs);
}
protected:
NodePtr m_lhs;
NodePtr m_rhs;
};
class AddNode : public BinNode
{
public:
template<typename Lhs, typename Rhs>
AddNode(const Lhs &lhs, const Rhs &rhs):
BinNode(lhs, rhs)
{
this->m_value = lhs.value() + rhs.value();
}
virtual Scalar value() const {return m_value;}
virtual void backward(const Scalar &sensetivity)
{
this->m_lhs->backward(sensetivity);
this->m_rhs->backward(sensetivity);
}
protected:
Scalar m_value;
};
template<typename Lhs, typename Rhs>
AddNode operator +(const Lhs &lhs, const Rhs &rhs)
{
return AddNode(lhs, rhs);
}
int main()
{
VarNode x(5, true);
VarNode y(2, true);
auto res = x + y;
res.backward(1);
std::cout << "res " << res.value() << std::endl;
std::cout << "x.grad " << x.grad() << std::endl;
std::cout << "y.grad " << y.grad() << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment