Skip to content

Instantly share code, notes, and snippets.

@zoq
Created June 24, 2018 13:11
Show Gist options
  • Save zoq/e702cff47a4ababcdcaae8addc95569c to your computer and use it in GitHub Desktop.
Save zoq/e702cff47a4ababcdcaae8addc95569c to your computer and use it in GitHub Desktop.
network.hpp
arma::mat input = arma::ones(20, 1);
arma::mat target;
target.ones(20, 1);
arma::mat output;
// User sub-network.
Sequential<>* userModel = new Sequential<>();
userModel->Add<Subview<> >(1, 0);
userModel->Add<Embedding<> >(30, 20);
userModel->Add<Linear<> >(20, 10);
// Item sub-network.
Sequential<>* itemModel = new Sequential<>();
itemModel->Add<Subview<> >(1, 10);
itemModel->Add<Embedding<> >(30, 20);
itemModel->Add<Linear<> >(20, 10);
// Merge the user and item sub-network.
Concat<>* mergeModel = new Concat<>(true, true);
mergeModel->Add(userModel);
mergeModel->Add(itemModel);
// Create the main network.
FFN<NegativeLogLikelihood<>, RandomInitialization> network;
network.Add<IdentityLayer<> >();
network.Add(mergeModel);
network.Add<IdentityLayer<> >();
network.Add<Linear<> >(20, 5);
network.Add<SigmoidLayer<> >();
network.Add<LogSoftMax<> >();
network.Train(input, target);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment