Created
August 15, 2018 03:57
-
-
Save Daiver/61a91c67d36ba086eaa04a45d351ec42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <memory> | |
using Scalar = float; | |
class Node; | |
using NodePtr = std::shared_ptr<Node>; | |
class Node | |
{ | |
public: | |
Node(const bool requiresGrad = false) | |
{ | |
this->m_requiresGrad = requiresGrad; | |
} | |
virtual ~Node() {} | |
virtual Scalar value() const = 0; | |
virtual void backward(const Scalar &sensetivity) = 0; | |
virtual void zeroGrad() {} | |
bool requiresGrad() const {return m_requiresGrad;} | |
void setRequiresGrad(bool requiresGrad) {m_requiresGrad = requiresGrad;} | |
protected: | |
bool m_requiresGrad; | |
}; | |
class VarNode : public Node | |
{ | |
public: | |
VarNode(const Scalar value, const bool requiresGrad = false): | |
Node(requiresGrad) | |
{ | |
this->m_value = value; | |
this->zeroGrad(); | |
} | |
virtual Scalar value() const {return m_value;} | |
virtual void backward(const Scalar &sensetivity) | |
{ | |
this->m_grad += sensetivity; | |
} | |
virtual void zeroGrad() {this->m_grad = 0;} | |
Scalar grad() const {return m_grad;} | |
protected: | |
Scalar m_value; | |
Scalar m_grad; | |
}; | |
class BinNode : public Node | |
{ | |
public: | |
BinNode(const NodePtr &lhs, const NodePtr &rhs): | |
Node(lhs->requiresGrad() || rhs->requiresGrad()) | |
{ | |
m_lhs = lhs; | |
m_rhs = rhs; | |
} | |
template<typename Lhs, typename Rhs> | |
BinNode(const Lhs &lhs, const Rhs &rhs): | |
Node(lhs.requiresGrad() || rhs.requiresGrad()) | |
{ | |
m_lhs = std::make_shared<Lhs>(lhs); | |
m_rhs = std::make_shared<Rhs>(rhs); | |
} | |
protected: | |
NodePtr m_lhs; | |
NodePtr m_rhs; | |
}; | |
class AddNode : public BinNode | |
{ | |
public: | |
template<typename Lhs, typename Rhs> | |
AddNode(const Lhs &lhs, const Rhs &rhs): | |
BinNode(lhs, rhs) | |
{ | |
this->m_value = lhs.value() + rhs.value(); | |
} | |
virtual Scalar value() const {return m_value;} | |
virtual void backward(const Scalar &sensetivity) | |
{ | |
this->m_lhs->backward(sensetivity); | |
this->m_rhs->backward(sensetivity); | |
} | |
protected: | |
Scalar m_value; | |
}; | |
template<typename Lhs, typename Rhs> | |
AddNode operator +(const Lhs &lhs, const Rhs &rhs) | |
{ | |
return AddNode(lhs, rhs); | |
} | |
int main() | |
{ | |
VarNode x(5, true); | |
VarNode y(2, true); | |
auto res = x + y; | |
res.backward(1); | |
std::cout << "res " << res.value() << std::endl; | |
std::cout << "x.grad " << x.grad() << std::endl; | |
std::cout << "y.grad " << y.grad() << std::endl; | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment