Skip to content

Instantly share code, notes, and snippets.

Created July 27, 2019 00:48
Show Gist options
  • Save john-/11c8a886aff761cfe2573a0421251d3c to your computer and use it in GitHub Desktop.
Save john-/11c8a886aff761cfe2573a0421251d3c to your computer and use it in GitHub Desktop.
#!/usr/bin/env perl
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::Gluon qw(gluon);
use AI::MXNet::AutoGrad qw(autograd);
use AI::MXNet::Gluon::NN qw(nn);
use AI::MXNet::Base;
use Getopt::Long qw(HelpMessage);
'lr=f' => \(my $lr = 0.1),
'log-interval=i' => \(my $log_interval = 100),
'momentum=f' => \(my $momentum = 0.9),
'hybridize=i' => \(my $hybridize = 0 ),
'cuda=i' => \(my $cuda = 0 ),
'load_params=i' => \(my $load_params = 0 ),
'batch-size=i' => \(my $batch_size = 100),
'epochs=i' => \(my $epochs = 1 ),
'help' => sub { HelpMessage(0) },
) or HelpMessage(1);
# define network
my $net = nn->Sequential();
$net->name_scope(sub {
$net->add(nn->Dense(128, activation=>'relu'));
$net->add(nn->Dense(64, activation=>'relu'));
$net->hybridize() if $hybridize;
$net->load_parameters('mnist.params') if $load_params;
# data
sub transformer
my ($data, $label) = @_;
$data = $data->reshape([-1])->astype('float32')/255;
#my $label = $label_tmp->at(0);
return ($data, $label);
my $train_data = gluon->data->DataLoader(
#gluon->data->vision->MNIST('./data', train=>1, transform => \&transformer),
gluon->data->vision->ImageFolderDataset(root => '~/cart/training/train', flag => 0, transform => \&transformer),
batch_size=>$batch_size, shuffle=>1, last_batch=>'discard'
my $val_data = gluon->data->DataLoader(
#gluon->data->vision->MNIST('./data', train=>0, transform=> \&transformer),
gluon->data->vision->ImageFolderDataset(root => '~/cart/training/test', flag => 0, transform => \&transformer),
batch_size=>$batch_size, shuffle=>0
# train
sub test
my $ctx = shift;
my $metric = mx->metric->Accuracy();
while(defined(my $d = <$val_data>))
my ($data, $label) = @$d;
$data = $data->as_in_context($ctx);
$label = $label->as_in_context($ctx);
my $output = $net->($data);
$metric->update([$label], [$output]);
return $metric->get;
sub train
my ($epochs, $ctx) = @_;
# Collect all parameters from net and its children, then initialize them.
$net->initialize(mx->init->Xavier(magnitude=>2.24), ctx=>$ctx);
# Trainer is for updating parameters with gradient.
my $trainer = gluon->Trainer($net->collect_params(), 'sgd', { learning_rate => $lr, momentum => $momentum });
my $metric = mx->metric->Accuracy();
my $loss = gluon->loss->SoftmaxCrossEntropyLoss();
for my $epoch (0..$epochs-1)
# reset data iterator and metric at begining of epoch.
enumerate(sub {
my ($i, $d) = @_;
my ($data, $label) = @$d;
$data = $data->as_in_context($ctx);
$label = $label->as_in_context($ctx);
# Start recording computation graph with record() section.
# Recorded graphs can then be differentiated with backward.
my $output;
autograd->record(sub {
$output = $net->($data);
my $L = $loss->($output, $label);
# take a gradient step with batch_size equal to data.shape[0]
# update metric at last.
$metric->update([$label], [$output]);
if($i % $log_interval == 0 and $i > 0)
my ($name, $acc) = $metric->get();
print "[Epoch $epoch Batch $i] Training: $name=$acc\n";
}, \@{ $train_data });
my ($name, $acc) = $metric->get();
print "[Epoch $epoch] Training: $name=$acc\n";
my ($val_name, $val_acc) = test($ctx);
print "[Epoch $epoch] Validation: $val_name=$val_acc\n"
train($epochs, $cuda ? mx->gpu(0) : mx->cpu);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment