Created
April 17, 2016 15:28
-
-
Save yko/286e478455268a7c63b4a3871fead277 to your computer and use it in GitHub Desktop.
backprop neural net trained by MNIST database
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use strict; | |
use warnings; | |
use lib 'lib/perl5'; | |
use Fcntl qw(SEEK_SET SEEK_CUR SEEK_END); | |
use Getopt::Long; | |
GetOptions( d => \my($DEBUG) ); | |
if (!eval { require AI::NeuralNet::Simple; 1; } || @ARGV != 4 || grep { !-f $_ } @ARGV) { | |
print <<EOF; | |
Usage: | |
# Download MNIST database | |
wget http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz} | |
gzip -d {train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz} | |
# Install AI::NeuralNet::Simple | |
curl -L https://cpanmin.us | perl - -L . AI::NeuralNet::Simple | |
# Train a simple NN agains train-images-idx3-ubyte train-labels-idx1-ubyte | |
# and test it agains t10k-images-idx3-ubyte t10k-labels-idx1-ubyte | |
perl $0 train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte | |
# Read more about MNIST db at http://yann.lecun.com/exdb/mnist/ | |
# Read more about AI::NeuralNet::Simple at https://metacpan.org/pod/AI::NeuralNet::Simple | |
EOF | |
exit; | |
} | |
my ($pixels, $labels, $test_pixels, $test_labels) = @ARGV; | |
my $net; | |
{ | |
my ($train_images, $train_answers, $width, $height) = load_data($pixels, $labels); | |
print "Loaded train data (img[" . @$train_images . "]=$pixels, answers[" . @$train_answers . "=$labels)...\n"; | |
my @train_set; | |
for (my $idx =0; $idx < @$train_images; $idx++) { | |
# Map an answer to a 10-element array of booleans | |
my @output = map { $_ == $train_answers->[$idx] ? 1 : 0 } 0..9; | |
push @train_set, $train_images->[$idx],\@output; | |
} | |
print "Training...\n"; | |
# NN with input layer of number of pixels in an image, | |
# 50 is a randomly picked number of neurons for a hidden layer, | |
# 10 - number of possible outputs | |
$net = AI::NeuralNet::Simple->new($width*$height, 50, 10); | |
$net->train_set(\@train_set, 30, 0.01); | |
print "Done with training.\n"; | |
}; | |
{ | |
my ($test_images, $test_answers, $width, $height) = load_data($test_pixels, $test_labels); | |
print "Loaded test data (img[". @$test_images . "]=$test_pixels, answers[" . @$test_answers . "]=$test_labels)...\n"; | |
my ($right_answers, $wrong_answers); | |
for (my $idx = 0; $idx < @$test_images; $idx++) { | |
my $image = $test_images->[$idx]; | |
my $right_answer = $test_answers->[$idx]; | |
my $net_answered = $net->winner($test_images->[$idx]); | |
if ($net_answered == $right_answer) { | |
$right_answers++; | |
} else { | |
$wrong_answers++; | |
} | |
if ($DEBUG && $net_answered != $right_answer) { | |
printf "Expected answer is %i, NN answered %i\n", $right_answer, $net_answered; | |
for my $y (0..$height-1) { | |
for my $x (0..$width-1) { | |
print $image->[($y*$height) + $x] ? "#" : " "; | |
} | |
print "\n"; | |
} | |
print '=' x 30, "\n"; | |
} | |
} | |
printf "Correct answer ratio=%0.2f%% (%i right vs %i wrong)\n", 100*$right_answers/($wrong_answers+$right_answers), $right_answers, $wrong_answers; | |
} | |
sub load_data { | |
my ($images, $labels) = @_; | |
open my $PIXELS, '<:raw', $images or die; | |
open my $LABELS, '<:raw', $labels or die; | |
my ($width, $height); | |
# Skip all the magic, load only sizes and the data | |
seek $PIXELS, 8, SEEK_SET; | |
seek $LABELS, 8, SEEK_SET; | |
read $PIXELS, $width, 4; | |
read $PIXELS, $height, 4; | |
$width = unpack 'N', $width; | |
$height = unpack 'N', $height; | |
print "Images are sized $width x $height\n"; | |
my @images; | |
my @labels; | |
my $imglen = $width*$height; | |
while (!eof($PIXELS)) { | |
read $PIXELS, my($bytes), $imglen * 1000; | |
read $LABELS, my($labels_bytes), 1000; | |
my @pixels = map { $_ < 0.9 ? 0 : 1 } unpack 'C' . length($bytes), $bytes; | |
my @tmplabels = unpack 'C' . length($labels_bytes), $labels_bytes; | |
while (@pixels) { | |
my @image = splice @pixels, -$imglen, $imglen, (); | |
push @images, \@image if @image == $imglen; | |
push @labels, pop @tmplabels; | |
} | |
} | |
return \@images, \@labels, $width, $height; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment