Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created February 16, 2016 03:02
Show Gist options
  • Save jcjohnson/7dd6bfc107b26207b1e6 to your computer and use it in GitHub Desktop.
Save jcjohnson/7dd6bfc107b26207b1e6 to your computer and use it in GitHub Desktop.
require 'torch'
require 'nn'
require 'cutorch'
require 'cunn'
require 'cudnn'
require 'optim'
require 'hdf5'
require 'image'
local cmd = torch.CmdLine()
cmd:option('-data_h5_file', 'cv/TinyImageNetA.h5')
cmd:option('-output_h5_file', 'cv/torch_model.h5')
cmd:option('-output_t7_file', 'cv/torch_model.t7')
local opt = cmd:parse(arg)
local function load_data(data_file)
print(data_file)
local f = hdf5.open(data_file)
local dset = {}
dset.X_train = f:read('/X_train'):all()
dset.y_train = f:read('/y_train'):all() + 1
dset.X_val = f:read('/X_val'):all()
dset.y_val = f:read('/y_val'):all() + 1
f:close()
return dset
end
local function get_minibatch(X, y, batch_size)
local mask = torch.LongTensor(batch_size):random(X:size(1))
local X_batch = X:index(1, mask)
local y_batch = y:index(1, mask)
return X_batch, y_batch
end
function random_crop_flip(batch)
local crop_size = 16
local H, W = batch:size(3), batch:size(4)
local h, w = H - crop_size, W - crop_size
local x0 = torch.random(1 + crop_size)
local y0 = torch.random(1 + crop_size)
local cropped = batch[{{}, {}, {y0, y0 + h - 1}, {x0, x0 + w - 1}}]
if torch.random(2) == 1 then
for i = 1, cropped:size(1) do
cropped[i] = image.hflip(cropped[i])
end
end
return cropped
end
function random_flip(batch)
local flipped = batch:clone()
for i = 1, batch:size(1) do
if torch.random(2) == 1 then
flipped[i] = image.hflip(flipped[i])
end
end
return flipped
end
function center_crop(batch)
local crop_size = 16
local H, W = batch:size(3), batch:size(4)
local h, w = H - crop_size, W - crop_size
local x0, y0 = crop_size / 2, crop_size / 2
local cropped = batch[{{}, {}, {y0, y0 + h - 1}, {x0, x0 + w - 1}}]
return cropped:clone()
end
local function check_accuracy(X, y, model, batch_size)
model:evaluate()
local num_correct = 0
local num_tested = 0
for t = 1, 20 do
local X_batch, y_batch = get_minibatch(X, y, batch_size)
-- X_batch = center_crop(X_batch)
X_batch = X_batch:cuda()
y_batch = y_batch:cuda()
local scores = model:forward(X_batch)
local _, y_pred = scores:max(2)
num_correct = num_correct + torch.eq(y_pred, y_batch):sum()
num_tested = num_tested + batch_size
end
return num_correct / num_tested
end
local function build_model()
--[[
-- This is what the Python code expects now
local num_filters = {64, 64, 128, 128, 256, 256, 512}
local filter_sizes = {5, 3, 3, 3, 3, 3, 3}
local filter_strides = {2, 1, 2, 1, 2, 1, 2}
local num_classes = 100
local hidden_dim = 1024
local image_size = 64 - 16
--]]
local num_filters = {64, 64, 128, 128, 256, 256, 512, 512, 1024}
local filter_sizes = {5, 3, 3, 3, 3, 3, 3, 3, 3}
local filter_strides = {2, 1, 2, 1, 2, 1, 2, 1, 2}
local dropout = {0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5}
local num_classes = 100
local hidden_dim = 512
local image_size = 64
local prev_dim = 3
local cur_size = image_size
local model = nn.Sequential()
for i = 1, #num_filters do
local next_dim = num_filters[i]
local size = filter_sizes[i]
local stride = filter_strides[i]
local pad = (size - 1) / 2
model:add(nn.SpatialConvolution(prev_dim, next_dim,
size, size, stride, stride, pad, pad))
model:add(nn.SpatialBatchNormalization(next_dim))
model:add(nn.ReLU(true))
model:add(nn.Dropout(dropout[i]))
prev_dim = next_dim
if stride == 2 then
cur_size = cur_size / 2
end
end
local fan_in = cur_size * cur_size * num_filters[#num_filters]
model:add(nn.View(-1):setNumInputDims(3))
model:add(nn.Linear(fan_in, hidden_dim))
model:add(nn.BatchNormalization(hidden_dim))
model:add(nn.Dropout(0.8))
model:add(nn.ReLU(true))
model:add(nn.Linear(hidden_dim, num_classes))
return model
end
local function save_model(model, out_file)
local next_weight_idx = 1
local next_bn_idx = 1
local f = hdf5.open(out_file, 'w')
for i = 1, #model do
local layer = model:get(i)
if torch.isTypeOf(layer, nn.SpatialConvolution) or
torch.isTypeOf(layer, nn.Linear) then
f:write(string.format('/W%d', next_weight_idx), layer.weight:float())
f:write(string.format('/b%d', next_weight_idx), layer.bias:float())
next_weight_idx = next_weight_idx + 1
elseif torch.isTypeOf(layer, nn.SpatialBatchNormalization) or
torch.isTypeOf(layer, nn.BatchNormalization) then
f:write(string.format('/gamma%d', next_bn_idx), layer.weight:float())
f:write(string.format('/beta%d', next_bn_idx), layer.bias:float())
f:write(string.format('/running_mean%d', next_bn_idx), layer.running_mean:float())
if torch.isTypeOf(layer, nn.BatchNormalization) then
f:write(string.format('/running_var%d', next_bn_idx),
torch.pow(layer.running_std, -2.0):add(-layer.eps):float())
elseif torch.isTypeOf(layer, nn.SpatialBatchNormalization) then
f:write(string.format('/running_var%d', next_bn_idx),
layer.running_var:float())
end
next_bn_idx = next_bn_idx + 1
end
end
f:close()
end
local dset = load_data(opt.data_h5_file)
local model = build_model()
print(model)
cudnn.convert(model, cudnn)
model:cuda()
model:training()
local crit = nn.CrossEntropyCriterion():cuda()
local num_iterations = 120000
local reg = 1e-3
local batch_size = 50
local config = {
learningRate=1e-1,
}
local t = 0
local params, gradParams = model:getParameters()
local function f(w)
gradParams:zero()
local X_batch, y_batch = get_minibatch(dset.X_train, dset.y_train, batch_size)
-- X_batch = random_crop_flip(X_batch):cuda()
X_batch = random_flip(X_batch):cuda()
y_batch = y_batch:cuda()
assert(w == params)
local scores = model:forward(X_batch)
local data_loss = crit:forward(scores, y_batch)
local dscores = crit:backward(scores, y_batch)
model:backward(X_batch, dscores)
-- add regularization
gradParams:add(reg, params)
if t % 100 == 0 then
print(t, data_loss, torch.abs(gradParams):mean())
end
return data_loss, gradParams
end
while t < num_iterations do
t = t + 1
-- optim.adam(f, params, config)
optim.sgd(f, params, config)
-- Check training and validation accuracy once in a while
if t % 200 == 0 then
local train_acc = check_accuracy(dset.X_train, dset.y_train, model, batch_size)
local val_acc = check_accuracy(dset.X_val, dset.y_val, model, batch_size)
print('train acc: ', train_acc, 'val_acc: ', val_acc)
model:training()
end
if t % 7500 == 0 then
config.learningRate = config.learningRate / 1.5
end
--[[
-- This schedule works well for adam, starting from 1e-3
if t % 4000 == 0 then
config.learningRate = config.learningRate / 2.0
end
--]]
end
save_model(model, opt.output_h5_file)
torch.save(opt.output_t7_file, model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment