Last active
January 11, 2017 09:07
-
-
Save kdplus/7c0b9451ad16e1cf80d2eff6f98241e3 to your computer and use it in GitHub Desktop.
Get 94% on my testdata
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'xlua' | |
require 'image' | |
require 'cunn' | |
require 'cudnn' | |
require 'nn' | |
require 'torch' | |
require 'optim' | |
require 'paths' | |
-- training config ----------------------------------------------------------------- | |
isnewnet = 0 | |
isvgg = 0 | |
showpred = 1 | |
showlogger = 0 | |
usetrainer = 0 | |
------------------------------------------ | |
-- load image ---------------------------------------------------------------------- | |
print("Training start") | |
loadSize = torch.Tensor(1) | |
loadSize = {3,224,224} | |
ori_imgs = torch.load('ori-imgs.t7') | |
imgs = ori_imgs.data | |
labels = ori_imgs.label | |
testset = torch.load('cifar10-test.t7') | |
print("imgs load finish") | |
-- image proccessing --------------------------------------------------------------- | |
-- load imgs from folder | |
local i = 0 | |
for file in paths.files(".") do | |
if file:find("jpg" .. '$') then | |
i = i + 1 | |
input = image.load(file) | |
input = image.scale(input, loadSize[2], loadSize[3]) | |
imgs[i] = input | |
labels[i] = 1 | |
end | |
end | |
ori_imgs.data = imgs | |
ori_imgs.label = labels | |
torch.save("ori-imgs.t7", ori_imgs) | |
function gamma(i, g) | |
local input = imgs[i]:double() | |
dst = 255 * torch.pow((input/255),(1/g)) | |
return dst | |
end | |
-- 105 = 3 * (5 * 7) include the traindata and validata | |
trainimg = torch.Tensor(105, 3, 224, 224) | |
trainlabel = torch.Tensor(105) | |
-- Gamma correction, 3 types for every pic | |
gamma_list = {1, 1/1.4, 4.2} | |
for i = 1, 35 do | |
for g = 1, 3 do | |
trainimg[(i-1)*3+g] = gamma(i, gamma_list[g]) | |
trainlabel[(i-1)*3+g] = labels[i] | |
end | |
end | |
-- subtract mean | |
--[[ | |
mean = {104, 117, 123} | |
stdv = {} | |
for i=1, 3 do | |
print('Channel ' .. i .. ', Mean: ' .. mean[i]) | |
trainimg[{ {}, {i}, {}, {} }]:add(-mean[i]) | |
--[[ | |
stdv[i] = trainimg[{ {}, {i}, {}, {} }]:std() | |
print('Channel ' .. i .. ', Standard Deviation:' .. stdv[i]) | |
trainimg[{ {}, {i}, {}, {} }]:div(stdv[i]) | |
--]] | |
--]] | |
mean = {} | |
stdv = {} | |
for i=1,3 do | |
mean[i] = trainimg[{ {}, {i}, {}, {} }]:mean() | |
print('Channel ' .. i .. ', Mean: ' .. mean[i]) | |
trainimg[{ {}, {i}, {}, {} }]:add(-118.380948) | |
stdv[i] = trainimg[{ {}, {i}, {}, {} }]:std() | |
print('Channel ' .. i .. ', Standard Deviation:' .. stdv[i]) | |
trainimg[{ {}, {i}, {}, {} }]:div(61.896913) | |
end | |
--itorch.image(trainimg[1]) | |
--trainset = torch.Tensor(2160, 3, 224, 224) | |
function rotate(i, x, delta) | |
local input = trainimg[i] | |
local dst = image.rotate(input, (x*18+delta-1) * math.pi / 180, 'bilinear') | |
local out = image.crop(dst, 55, 55, 169, 169) | |
out = image.scale(out, 224, 224) | |
return out | |
end | |
--itorch.image(rotate(180 ,50)) | |
-- Load Net ------------------------------------------------------------------------ | |
cudnn.fastest = true | |
--net = torch.load('best_model.t7') | |
if isnewnet==1 then | |
if isvgg == 1 then | |
net = torch.load('vgg16.t7') | |
net:remove() -- nn.SoftMax | |
net:remove() -- nn.Linear | |
net:remove() | |
net:remove() | |
net:remove() | |
net:remove() | |
net:remove() | |
net:remove() | |
net:add(nn.Linear(25088,4096)) | |
net:add(nn.ReLU()) | |
net:add(nn.Dropout(0.500000)) | |
net:add(nn.Linear(4096, 4096)) | |
net:add(nn.ReLU()) | |
net:add(nn.Dropout(0.500000)) | |
net:add(nn.Linear(4096, 10)) | |
net:add(cudnn.SoftMax()) | |
else | |
net = torch.load'./alexnet_torch.t7':unpack() | |
net:remove() | |
local cnn1 = nn.Sequential() | |
cnn1:add(nn.View(256*6*6)) | |
cnn1:add(nn.Dropout(0.5)) | |
cnn1:add(nn.Linear(256*6*6,4096)) | |
cnn1:add(nn.Threshold(0,1e-6)) | |
cnn1:add(nn.Dropout(0.5)) | |
cnn1:add(nn.Linear(4096,4096)) | |
cnn1:add(nn.Threshold(0,1e-6)) | |
cnn1:add(nn.Linear(4096,5)) | |
cnn1:add(cudnn.LogSoftMax()) | |
net:add(cnn1) | |
end | |
else | |
net = torch.load('training_model.t7') | |
print("load training model") | |
end | |
print(net) | |
net = net:cuda() | |
print("load net finish") | |
criterion = nn.ClassNLLCriterion():cuda() | |
if usetrainer == 1 then | |
net = net:cuda() | |
criterion = criterion:cuda() | |
trainimg = trainimg:cuda() | |
trainlabel = trainlabel:cuda() | |
end | |
-- for record log ---------------------------------------------- | |
if showlogger == 1 then | |
logger = optim.Logger('accuracy.log') | |
logger:setNames{'Training acc.', 'vali acc.'} | |
logger2 = optim.Logger('loss.log') | |
logger2:setNames{'Training loss', 'vali loss'} | |
end | |
if isnewnet == 1 then | |
trainAccLog = {} | |
valiAccLog = {} | |
trainLossLog = {} | |
valiLossLog = {} | |
else | |
trainAccLog = torch.load('trainAccLog.t7') | |
valiAccLog = torch.load('valiAccLog.t7') | |
trainLossLog = torch.load('trainLossLog.t7') | |
valiLossLog = torch.load('valiLossLog.t7') | |
end | |
-- load log to logger | |
if showlogger == 1 then | |
for key,value in pairs(trainAccLog) do | |
logger:add{value, valiAccLog[key]} | |
logger2:add{trainLossLog[key], valiLossLog[key]} | |
end | |
logger:style{'+-', '+-'} | |
logger:plot() | |
logger2:style{'+-', '+-'} | |
logger2:plot() | |
end | |
-- train Batch -------------------------------------------------------------------- | |
local iter = 0 | |
optimState = { | |
learningRate = 1e-5, | |
learningRateDecay = 1e-4, | |
bate1 = 0.9, | |
bate2 = 0.999, | |
epsilon = 1e-8, | |
--weightDecay = 0.0005, | |
--momentum = 0.9 | |
} | |
--if isnewnet==0 then | |
-- optimState.learningRate = 1e-4 | |
--end | |
function trainBatch(batchdata, batchlabel) | |
local params, gradParams = net:getParameters() | |
-- optimState | |
local loss = 0 | |
local hit = 0 | |
function feval(params) | |
--iter = iter + 1 | |
gradParams:zero() | |
local outputs = net:forward(batchdata) | |
loss = criterion:forward(outputs, batchlabel) | |
-- for debug, show the prediction in training | |
_, indices = torch.max(outputs, 2) | |
for i = 1, batchsize do | |
if indices[i][1] == batchlabel[i] then hit = hit + 1 end | |
if showpred == 1 then | |
print("pred: ", indices[i][1], " label: ", batchlabel[i]) | |
end | |
end | |
print("Hit: ", hit) | |
local dloss_doutputs = criterion:backward(outputs, batchlabel) | |
net:backward(batchdata, dloss_doutputs) | |
return loss, gradParams | |
end | |
_, fs = optim.adam(feval, params, optimState) | |
--[[ | |
if iter % 10 == 0 then | |
print('error for iteration ' .. iter .. ' is ' .. fs[1]) | |
end | |
--]] | |
return loss, hit | |
end | |
-- process index for trainer (if you use trainer to optim) --------------------------------- | |
function trainsetIndex() | |
setmetatable(trainset, | |
{__index = function(t, i) | |
return {t.data[i], t.label[i]} | |
end}); | |
trainset.data = trainset.data:double() -- convert the data from a ByteTensor to a DoubleTensor. | |
function trainset:size() | |
return self.data:size(1) | |
end | |
end | |
-- vali data test -------------------------------------------------------------------------- | |
-- load best acc value | |
if isnewnet == 1 then | |
best_vali_acc = 0 | |
else | |
best_vali_acc = torch.load('best_vali_acc.t7') | |
end | |
print("load best acc:", best_vali_acc) | |
-- test vali data in batch (cause my memory cannot run in one time) | |
allValiSize = 900 | |
valisize = 300 | |
valiBase = 0 | |
valiorder = torch.randperm(allValiSize) | |
function get_vali_acc(isforall) | |
validata = torch.Tensor(valisize,3,224,224) | |
valilabel = torch.Tensor(valisize) | |
-- do not update the valiorder if it is test on all validata | |
if isforall == 0 then | |
valiorder = torch.randperm(allValiSize) | |
end | |
-- base is to avoid the training data, valiBase is for batch base | |
-- prepare the vali data into testset | |
local base = 5400 | |
for i = 1, valisize do | |
index = valiorder[i+valiBase] + base | |
index = math.floor((index-1) / 60) + 1 | |
angleLevel = (valiorder[i+valiBase] + base - 1) % 60 | |
angle = math.floor(angleLevel / 3) | |
angleDelta = angleLevel % 3 | |
validata[i] = rotate(index ,angle, angleDelta) | |
valilabel[i] = angle + 1 | |
-- let the x and x+180 be the same class | |
if valilabel[i] > 10 then | |
valilabel[i] = valilabel[i] - 10 | |
end | |
valilabel[i] = math.floor((valilabel[i]-1) / 2) + 1 | |
end | |
testset.data = validata | |
testset.label = valilabel | |
correct = 0 | |
-- compute the loss and acc | |
local valiLoss = 0 | |
for i=1, valisize do | |
local groundtruth = testset.label[i] | |
-- for 4d tensor | |
tempdata = torch.Tensor(1, 3, 224, 224) | |
tempdata[1] = testset.data[i] | |
local prediction = net:forward(tempdata:cuda()) | |
loss = criterion:forward(prediction, testset.label[i]) | |
valiLoss = valiLoss + loss | |
local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order | |
if groundtruth == indices[1] then | |
correct = correct + 1 | |
end | |
end | |
valiLoss = valiLoss / valisize | |
print("Hit",correct,"in",valisize, " ", 100*correct/valisize .. ' % ') | |
return correct/valisize, valiLoss | |
end | |
-- manage the vali batch when test on all vali data | |
function get_all_vali_acc() | |
local acc = 0 | |
local loss = 0 | |
for i = 1, allValiSize/valisize do | |
-- update the new valiBase | |
valiBase = (i-1)*valisize | |
a, l = get_vali_acc(1) | |
acc = acc + a | |
loss = loss + l | |
end | |
acc = acc / (allValiSize/valisize) | |
loss = loss / (allValiSize/valisize) | |
-- update best model and best acc | |
if acc * 100 > best_vali_acc then | |
best_vali_acc = acc * 100 | |
print("It is best now!", best_vali_acc) | |
best_model = net | |
torch.save("best_model.t7", best_model) | |
torch.save("best_vali_acc.t7", best_vali_acc) | |
else | |
print("not the best!") | |
end | |
print("best_vali_acc = ", best_vali_acc .. "%") | |
valiBase = 0 | |
return acc, loss | |
end | |
-- use batchtraining ----------------------------------------------------------------------- | |
batchsize = 10 | |
trainsize = 5400 | |
-- use the different order in every time | |
if isnewnet == 0 then | |
batchorder = torch.randperm(trainsize) | |
torch.save('batchorder.t7', batchorder) | |
else | |
batchorder = torch.load('batchorder.t7') | |
end | |
batchdata = torch.Tensor(batchsize,3,224,224) | |
batchlabel = torch.Tensor(batchsize) | |
lastValiAcc = -1 | |
decayCnt = 0 | |
for e = 1,1000 do | |
local lossSum = 0 | |
local hit = 0 | |
local j = 0 | |
for i = 1, (trainsize/batchsize) do | |
print(" ") | |
print("epoch: ", e, " batch: ", i, "------------------------------------") | |
-- prepare the traindata for this batch | |
for j = 1, batchsize do | |
index = batchorder[(i-1)*batchsize+j] | |
-- get the actually img index since every pic have 3(delta)*20(class) degrees | |
index = math.floor((index-1) / 60) + 1 | |
angleLevel = (batchorder[(i-1)*batchsize+j]-1) % 60 | |
angle = math.floor(angleLevel / 3) | |
angleDelta = angleLevel % 3 | |
batchdata[j] = rotate(index ,angle, angleDelta) | |
batchlabel[j] = angle + 1 | |
if batchlabel[j] > 10 then | |
batchlabel[j] = batchlabel[j] - 10 | |
end | |
batchlabel[j] = math.floor((batchlabel[j] - 1)/ 2) + 1 | |
end | |
-- if use trainer optim | |
if usetrainer == 1 then | |
trainsetIndex() | |
trainset.data = batchdata:cuda() | |
trainset.label = batchlabel:cuda() | |
trainer = nn.StochasticGradient(net, criterion) | |
trainer.learningRate = 1e-5 | |
trainer.maxIteration = 2 | |
x = trainer:train(trainset) | |
print(x) | |
end | |
loss, hitOnce = trainBatch(batchdata:cuda(), batchlabel:cuda()) | |
lossSum = lossSum + loss | |
hit = hit + hitOnce | |
trainAcc = hit / i / batchsize | |
trainLoss = lossSum / i | |
print("lr: ", optimState.learningRate, "lrdecayCnt:", decayCnt) | |
print("HitSum: ", hit, "Size:", i * batchsize) | |
print ("train Loss: ", trainLoss, "train acc: ", trainAcc) | |
end | |
valiAllAcc, valiAllLoss = get_all_vali_acc() | |
print("valiAcc and valiLoss", valiAllAcc, valiAllLoss) | |
if showlogger == 1 then | |
logger2:add{trainLoss, valiLoss} | |
logger2:style{'+-', '+-'} | |
logger2:plot() | |
logger:add{trainAcc, valiAcc} | |
logger:style{'+-', '+-'} | |
logger:plot() | |
end | |
table.insert(valiAccLog, valiAllAcc) | |
table.insert(valiLossLog, valiAllLoss) | |
table.insert(trainAccLog, trainAcc) | |
table.insert(trainLossLog, trainLoss) | |
torch.save("valiAccLog.t7", valiAccLog) | |
torch.save("valiLossLog.t7", valiLossLog) | |
torch.save("trainAccLog.t7", trainAccLog) | |
torch.save("trainLossLog.t7", trainLossLog) | |
torch.save("training_model.t7", net) | |
-- update learning rate | |
--if isnewnet == 1 then | |
if decayCnt < 0 then | |
print("Last Vali Acc is ", lastValiAcc, "Now Vali Acc is ", valiAllAcc) | |
if valiAllAcc <= lastValiAcc then | |
optimState.learningRate = optimState.learningRate / 5 | |
print("Change the lr to ", optimState.learningRate) | |
decayCnt = decayCnt + 1 | |
else | |
print("keep the lr") | |
end | |
end | |
--end | |
lastValiAcc = valiAllAcc | |
print ("epoch:", e, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
print(lossSum) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment