Created
June 25, 2015 18:52
-
-
Save alcinos/3bedb2f7c4518fa93220 to your computer and use it in GitHub Desktop.
Api proposal for arrayfire_ml's deep nets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* @author alcinos | |
* @date Thu Jun 25 11:02:48 2015 | |
* | |
* @brief Proposal for the api of a deep neural net library | |
* | |
* | |
*/ | |
class ComputationNode | |
{ | |
public: | |
ComputationNode(std::string name,NodeParam& params); | |
ComputationNode(const ComputationNode& other); | |
ComputationNode(ComputationNode&& other); | |
ComputationNode& operator=(const ComputationNode&); | |
virtual ~ComputationNode(); | |
virtual void computeForward(const std::vector<std::shared_ptr<array> >& input, | |
const std::vector<std::shared_ptr<array> >& output); | |
virtual void computeBackward(const std::vector<std::shared_ptr<array> >& input, | |
const std::vector<std::shared_ptr<array> >& output); | |
//the name of the node | |
std::string getName(); | |
//number of arguments of the node (ie size of the "input" vector) | |
virtual unsigned getArity() = 0; | |
}; | |
class Network | |
{ | |
public: | |
Network(const Network& other); | |
Network(Network&& other); | |
Network& operator=(const Network&); | |
void addNode(const std::shared_ptr<array>& node, bool inputNode = false); | |
/**create a directed link between two nodes. The parameter inputPos is used when the destination node | |
*has several inputs (arity > 1), to select which of its inputs we are feeding (the function computed by | |
*the node may not be symmetric with respect to its inputs) | |
*This function has to check that the parameter inputPos is consistent with the node's arity | |
*It must also check that this particuliar input is not fed by anyone else. | |
*/ | |
void addLink(const std::string& nameOrigin, const std::string& nameDest, unsigned inputPos = 0 ); | |
void addLink(const std::shared_ptr<ComputationNode>& origin, | |
const std::shared_ptr<ComputationNode>& dest); | |
//check that all nodes have all their inputs fed. (except input nodes) | |
bool check(); | |
std::shared_ptr<ComputationNode> getNodeByName(const std::string& name); | |
void computeForward(const std::vector<std::shared_ptr<array> >& input, | |
const std::vector<std::shared_ptr<array> >& output); | |
void computeBackward(const std::vector<std::shared_ptr<array> >& error); | |
protected: | |
std::vector<std::shared_ptr<ComputationNode> > m_nodes; | |
//the unsigned is just the position in the m_nodes vector | |
std::unordered_map<std::string,unsigned > m_nodesByName; | |
//ids of the input nodes | |
std::vector<unsigned> m_inputNodes; | |
std::vector<std::vector<unsigned> > m_adjencyList; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment