Skip to content

Instantly share code, notes, and snippets.

@yko
Created April 17, 2016 15:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yko/286e478455268a7c63b4a3871fead277 to your computer and use it in GitHub Desktop.
Save yko/286e478455268a7c63b4a3871fead277 to your computer and use it in GitHub Desktop.
backprop neural net trained by MNIST database
use strict;
use warnings;
use lib 'lib/perl5';
use Fcntl qw(SEEK_SET SEEK_CUR SEEK_END);
use Getopt::Long;
GetOptions( d => \my($DEBUG) );
if (!eval { require AI::NeuralNet::Simple; 1; } || @ARGV != 4 || grep { !-f $_ } @ARGV) {
print <<EOF;
Usage:
# Download MNIST database
wget http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
gzip -d {train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
# Install AI::NeuralNet::Simple
curl -L https://cpanmin.us | perl - -L . AI::NeuralNet::Simple
# Train a simple NN agains train-images-idx3-ubyte train-labels-idx1-ubyte
# and test it agains t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
perl $0 train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
# Read more about MNIST db at http://yann.lecun.com/exdb/mnist/
# Read more about AI::NeuralNet::Simple at https://metacpan.org/pod/AI::NeuralNet::Simple
EOF
exit;
}
my ($pixels, $labels, $test_pixels, $test_labels) = @ARGV;
my $net;
{
my ($train_images, $train_answers, $width, $height) = load_data($pixels, $labels);
print "Loaded train data (img[" . @$train_images . "]=$pixels, answers[" . @$train_answers . "=$labels)...\n";
my @train_set;
for (my $idx =0; $idx < @$train_images; $idx++) {
# Map an answer to a 10-element array of booleans
my @output = map { $_ == $train_answers->[$idx] ? 1 : 0 } 0..9;
push @train_set, $train_images->[$idx],\@output;
}
print "Training...\n";
# NN with input layer of number of pixels in an image,
# 50 is a randomly picked number of neurons for a hidden layer,
# 10 - number of possible outputs
$net = AI::NeuralNet::Simple->new($width*$height, 50, 10);
$net->train_set(\@train_set, 30, 0.01);
print "Done with training.\n";
};
{
my ($test_images, $test_answers, $width, $height) = load_data($test_pixels, $test_labels);
print "Loaded test data (img[". @$test_images . "]=$test_pixels, answers[" . @$test_answers . "]=$test_labels)...\n";
my ($right_answers, $wrong_answers);
for (my $idx = 0; $idx < @$test_images; $idx++) {
my $image = $test_images->[$idx];
my $right_answer = $test_answers->[$idx];
my $net_answered = $net->winner($test_images->[$idx]);
if ($net_answered == $right_answer) {
$right_answers++;
} else {
$wrong_answers++;
}
if ($DEBUG && $net_answered != $right_answer) {
printf "Expected answer is %i, NN answered %i\n", $right_answer, $net_answered;
for my $y (0..$height-1) {
for my $x (0..$width-1) {
print $image->[($y*$height) + $x] ? "#" : " ";
}
print "\n";
}
print '=' x 30, "\n";
}
}
printf "Correct answer ratio=%0.2f%% (%i right vs %i wrong)\n", 100*$right_answers/($wrong_answers+$right_answers), $right_answers, $wrong_answers;
}
sub load_data {
my ($images, $labels) = @_;
open my $PIXELS, '<:raw', $images or die;
open my $LABELS, '<:raw', $labels or die;
my ($width, $height);
# Skip all the magic, load only sizes and the data
seek $PIXELS, 8, SEEK_SET;
seek $LABELS, 8, SEEK_SET;
read $PIXELS, $width, 4;
read $PIXELS, $height, 4;
$width = unpack 'N', $width;
$height = unpack 'N', $height;
print "Images are sized $width x $height\n";
my @images;
my @labels;
my $imglen = $width*$height;
while (!eof($PIXELS)) {
read $PIXELS, my($bytes), $imglen * 1000;
read $LABELS, my($labels_bytes), 1000;
my @pixels = map { $_ < 0.9 ? 0 : 1 } unpack 'C' . length($bytes), $bytes;
my @tmplabels = unpack 'C' . length($labels_bytes), $labels_bytes;
while (@pixels) {
my @image = splice @pixels, -$imglen, $imglen, ();
push @images, \@image if @image == $imglen;
push @labels, pop @tmplabels;
}
}
return \@images, \@labels, $width, $height;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment