Skip to content

Instantly share code, notes, and snippets.

Last active October 25, 2017 12:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adash333/fca0128a93ffd7d68617ab373da032ac to your computer and use it in GitHub Desktop.
Save adash333/fca0128a93ffd7d68617ab373da032ac to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# original code from
from __future__ import print_function
import matplotlib
except ImportError:
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from import extensions
import os
import cv2
import numpy as np
from chainer.datasets import tuple_dataset
# Network definition
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(MLP(args.unit, 10))
if args.gpu >= 0:
# Make a specified GPU current
model.to_gpu() # Copy the model to the GPU
# Setup an optimizer
optimizer = chainer.optimizers.Adam()
# Load the MNIST dataset
# train, test = chainer.datasets.get_mnist()
#その1 ------データセット作成------
def getDataSet():
X_train = []
X_test = []
y_train = []
y_test = []
for i in range(0,2):
path = "./" #ここにディレクトリのパスを設定
imgList = os.listdir(path+str(i))
imgNum = len(imgList)
cutNum = imgNum - imgNum/5
for j in range(len(imgList)):
imgSrc = cv2.imread(path+str(i)+"/"+imgList[j], 0)
if imgSrc is None:continue
if j < cutNum:
return X_train,y_train,X_test,y_test
# train, test = datasets.get_mnist(ndim=3)
X_train,y_train,X_test,y_test = getDataSet()
X_train = np.array(X_train).astype(np.float32).reshape((len(X_train), 784)) / 255
y_train = np.array(y_train).astype(np.int32)
X_test = np.array(X_test).astype(np.float32).reshape((len(X_test), 784)) / 255
y_test = np.array(y_test).astype(np.int32)
train = tuple_dataset.TupleDataset(X_train, y_train)
test = tuple_dataset.TupleDataset(X_test, y_test)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)
# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
# Save two plot images to the result dir
if extensions.PlotReport.available():
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
# Print a progress bar to stdout
if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)
# Run the training
chainer.serializers.save_npz('my_mnist.model', model)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment