Created
December 7, 2015 10:58
-
-
Save etrulls/06bcb54582c9fa4e8c31 to your computer and use it in GitHub Desktop.
26killer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'lfs' | |
require 'xlua' | |
-- color output | |
colors = require 'ansicolors' | |
COL_LIGHTRED = '%{bright red}' | |
COL_DARKRED = '%{reset red}' | |
COL_LIGHTGREEN = '%{bright green}' | |
COL_DARKGREEN = '%{reset green}' | |
COL_LIGHTYELLOW = '%{bright yellow}' | |
COL_DARKYELLOW = '%{reset yellow}' | |
COL_LIGHTBLUE = '%{bright blue}' | |
COL_DARKBLUE = '%{reset blue}' | |
COL_LIGHTMAGENTA = '%{bright magenta}' | |
COL_DARKMAGENTA = '%{reset magenta}' | |
COL_LIGHTCYAN = '%{bright cyan}' | |
COL_DARKCYAN = '%{reset cyan}' | |
COL_RESET = '%{reset white}' | |
COL_ERROR = COL_DARKRED | |
COL_WARNING = COL_LIGHTRED | |
COL_EPOCH = COL_DARKMAGENTA | |
COL_HEADER = COL_LIGHTCYAN | |
COL_RESULT = COL_DARKGREEN | |
COL_TIMING = COL_DARKYELLOW | |
COL_SETTING = COL_DARKBLUE | |
COL_DEBUG = COL_LIGHTMAGENTA | |
-- convenience functions | |
function printf(msg, ...) | |
print( string.format( msg, ... ) ) | |
end | |
function printc(clr, msg, ...) | |
print( colors( clr .. string.format( msg, ... ) ) ) | |
end | |
-- easy quit | |
function quit(msg) | |
printc( COL_ERROR, msg ) | |
os.exit(-1) | |
end | |
-- print nice table (this is very very simple) | |
-- INPUTS: table, name | |
function print_table(t, n) | |
if n == nil then | |
print('{') | |
else | |
print(n .. ' {') | |
end | |
if t ~= nil then | |
for k,v in pairs(t) do | |
--print( string.format( colors(' ' .. COL_RESET .. '%s:' .. COL_DARKCYAN .. ' %s'), k, v ) ) | |
print( string.format( colors(' ' .. '%s:' .. COL_DARKCYAN .. ' %s'), k, v ) ) | |
end | |
end | |
print('}') | |
end | |
-- check if a file exists | |
function file_exists(name) | |
local f=io.open(name,'r') | |
if f~=nil then io.close(f) return true else return false end | |
end | |
-- check if file is a folder | |
function is_folder(name) | |
local attr = lfs.attributes(name) | |
if type(attr) ~= "table" then | |
return false | |
end | |
if attr.mode == "directory" then | |
return true | |
else | |
return false | |
end | |
end | |
-- explode a string | |
function explode(str, sep) | |
if sep == nil then | |
sep = "%s" | |
end | |
local t = {} | |
i = 1 | |
for s in string.gmatch(str, '([^' .. sep .. ']+)') do | |
t[i] = s | |
i = i + 1 | |
end | |
return t | |
end | |
-- hdf5 stuff | |
function hdf5_load( file, vars ) | |
if vars == nil then | |
error('Call: hdf5_load( file, {"var1", "var2"} )') | |
end | |
-- caution | |
if torch.type(vars) ~= 'table' then | |
vars = {vars} | |
end | |
require 'hdf5' | |
local file = hdf5.open(file, 'r') | |
local output = {} | |
for i=1,#vars do | |
if torch.type( vars[i] ) ~= 'string' then | |
error('hdf5_load: must pass variable names as strings') | |
end | |
output[i] = file:read( '/' .. vars[i] ):all() | |
end | |
file:close() | |
return unpack( output ) | |
end | |
function hdf5_save( file, data ) | |
if data == nil then | |
error('Call: hdf5_save( file, {variables=data} )') | |
end | |
require 'hdf5' | |
if file_exists( file ) then | |
printc( COL_WARNING, 'Warning: File "%s" exists', file ) | |
end | |
local file = hdf5.open(file, 'w') | |
for k,v in pairs(data) do | |
file:write( '/' .. k, v ) | |
end | |
file:close() | |
end | |
-- softer progress bar | |
function progress( curr, total, skip ) | |
if curr % skip == 0 or curr == total then | |
xlua.progress( curr, total ) | |
end | |
end | |
-- temporary fix | |
-- hard angle, in degrees, {0, 90, 180, 270} | |
function hard_rotate( im, angle ) | |
-- sanity checks | |
if im:nDimension() ~= 3 then | |
error( 'Image must have three dimensions' ) | |
end | |
if im:size(2) ~= im:size(3) then | |
error( 'Image must be square' ) | |
end | |
local s = im:size(2) | |
-- conveniency deep copy: no need to figure out the tensor type | |
local out = im:clone() | |
-- hard copy | |
if angle == 90 then | |
for x=1,s do | |
for y=1,s do | |
out[{{},s-y+1,x}] = im[{{},x,y}] | |
end | |
end | |
elseif angle == 180 then | |
for x=1,s do | |
for y=1,s do | |
out[{{},x,y}] = im[{{},s-x+1,s-y+1}] | |
end | |
end | |
elseif angle == 270 then | |
for x=1,s do | |
for y=1,s do | |
out[{{},x,y}] = im[{{},s-y+1,x}] | |
end | |
end | |
elseif angle ~= 0 then | |
error( 'Valid rotation values: {0, 90, 180, 270}' ) | |
end | |
return out | |
end | |
-- replicate function used while rotation was bugged | |
-- angle in degrees | |
function soft_rotate( conf, im, angle ) | |
-- sanity checks | |
if im:nDimension() ~= 3 then | |
error( 'Image must have three dimensions' ) | |
end | |
if im:size(2) ~= im:size(3) then | |
error( 'Image must be square' ) | |
end | |
-- soft rotate | |
local out = image.rotate( im:double():clone(), angle * math.pi / 180, 'bilinear' ) | |
-- convert | |
if im:type() == 'torch.FloatTensor' then | |
out = out:float() | |
elseif im:type() == 'torch.ByteTensor' then | |
out = out:round() | |
out:clamp(0, 255) | |
elseif im:type() ~= 'torch.DoubleTensor' then | |
error( 'Unsupported tensor type' ) | |
end | |
-- crop | |
local bias = 1 | |
return out[{ {1,conf.aug_patch[1]}, {1+conf.extra_h+bias,conf.aug_patch[2]-conf.extra_h+bias}, {1+conf.extra_w+bias,conf.aug_patch[3]-conf.extra_w+bias} }]:clone() | |
end | |
-- rotate with the warp function | |
function rotate_warp( im, angle ) | |
assert( im:nDimension() == 3 ) | |
local width = im:size()[3] | |
local height = im:size()[2] | |
local grid_y = torch.ger( torch.linspace(-1,1,height), torch.ones(width) ) | |
local grid_x = torch.ger( torch.ones(height), torch.linspace(-1,1,width) ) | |
local flow = torch.DoubleTensor() | |
flow:resize(2, height, width) | |
flow:zero() | |
-- Apply uniform scale | |
local flow_scale = torch.DoubleTensor() | |
flow_scale:resize(2, height, width) | |
flow_scale[1] = grid_y | |
flow_scale[2] = grid_x | |
flow_scale[1]:add(1):mul(0.5) -- 0 to 1 | |
flow_scale[2]:add(1):mul(0.5) -- 0 to 1 | |
flow_scale[1]:mul(height) | |
flow_scale[2]:mul(width) | |
flow:add(flow_scale) | |
local flow_rot = torch.DoubleTensor() | |
flow_rot:resize(2, height, width) | |
flow_rot[1] = grid_y * ((height-1)/2) * -1 | |
flow_rot[2] = grid_x * ((width-1)/2) * -1 | |
local view = flow_rot:reshape(2, height*width) | |
local rotmat = rmat( angle ) | |
local flow_rotr = torch.mm( rotmat, view ) | |
flow_rot = flow_rot - flow_rotr:reshape( 2, height, width ) | |
flow:add(flow_rot) | |
-- simple, bilinear, bicubic, lanczos | |
return image.warp(im:double(), flow, 'bicubic', false) | |
--return image.warp(im:double(), flow, 'lanczos', false) | |
end | |
-- helper function | |
function rmat( deg ) | |
local r = deg / 180 * math.pi | |
return torch.DoubleTensor{{math.cos(r), -math.sin(r)}, {math.sin(r), math.cos(r)}} | |
end | |
-- image padding (from koraykv/fex) | |
function dimnarrow(x,sz,pad,dim) | |
local xn = x | |
for i=1,x:dim() do | |
if i > dim then | |
xn = xn:narrow(i,pad[i]+1,sz[i]) | |
end | |
end | |
return xn | |
end | |
function padzero(x,pad) | |
local sz = x:size() | |
for i=1,x:dim() do sz[i] = sz[i]+pad[i]*2 end | |
local xx = x.new(sz):zero() | |
local xn = dimnarrow(xx,x:size(),pad,-1) | |
xn:copy(x) | |
return xx | |
end | |
function padmirror(x,pad) | |
local xx = padzero(x,pad) | |
local sz = xx:size() | |
for i=1,x:dim() do | |
local xxn = dimnarrow(xx,x:size(),pad,i) | |
for j=1,pad[i] do | |
xxn:select(i,j):copy(xxn:select(i,pad[i]*2-j+1)) | |
xxn:select(i,sz[i]-j+1):copy(xxn:select(i,sz[i]-pad[i]*2+j)) | |
end | |
end | |
return xx | |
end | |
-- print confusion matrix directly from the matrix | |
function print_confmat( mat, title, color ) | |
local digits = string.format( '%d', mat:max() ):len() | |
if not title then | |
title = string.format( 'Confusion matrix (%d):', digits ) | |
end | |
if not color then | |
color = COL_RESET | |
end | |
printc( color, title .. ':' ) | |
for i=1,mat:size(1) do | |
local s = nil | |
if i == 1 then | |
s = ' [' | |
else | |
s = ' ' | |
end | |
for j=1,mat:size(2) do | |
s = string.format( string.format(' %%s %%%dd', digits), s, mat[i][j] ) | |
end | |
if i==mat:size(1) then | |
s = s .. ' ]' | |
end | |
printc( color, s ) | |
end | |
local accum, avg = 0, 0 | |
for i=1,mat:size(1) do | |
local curr = mat[i][i] / mat[i]:sum() | |
avg = avg + curr | |
printc( color, ' Avg. accuracy, class "%d": %.03f', i, curr ) | |
accum = accum + mat[i][i] | |
end | |
printc( color, ' Avg. accuracy: %.03f', accum / mat:sum() ) | |
printc( color, ' Avg. class accuracy: %.03f', avg / mat:size(1) ) | |
end | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'cudnn' | |
local SpatialConvolutionUpsample, parent = torch.class('cudnn.SpatialConvolutionUpsample','cudnn.SpatialConvolution') | |
function SpatialConvolutionUpsample:__init(nInputPlane, nOutputPlane, kW, kH, factor, groups) | |
factor = factor or 2 | |
assert(kW and kH and nInputPlane and nOutputPlane) | |
assert(kW % 2 == 1, 'kW has to be odd') | |
assert(kH % 2 == 1, 'kH has to be odd') | |
self.factor = factor | |
self.kW = kW | |
self.kH = kH | |
self.nInputPlaneU = nInputPlane | |
self.nOutputPlaneU = nOutputPlane | |
parent.__init(self, nInputPlane, nOutputPlane * factor * factor, kW, kH, 1, 1, (kW-1)/2, (kH-1)/2, groups) | |
end | |
function SpatialConvolutionUpsample:updateOutput(input) | |
self.output = parent.updateOutput(self, input) | |
if input:dim() == 4 then | |
self.h = input:size(3) | |
self.w = input:size(4) | |
self.output = self.output:view(input:size(1), self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) | |
else | |
self.h = input:size(2) | |
self.w = input:size(3) | |
self.output = self.output:view(self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) | |
end | |
return self.output | |
end | |
function SpatialConvolutionUpsample:updateGradInput(input, gradOutput) | |
if not gradOutput:isContiguous() then | |
--gradOutput = gradOutput:resizeAs(gradOutput):copy(gradOutput) | |
gradOutput = gradOutput:clone() | |
end | |
assert( gradOutput:isContiguous() ) | |
if input:dim() == 4 then | |
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
else | |
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
end | |
self.gradInput = parent.updateGradInput(self, input, gradOutput) | |
return self.gradInput | |
end | |
function SpatialConvolutionUpsample:accGradParameters(input, gradOutput, scale) | |
if not gradOutput:isContiguous() then | |
--gradOutput = gradOutput:resizeAs(gradOutput):copy(gradOutput) | |
gradOutput = gradOutput:clone() | |
end | |
assert( gradOutput:isContiguous() ) | |
if input:dim() == 4 then | |
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
else | |
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
end | |
parent.accGradParameters(self, input, gradOutput, scale) | |
end | |
function SpatialConvolutionUpsample:accUpdateGradParameters(input, gradOutput, scale) | |
if input:dim() == 4 then | |
gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
else | |
gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) | |
end | |
parent.accUpdateGradParameters(self, input, gradOutput, scale) | |
end | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'paths' | |
require 'optim' | |
require 'lfs' | |
require 'xlua' | |
require 'math' | |
require 'cutorch' | |
require 'cunn' | |
require 'cudnn' | |
require 'image' | |
-- local dependencies | |
dofile( './optimizer.lua' ) | |
dofile( './common.lua' ) | |
dofile( './deconvolution.lua' ) | |
function error_rates_from_confmat( mat ) | |
assert( mat:numel() == 4 ) | |
local r_neg = mat[1][2] / (mat[1][1] + mat[1][2]) | |
local r_pos = mat[2][1] / (mat[2][1] + mat[2][2]) | |
local curr_res = (r_pos + r_neg) / 2 | |
return curr_res, r_pos, r_neg | |
end | |
torch.setnumthreads(1) | |
cutorch.setDevice(1) | |
torch.setdefaulttensortype('torch.FloatTensor') | |
dofile( 'model.lua' ) | |
input_size = get_input_size() | |
-- load data | |
local data = {} | |
data.classes = {'a', 'b'} | |
data.stack_train = torch.rand(1,100,1000,1000) | |
data.labels_train = torch.Tensor(1,100,1000,1000):bernoulli():add(1) | |
-- class weights | |
local num_pos = data.labels_train:eq( 2 ):sum() | |
local num_neg = data.labels_train:lt( 2 ):sum() | |
local num_pix = data.labels_train:numel() | |
assert( num_pos + num_neg == num_pix ) | |
local w_pos = num_neg / num_pix | |
local w_neg = num_pos / num_pix | |
printc( COL_SETTING, 'Samples: %d positives, %d negatives. Weights: [%.04f,%.04f]', num_pos, num_neg, w_pos, w_neg ) | |
collectgarbage() | |
torch.manualSeed( 666 ) | |
buffer = 8 / 2 | |
-- TRAINING ONLY | |
-- instantiate the model and reset the weights | |
print( 'Loading model definition' ) | |
model = build_model():cuda() | |
collectgarbage() | |
-- loss function | |
local loss_criterion = nn.ClassNLLCriterion( torch.Tensor{w_neg, w_pos} ) | |
-- move to the gpu | |
local loss_criterion = loss_criterion:cuda() | |
local weights, gradients = model:getParameters() | |
collectgarbage() | |
-- default optimizer parameters | |
optimizer_state = get_initial_optimizer_state() | |
-- instantiate the optimizer | |
local optimizer = Optimizer{ | |
ModeCPU = false, | |
Model = model, | |
Loss = loss_criterion, | |
OptFunction = _G.optim['sgd'], | |
OptState = optimizer_state, | |
Parameters = {weights, gradients} | |
} | |
-- feedback | |
local num_archived = 8 | |
local training_loss = torch.Tensor(num_archived):fill(-1) | |
-- sampled windows | |
local stride = 1e6 | |
local array_s = torch.range( 1, data.stack_train:size(2)-input_size[1]+1 ) | |
local array_y = torch.range( 1, data.stack_train:size(3)-input_size[2]+1, stride ) | |
local array_x = torch.range( 1, data.stack_train:size(4)-input_size[3]+1, stride ) | |
local ns = array_s:numel() | |
local ny = array_y:numel() | |
local nx = array_x:numel() | |
array_s:repeatTensor( array_s, ny * nx ) | |
array_y:repeatTensor( array_y, ns * nx ) | |
array_x:repeatTensor( array_x, ns * ny ) | |
local num_windows = ns * nx * ny | |
printc( COL_DARKMAGENTA, 'Iterating over %d windows', num_windows ) | |
collectgarbage() | |
-- loop | |
--local npos, nneg = 0,0 | |
local confusion_acc = optim.ConfusionMatrix( data.classes ) | |
local confusion_batch = optim.ConfusionMatrix( data.classes ) | |
for epoch=1,1e6 do | |
print('') | |
printc( COL_LIGHTMAGENTA, '{{{{{{ EPOCH %d }}}}}}', epoch ) | |
printc( COL_HEADER, '--==[[ Training ]]==--' ) | |
local windows = 0 | |
local order = torch.randperm( num_windows ) | |
model:training() | |
confusion_acc:zero() | |
local t_epoch = 0 | |
local acc_loss = 0 | |
for i=1,num_windows do | |
sys.tic() | |
confusion_batch:zero() | |
--progress( i, num_windows, 1 ) | |
local k = array_s[ order[i] ] | |
local y = array_y[ order[i] ] + torch.random( 0, torch.Tensor{ stride-1, data.stack_train:size(3) - array_y[ order[i] ] - input_size[2] + 1 }:min() ) | |
local x = array_x[ order[i] ] + torch.random( 0, torch.Tensor{ stride-1, data.stack_train:size(4) - array_x[ order[i] ] - input_size[3] + 1 }:min() ) | |
local slice = data.stack_train[{ 1,{k,k+input_size[1]-1},{y,y+input_size[2]-1},{x,x+input_size[3]-1} }] | |
local gt = data.labels_train[{ 1,{k,k+input_size[1]-1},{y+buffer,y+input_size[2]-1-buffer},{x+buffer,x+input_size[3]-1-buffer} }]:eq(2):float():add(1) | |
-- data augmentation | |
local rot = torch.Tensor{0, math.pi/2, math.pi, math.pi*3/2}[ torch.random(1,4) ] | |
if rot > 0 then | |
slice = image.rotate( slice, rot ) | |
gt = image.rotate( gt, rot ) | |
end | |
local flip = torch.bernoulli() | |
if flip == 1 then | |
slice = image.hflip( slice ) | |
gt = image.hflip( gt ) | |
end | |
-- adjust criterion weights online | |
local npos = gt:eq(2):sum() | |
local nneg = gt:eq(1):sum() | |
local w_pos = nneg / (npos + nneg) | |
local w_neg = npos / (npos + nneg) | |
--printf('wpos %.03f, wneg %.03f', w_pos, w_neg) | |
optimizer.Loss = nn.ClassNLLCriterion( torch.Tensor{w_neg, w_pos} ):cuda() | |
collectgarbage() | |
slice = slice:cuda() | |
gt = gt:reshape( gt:numel() ):cuda() | |
local curr_pred, curr_loss = optimizer:optimize(slice, gt) | |
acc_loss = acc_loss + curr_loss | |
confusion_acc:batchAdd( curr_pred, gt ) | |
confusion_batch:batchAdd( curr_pred, gt ) | |
local err, err_pos, err_neg = error_rates_from_confmat( confusion_batch.mat ) | |
-- feedback | |
local t_batch = sys.toc() | |
t_epoch = t_epoch + t_batch | |
printc( COL_DARKCYAN, 'Window %3d/%3d: [{ {%3d,%3d}, {%3d,%3d}, {%4d,%4d} }], #pos: %5d, avg rate: %.03f, pos rate: %.03f, neg rate: %.03f, loss: %.03f, time: %.02f sec', i, num_windows, k, k+input_size[1]-1, y, y+input_size[2]-1, x, x+input_size[3]-1, npos, 1-err, 1-err_pos, 1-err_neg, curr_loss, t_batch ) | |
end | |
print_confmat( confusion_acc.mat, 'In-training confmat', COL_DARKBLUE ) | |
for i=num_archived,2,-1 do | |
training_loss[i] = training_loss[i-1] | |
if training_loss[i] >= 0 then | |
printc( COL_DARKGREEN, 'Epoch %d, loss: %.04f', epoch-i+1, training_loss[i] ) | |
end | |
end | |
training_loss[1] = acc_loss | |
printc( COL_LIGHTGREEN, 'Epoch %d, loss: %.02f', epoch, training_loss[1]) | |
printc( COL_TIMING, 'Done in %.02f sec', t_epoch ) | |
epoch = epoch + 1 | |
collectgarbage() | |
end | |
::fin:: | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- model input: [z, y, x] | |
function get_input_size() | |
return {1, 572, 572} | |
end | |
function get_initial_optimizer_state() | |
return | |
{ | |
learningRate = 5e-5, | |
--learningRate = 1e-4, | |
momentum = .99, | |
weightDecay = 5e-6, | |
learningRateDecay = 1e-4 | |
} | |
end | |
function get_filter_size() | |
return 64 | |
end | |
function build_model() | |
-- TODO: add dropout, crop rather than zero-pad | |
local seq = nn.Sequential | |
local identity = nn.Identity | |
local concat = nn.DepthConcat | |
local conv = cudnn.SpatialConvolution | |
local pool = cudnn.SpatialMaxPooling | |
local rect = cudnn.ReLU | |
local deconv = cudnn.SpatialConvolutionUpsample | |
-- starting filter size: must follow f[x] == 2 * f[x-1] | |
local fs = get_filter_size() | |
-- sampling down | |
local d1 = seq() | |
d1:add( conv(1,fs,3,3) ) | |
d1:add( rect() ) | |
d1:add( conv(fs,fs,3,3) ) | |
d1:add( rect() ) | |
local d2 = seq() | |
d2:add( pool(2,2) ) | |
d2:add( conv(fs,2*fs,3,3) ) | |
d2:add( rect() ) | |
d2:add( conv(2*fs,2*fs,3,3) ) | |
d2:add( rect() ) | |
local d3 = seq() | |
d3:add( pool(2,2) ) | |
d3:add( conv(2*fs,4*fs,3,3) ) | |
d3:add( rect() ) | |
d3:add( conv(4*fs,4*fs,3,3) ) | |
d3:add( rect() ) | |
-- bottom branch | |
local b = seq() | |
b:add( pool(2,2) ) | |
b:add( conv(4*fs,8*fs,3,3) ) | |
b:add( rect() ) | |
b:add( conv(8*fs,8*fs,3,3) ) | |
b:add( rect() ) | |
b:add( deconv(8*fs,4*fs,1,1,2) ) | |
-- sampling up | |
local u3 = seq() | |
u3:add( conv(8*fs,4*fs,3,3) ) | |
u3:add( rect() ) | |
u3:add( conv(4*fs,4*fs,3,3) ) | |
u3:add( rect() ) | |
u3:add( deconv(4*fs,2*fs,1,1,2) ) | |
local u2 = seq() | |
u2:add( conv(4*fs,2*fs,3,3) ) | |
u2:add( rect() ) | |
u2:add( conv(2*fs,2*fs,3,3) ) | |
u2:add( rect() ) | |
u2:add( deconv(2*fs,fs,1,1,2) ) | |
local u1 = seq() | |
u1:add( conv(2*fs,fs,3,3) ) | |
u1:add( rect() ) | |
u1:add( conv(fs,fs,3,3) ) | |
u1:add( rect() ) | |
-- 1x1 convolutional layer | |
local o = conv(fs,2,1,1) | |
-- link up | |
local join3 = concat(1) | |
join3:add( b ) | |
join3:add( identity() ) | |
local sub3 = seq() | |
sub3:add( d3 ):add( join3 ):add( u3 ) | |
local join2 = concat(1) | |
join2:add( sub3 ) | |
join2:add( identity() ) | |
local sub2 = seq() | |
sub2:add( d2 ):add( join2 ):add( u2 ) | |
local join1 = concat(1) | |
join1:add( sub2 ) | |
join1:add( identity() ) | |
local sub1 = seq() | |
sub1:add( d1 ):add( join1 ):add( u1 ):add( o ) | |
local model = seq() | |
model:add( sub1 ) | |
local classifier = seq() | |
classifier:add( nn.View( 2,-1 ) ) | |
classifier:add( nn.Transpose( {2,1} ) ) | |
classifier:add( nn.LogSoftMax() ) | |
model:add( classifier ) | |
return model | |
end | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local Optimizer = torch.class('Optimizer') | |
function Optimizer:__init(...) | |
--xlua.require('torch',true) | |
--xlua.require('cunn',true) | |
local args = dok.unpack( | |
{...}, | |
'Optimizer','Initialize an optimizer', | |
{arg='ModeCPU', type ='boolean', help='CPU/GPU flag',req=true}, | |
{arg='Model', type ='table', help='Optimized model',req=true}, | |
{arg='Loss', type ='function', help='Loss function',req=true}, | |
{arg='L1Coeff', type ='number', help='L1 Regularization coeff',default=0}, | |
{arg='Parameters', type = 'table', help='Model parameters - weights and gradients',req=false}, | |
{arg='OptFunction', type = 'function', help = 'Optimization function' ,req = true}, | |
{arg='OptState', type = 'table', help='Optimization configuration', default = {}, req=false}, | |
{arg='HookFunction', type = 'function', help='Hook function of type fun(y,yt,err)', req = false} | |
) | |
-- dependencies | |
if args.ModeCPU then | |
require 'nn' | |
else | |
require 'cunn' | |
end | |
self.Model = args.Model | |
self.Loss = args.Loss | |
self.Parameters = args.Parameters | |
self.OptFunction = args.OptFunction | |
self.OptState = args.OptState | |
self.HookFunction = args.HookFunction | |
self.L1Coeff = args.L1Coeff | |
if self.Parameters == nil then | |
self.Parameters = {} | |
self.Weights, self.Gradients = self.Model:getParameters() | |
else | |
self.Weights, self.Gradients = self.Parameters[1], self.Parameters[2] | |
end | |
end | |
function Optimizer:optimize(x,yt) | |
local y, err, value | |
local f_eval = function() | |
self.Gradients:zero() | |
y = self.Model:forward(x) | |
err = self.Loss:forward(y,yt) | |
local dE_dy = self.Loss:backward(y,yt) | |
self.Model:backward(x, dE_dy) | |
if self.HookFunction then | |
value = self.HookFunction(y,yt,err) | |
end | |
if self.L1Coeff>0 then | |
self.Gradients:add(torch.sign(self.Weights):mul(self.L1Coeff)) | |
end | |
return err, self.Gradients | |
end | |
local opt_value = self.OptFunction(f_eval, self.Weights, self.OptState) | |
return y, err,value, opt_value | |
end | |
--function Optimizer:optimStates(opts)--opts must be of for {{weight = optimState, bias = optimState} .... } | |
-- for i, optimState in ipairs(opts) do | |
--local weight_size = self.Weights:size(1) | |
--local learningRates = torch.Tensor(weight_size):fill(self.OptState.learningRate) | |
--local weightDecays = torch.Tensor(weight_size):fill(self.OptState.weightDecay) | |
--local counter = 0 | |
--for i, layer in ipairs(model.modules) do | |
-- local weight_size = layer.weight:size(1)*layer.weight:size(2) | |
-- learningRates[{{counter+1, counter+weight_size}}]:fill(1) | |
-- weightDecays[{{counter+1, counter+weight_size}}]:fill(wds) | |
-- counter = counter+weight_size | |
-- local bias_size = layer.bias:size(1) | |
-- learningRates[{{counter+1, counter+bias_size}}]:fill(2) | |
-- weightDecays[{{counter+1, counter+bias_size}}]:fill(0) | |
-- counter = counter+bias_size | |
-- end | |
--end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment