Last active
October 25, 2017 12:09
-
-
Save adash333/fca0128a93ffd7d68617ab373da032ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# original code from http://www.mathgram.xyz/entry/chainer/bake/part3 | |
from __future__ import print_function | |
try: | |
import matplotlib | |
matplotlib.use('Agg') | |
except ImportError: | |
pass | |
import argparse | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import training | |
from chainer.training import extensions | |
import os | |
import cv2 | |
import numpy as np | |
from chainer.datasets import tuple_dataset | |
# Network definition | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
# the size of the inputs to each layer will be inferred | |
self.l1 = L.Linear(None, n_units) # n_in -> n_units | |
self.l2 = L.Linear(None, n_units) # n_units -> n_units | |
self.l3 = L.Linear(None, n_out) # n_units -> n_out | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
def main(): | |
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | |
parser.add_argument('--batchsize', '-b', type=int, default=100, | |
help='Number of images in each mini-batch') | |
parser.add_argument('--epoch', '-e', type=int, default=20, | |
help='Number of sweeps over the dataset to train') | |
parser.add_argument('--frequency', '-f', type=int, default=-1, | |
help='Frequency of taking a snapshot') | |
parser.add_argument('--gpu', '-g', type=int, default=-1, | |
help='GPU ID (negative value indicates CPU)') | |
parser.add_argument('--out', '-o', default='result', | |
help='Directory to output the result') | |
parser.add_argument('--resume', '-r', default='', | |
help='Resume the training from snapshot') | |
parser.add_argument('--unit', '-u', type=int, default=1000, | |
help='Number of units') | |
args = parser.parse_args() | |
print('GPU: {}'.format(args.gpu)) | |
print('# unit: {}'.format(args.unit)) | |
print('# Minibatch-size: {}'.format(args.batchsize)) | |
print('# epoch: {}'.format(args.epoch)) | |
print('') | |
# Set up a neural network to train | |
# Classifier reports softmax cross entropy loss and accuracy at every | |
# iteration, which will be used by the PrintReport extension below. | |
model = L.Classifier(MLP(args.unit, 10)) | |
if args.gpu >= 0: | |
# Make a specified GPU current | |
chainer.cuda.get_device_from_id(args.gpu).use() | |
model.to_gpu() # Copy the model to the GPU | |
# Setup an optimizer | |
optimizer = chainer.optimizers.Adam() | |
optimizer.setup(model) | |
# Load the MNIST dataset | |
# train, test = chainer.datasets.get_mnist() | |
#その1 ------データセット作成------ | |
#フォルダは整数で名前が付いています。 | |
#今回0が負例で、1が暦フォルダになっております。 | |
def getDataSet(): | |
#リストの作成 | |
X_train = [] | |
X_test = [] | |
y_train = [] | |
y_test = [] | |
for i in range(0,2): | |
#まずは2値分類を目指すので暦フォルダとothersフォルダの中身だけ引っ張ってきます。 | |
path = "./" #ここにディレクトリのパスを設定 | |
imgList = os.listdir(path+str(i)) | |
#データを4:1の割合でtrainとtestに分けます。 | |
imgNum = len(imgList) | |
cutNum = imgNum - imgNum/5 | |
for j in range(len(imgList)): | |
imgSrc = cv2.imread(path+str(i)+"/"+imgList[j], 0) | |
#またimreadはゴミを吸い込んでも、エラーで止まらずNoneを返してくれます。 | |
#ですので読み込み結果がNoneでしたらスキップしてもらいます。 | |
if imgSrc is None:continue | |
if j < cutNum: | |
X_train.append(imgSrc) | |
y_train.append(i) | |
else: | |
X_test.append(imgSrc) | |
y_test.append(i) | |
return X_train,y_train,X_test,y_test | |
# train, test = datasets.get_mnist(ndim=3) | |
X_train,y_train,X_test,y_test = getDataSet() | |
X_train = np.array(X_train).astype(np.float32).reshape((len(X_train), 784)) / 255 | |
y_train = np.array(y_train).astype(np.int32) | |
X_test = np.array(X_test).astype(np.float32).reshape((len(X_test), 784)) / 255 | |
y_test = np.array(y_test).astype(np.int32) | |
train = tuple_dataset.TupleDataset(X_train, y_train) | |
test = tuple_dataset.TupleDataset(X_test, y_test) | |
train_iter = chainer.iterators.SerialIterator(train, args.batchsize) | |
test_iter = chainer.iterators.SerialIterator(test, args.batchsize, | |
repeat=False, shuffle=False) | |
# Set up a trainer | |
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) | |
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) | |
# Evaluate the model with the test dataset for each epoch | |
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) | |
# Dump a computational graph from 'loss' variable at the first iteration | |
# The "main" refers to the target link of the "main" optimizer. | |
trainer.extend(extensions.dump_graph('main/loss')) | |
# Take a snapshot for each specified epoch | |
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) | |
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) | |
# Write a log of evaluation statistics for each epoch | |
trainer.extend(extensions.LogReport()) | |
# Save two plot images to the result dir | |
if extensions.PlotReport.available(): | |
trainer.extend( | |
extensions.PlotReport(['main/loss', 'validation/main/loss'], | |
'epoch', file_name='loss.png')) | |
trainer.extend( | |
extensions.PlotReport( | |
['main/accuracy', 'validation/main/accuracy'], | |
'epoch', file_name='accuracy.png')) | |
# Print selected entries of the log to stdout | |
# Here "main" refers to the target link of the "main" optimizer again, and | |
# "validation" refers to the default name of the Evaluator extension. | |
# Entries other than 'epoch' are reported by the Classifier link, called by | |
# either the updater or the evaluator. | |
trainer.extend(extensions.PrintReport( | |
['epoch', 'main/loss', 'validation/main/loss', | |
'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) | |
# Print a progress bar to stdout | |
trainer.extend(extensions.ProgressBar()) | |
if args.resume: | |
# Resume from a snapshot | |
chainer.serializers.load_npz(args.resume, trainer) | |
# Run the training | |
trainer.run() | |
chainer.serializers.save_npz('my_mnist.model', model) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment