Skip to content

Instantly share code, notes, and snippets.

@alcinos
Created June 25, 2015 18:52
Show Gist options
  • Save alcinos/3bedb2f7c4518fa93220 to your computer and use it in GitHub Desktop.
Save alcinos/3bedb2f7c4518fa93220 to your computer and use it in GitHub Desktop.
Api proposal for arrayfire_ml's deep nets
/**
* @author alcinos
* @date Thu Jun 25 11:02:48 2015
*
* @brief Proposal for the api of a deep neural net library
*
*
*/
class ComputationNode
{
public:
ComputationNode(std::string name,NodeParam& params);
ComputationNode(const ComputationNode& other);
ComputationNode(ComputationNode&& other);
ComputationNode& operator=(const ComputationNode&);
virtual ~ComputationNode();
virtual void computeForward(const std::vector<std::shared_ptr<array> >& input,
const std::vector<std::shared_ptr<array> >& output);
virtual void computeBackward(const std::vector<std::shared_ptr<array> >& input,
const std::vector<std::shared_ptr<array> >& output);
//the name of the node
std::string getName();
//number of arguments of the node (ie size of the "input" vector)
virtual unsigned getArity() = 0;
};
class Network
{
public:
Network(const Network& other);
Network(Network&& other);
Network& operator=(const Network&);
void addNode(const std::shared_ptr<array>& node, bool inputNode = false);
/**create a directed link between two nodes. The parameter inputPos is used when the destination node
*has several inputs (arity > 1), to select which of its inputs we are feeding (the function computed by
*the node may not be symmetric with respect to its inputs)
*This function has to check that the parameter inputPos is consistent with the node's arity
*It must also check that this particuliar input is not fed by anyone else.
*/
void addLink(const std::string& nameOrigin, const std::string& nameDest, unsigned inputPos = 0 );
void addLink(const std::shared_ptr<ComputationNode>& origin,
const std::shared_ptr<ComputationNode>& dest);
//check that all nodes have all their inputs fed. (except input nodes)
bool check();
std::shared_ptr<ComputationNode> getNodeByName(const std::string& name);
void computeForward(const std::vector<std::shared_ptr<array> >& input,
const std::vector<std::shared_ptr<array> >& output);
void computeBackward(const std::vector<std::shared_ptr<array> >& error);
protected:
std::vector<std::shared_ptr<ComputationNode> > m_nodes;
//the unsigned is just the position in the m_nodes vector
std::unordered_map<std::string,unsigned > m_nodesByName;
//ids of the input nodes
std::vector<unsigned> m_inputNodes;
std::vector<std::vector<unsigned> > m_adjencyList;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment