Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Forked from anonymous/train.lua
Created March 12, 2016 00:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szagoruyko/b9766469c2ede8799447 to your computer and use it in GitHub Desktop.
Save szagoruyko/b9766469c2ede8799447 to your computer and use it in GitHub Desktop.
require 'xlua'
local grad = require 'autograd'
local tablex = require 'pl.tablex'
grad.optimize(true)
local function cast(x)
if type(x) == 'table' then
for k,v in pairs(x) do x[k] = cast(v) end
return x
else
return x:float()
end
end
local bs = 8
local params = {
x = torch.randn(bs,28*28),
L1 = {
W = torch.randn(28*28,512),
b = torch.randn(512)
},
L2 = {
W = torch.randn(512,10),
b = torch.randn(10),
},
}
local L1, L2
L1, params.L1 = grad.nn.Linear(28*28,512)
L2, params.L2 = grad.nn.Linear(512,10)
local nonlin = grad.nn.Tanh()
cast(params)
print(params)
local f = function(params, inputs, targets)
local x = params.x
local O1 = L1(params.L1, x)
local O2 = L2(params.L2, nonlin(O1))
return O2
end
local g = grad(function(params, inputs, targets)
local y = f(params)
return grad.loss.crossEntropy(y, targets)
end)
local provider = torch.load'../cifar2.torch/datasets/mnist.t7'
cast(provider)
print(provider)
local opt = {
learningRate = 1e-3,
momentum = 0.9,
weightDecay = 0.0005,
dampening = 0,
}
local states = {
L1 = {
tablex.deepcopy(opt),
tablex.deepcopy(opt),
},
L2 = {
tablex.deepcopy(opt),
tablex.deepcopy(opt),
}
}
function train()
local loss = 0
local targets = cast(torch.Tensor(bs,10))
local indices = torch.randperm(provider.trainData.data:size(1)):long():split(bs)
indices[#indices] = nil
for t,v in ipairs(indices) do
xlua.progress(t,#indices)
local inputs = provider.trainData.data:index(1,v)
local target_idx = provider.trainData.labels:index(1,v)
targets:zero():scatter(2,target_idx:long():view(bs,1),1)
params.x:copy(inputs)
local grads, loss = g(params, inputs, targets)
for k,v in pairs(states) do
optim.sgd(function(x) return loss, grads[k][1] end, params[k][1], v[1])
optim.sgd(function(x) return loss, grads[k][2] end, params[k][2], v[2])
end
end
return loss
end
function test()
local confusion = optim.ConfusionMatrix(10)
for i=1,provider.testData.data:size(1),bs do
local inputs = provider.testData.data:narrow(1,i,bs)
local targets = provider.testData.labels:narrow(1,i,bs)
params.x:copy(inputs)
confusion:batchAdd(f(params), targets)
end
confusion:updateValids()
local test_acc = confusion.totalValid * 100
print(confusion)
local outputs = f(params)
end
train()
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment