Last active
June 4, 2017 14:02
-
-
Save ytbilly3636/fcc44484f739d58b8861f694061e4560 to your computer and use it in GitHub Desktop.
caffe2用
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# 読み込むモジュール | |
import numpy as np | |
from caffe2.python import core, cnn, workspace | |
from matplotlib import pyplot | |
# 初期化 | |
core.GlobalInit(['caffe2', '--caffe2_log_level=0']) | |
caffe2_root = '~/caffe2' | |
# ネットワークの入力 | |
def AddInput(model, batch_size, db, db_type): | |
# DBからデータ(画像、ラベル)の読み込み | |
data, label_ = model.TensorProtosDBInput([], ['data', 'label_'], batch_size=batch_size, db=db, db_type=db_type) | |
# データ型の変換: float -> int | |
label = model.Cast(label_, "label", to=core.DataType.INT32) | |
# 逆伝播時に微分の計算をしない(入力だから) | |
data = model.StopGradient(data, data) | |
return data, label | |
# ネットワークの本体 | |
def AddNetwork(model, data): | |
# NxCxHxW: bx3x32x32 -> bx32x32x32 -> bx32x16x16 | |
conv1 = model.Conv(data, 'conv1', dim_in=3, dim_out=32, kernel=5, stride=1, pad=2) | |
pool1 = model.MaxPool(conv1, 'pool1', kernel=2, stride=2) | |
relu1 = model.Relu(pool1, 'relu1') | |
# NxCxHxW: bx16x16x16 -> bx16x16x16 -> bx32x8x8 | |
conv2 = model.Conv(relu1, 'conv2', dim_in=32, dim_out=64, kernel=5, stride=1, pad=2) | |
relu2 = model.Relu(conv2, 'relu2') | |
pool2 = model.MaxPool(relu2, 'pool2', kernel=2, stride=2) | |
# NxCxHxW: bx32x8x8 -> bx64x8x8 -> bx64x4x4 | |
conv3 = model.Conv(pool2, 'conv3', dim_in=64, dim_out=64, kernel=5, stride=1, pad=2) | |
relu3 = model.Relu(conv3, 'relu3') | |
pool3 = model.MaxPool(relu3, 'pool3', kernel=2, stride=2) | |
fc3 = model.FC(pool3, 'fc3', 64 * 4 * 4, 100) | |
fc3 = model.Relu(fc3, fc3) | |
pred = model.FC(fc3, 'pred', 100, 10) | |
softmax = model.Softmax(pred, 'softmax') | |
return softmax | |
# ネットワークの正解率 | |
def AddAccuracy(model, softmax, label): | |
accuracy = model.Accuracy([softmax, label], 'accuracy') | |
return accuracy | |
# 学習 | |
def AddTrainingOperators(model, softmax, label): | |
# クロスエントロピーの計算 | |
xent = model.LabelCrossEntropy([softmax, label], 'xent') | |
# クロスエントロピーの平均損失の計算 | |
loss = model.AveragedLoss(xent, 'loss') | |
# 正解率の計算 | |
AddAccuracy(model, softmax, label) | |
# 損失関数の勾配を計算 | |
model.AddGradientOperators([loss]) | |
# 学習率の設定 | |
ITER = model.Iter('iter') | |
LR = model.LearningRate(ITER, 'LR', base_lr=-0.075, policy='step', stepsize=100, gamma=0.999) | |
# 更新に使う定数 | |
ONE = model.param_init_net.ConstantFill([], 'ONE', shape=[1], value=1.0) | |
# 全パラメータにおいて更新 | |
# param = param + param_grad * LR | |
for param in model.params: | |
param_grad = model.param_to_grad[param] | |
model.WeightedSum([param, ONE, param_grad, LR], param) | |
# CNNのモデル型を学習用として用意 | |
train_model = cnn.CNNModelHelper(order='NCHW', name='cifar_train') | |
# データセットの読み込み | |
data, label = AddInput(train_model, batch_size=50, db='cifar_train.minidb', db_type='minidb') | |
# ネットワークの設定 | |
softmax = AddNetwork(train_model, data) | |
# 学習の設定 | |
AddTrainingOperators(train_model, softmax, label) | |
# ネットワークの初期化 | |
workspace.RunNetOnce(train_model.param_init_net) | |
workspace.CreateNet(train_model.net) | |
# pyplot用 | |
total_iters = 2500 | |
accuracy = np.zeros(total_iters) | |
loss = np.zeros(total_iters) | |
# 学習 | |
for i in xrange(total_iters): | |
workspace.RunNet(train_model.net.Proto().name) | |
# グラフ描画 | |
accuracy[i] = workspace.FetchBlob('accuracy') | |
loss[i] = workspace.FetchBlob('loss') | |
pyplot.clf() | |
pyplot.plot(accuracy, 'r') | |
pyplot.plot(loss, 'b') | |
pyplot.legend(('accuracy', 'loss'), loc='upper right') | |
pyplot.pause(.01) | |
pyplot.savefig("graph.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment