Last active
November 15, 2015 10:36
-
-
Save ikostrikov/31812a244c0e6c9ea269 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'image' | |
require 'torch' | |
require 'nn' | |
require 'nngraph' | |
require 'optim' | |
require 'recurrent' | |
local cmd = torch.CmdLine() | |
cmd:text() | |
cmd:text('Training a simple character-level LSTM language model') | |
cmd:text() | |
cmd:text('Options') | |
cmd:option('-batch_size',16,'number of sequences to train on in parallel') | |
cmd:option('-seq_length',20,'number of timesteps to unroll to') | |
cmd:option('-rnn_size',10,'size of LSTM internal state') | |
cmd:option('-num_layers',1,'number of layers of LSTM') | |
cmd:option('-max_epochs',1000,'number of full passes through the training data') | |
cmd:option('-savefile','model_autosave','filename to autosave the model (protos) to, appended with the,param,string.t7') | |
cmd:option('-save_every',100,'save every 100 steps, overwriting the existing file') | |
cmd:option('-print_every',100,'how many steps/minibatches between printing out the loss') | |
cmd:option('-seed',123,'torch manual random number generator seed') | |
cmd:option('-cuda',false,'cuda') | |
cmd:text() | |
-- parse input params | |
local opt = cmd:parse(arg) | |
-- preparation stuff: | |
torch.manualSeed(opt.seed) | |
-- define model prototypes for ONE timestep, then clone them | |
-- | |
local rnn = RNN(1, opt.rnn_size, opt.num_layers, 0.0) | |
local recurrent = nn.RecurrentContainer(rnn.rnnModule) | |
recurrent:setState(rnn.initState, opt.batch_size) | |
recurrent:single() | |
recurrent:training() | |
local model = nn.Sequential() | |
model:add(recurrent) | |
local proj = nn.Linear(opt.rnn_size, 1) | |
local criterion = nn.MSECriterion() | |
if opt.cuda then | |
model:cuda() | |
criterion:cuda() | |
end | |
-- put the above things into one flattened parameters tensor | |
local params, grad_params = model:getParameters() | |
collectgarbage() | |
function make_batch(batch_size, seq_length, ind) | |
local xs = {} | |
local y = torch.Tensor(batch_size) | |
for t=1,seq_length do | |
local x = torch.rand(batch_size, 1) | |
table.insert(xs, x) | |
end | |
for b=1,batch_size do | |
y[b] = xs[ind[1]][b] + xs[ind[2]][b] | |
end | |
return xs, y | |
end | |
--dofile('test.lua') | |
-- do fwd/bwd and return loss, grad_params | |
local loss_count = 0 | |
function feval(params_) | |
if params_ ~= params then | |
params:copy(params_) | |
end | |
grad_params:zero() | |
------------------ get minibatch ------------------- | |
local x, y = make_batch(opt.batch_size, opt.seq_length, {2, 7}) | |
if opt.cuda then | |
for t=1,#x do | |
x[t] = x[t]:cuda() | |
y[t] = y[t]:cuda() | |
masks[t] = masks[t]:cuda() | |
end | |
end | |
------------------- forward pass ------------------- | |
local loss = 0 | |
local predictions = {} | |
for t=1,opt.seq_length do | |
-- we need only the last state for prediction | |
predictions[t] = model:forward(x[t]) | |
end | |
local final_prediction = proj:forward(predictions[#predictions]) | |
loss = loss + criterion:forward(final_prediction, y) | |
local doutput = criterion:backward(final_prediction, y) | |
local dproj = proj:backward(predictions[#predictions], doutput) | |
------------------ backward pass ------------------- | |
for t=opt.seq_length,1,-1 do | |
-- backprop through loss, and softmax/linear | |
doutput = model:backward(x[t], dproj) | |
end | |
-- clip gradient element-wise | |
grad_params:clamp(-5, 5) | |
return loss, grad_params | |
end | |
accLogger = optim.Logger(paths.concat('log', 'accuracy.log')) | |
-- optimization stuff | |
local losses = {} | |
local optim_state = {learningRate = 1e-2} | |
local iterations = opt.max_epochs * opt.batch_size | |
for i = 1, iterations do | |
optim.adadelta(feval, params, optim_state) | |
losses[#losses + 1] = loss_count | |
if i % opt.save_every == 0 then | |
torch.save(opt.savefile, protos) | |
end | |
if i % opt.print_every == 0 then | |
print(string.format("iteration %4d, loss = %6.8f, gradnorm = %6.4e", i, test(clones, init_state, 100, opt.batch_size, opt.seq_length), grad_params:norm())) | |
--accLogger:add{['% train accuracy'] = table.sum(losses)} | |
--accLogger:style{['% train accuracy'] = '-'} | |
--accLogger:plot() | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment