Skip to content

Instantly share code, notes, and snippets.

@zoq
Created June 17, 2019 13:21
Show Gist options
  • Save zoq/87070ff2a4bf769d2264527b2f67b035 to your computer and use it in GitHub Desktop.
Save zoq/87070ff2a4bf769d2264527b2f67b035 to your computer and use it in GitHub Desktop.
convnet_example.hpp
arma::cube imageA(28, 28, 3);
imageA.fill(1.0);
arma::cube imageB(28, 28, 3);
imageB.fill(0.5);
arma::mat imageData(28*28*3, 100);
imageData.fill(1.0);
arma::mat imageLabels = arma::zeros<arma::mat>(1, 100);
for (size_t i = 0; i < 100; i++)
{
if (i < 50)
{
imageLabels(i) = 1;
}
else
{
imageLabels(i) = 2;
}
}
for (size_t i = 0; i < 50; i++)
imageData.col(i) = arma::vectorise(imageB);
FFN<NegativeLogLikelihood<>, RandomInitialization> model;
model.Add<Convolution<> >(1 * 3, 8 * 3, 5, 5, 1, 1, 0, 0, 28, 28);
model.Add<ReLULayer<> >();
model.Add<MaxPooling<> >(8, 8, 2, 2);
model.Add<Convolution<> >(8 * 3, 12 * 3, 2, 2);
model.Add<ReLULayer<> >();
model.Add<MaxPooling<> >(2, 2, 2, 2);
model.Add<Linear<> >(192 * 3, 20);
model.Add<ReLULayer<> >();
model.Add<Linear<> >(20, 10);
model.Add<ReLULayer<> >();
model.Add<Linear<> >(10, 2);
model.Add<LogSoftMax<> >();
// Train for only 8 epochs.
ens::RMSProp opt(0.001, 1, 0.88, 1e-8, 8 * nPoints, -1);
double objVal = model.Train(imageData, imageLabels, opt);
// Test that objective value returned by FFN::Train() is finite.
BOOST_REQUIRE_EQUAL(std::isfinite(objVal), true);
arma::mat predictionTemp;
model.Predict(imageData, predictionTemp);
arma::mat prediction = arma::zeros<arma::mat>(1, predictionTemp.n_cols);
for (size_t i = 0; i < predictionTemp.n_cols; ++i)
{
prediction(i) = arma::as_scalar(arma::find(
arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1;
}
size_t correct = 0;
for (size_t i = 0; i < X.n_cols; i++)
{
if (prediction(i) == imageLabels(i))
correct++;
}
double classificationError = 1 - double(correct) / X.n_cols;
std::cout << classificationError << std::endl;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment