Skip to content

Instantly share code, notes, and snippets.

@etrulls
Created September 23, 2015 09:05
Show Gist options
  • Save etrulls/1444933b86eeda31afd6 to your computer and use it in GitHub Desktop.
Save etrulls/1444933b86eeda31afd6 to your computer and use it in GitHub Desktop.
mwe depthconcat
require 'cutorch'
require 'cunn'
require 'cudnn'
torch.setnumthreads(1)
cutorch.setDevice(1)
torch.setdefaulttensortype('torch.FloatTensor')
-- model
c1 = nn.Sequential()
c1:add( cudnn.SpatialConvolution(1,10,3,3) )
c2 = nn.Sequential()
c2:add( cudnn.SpatialConvolution(1,10,5,5) )
join = nn.DepthConcat(1)
join:add( c1 )
join:add( c2 )
model = nn.Sequential()
model:add( join )
model:add( nn.SpatialConvolution(20,2,1,1) )
model:add( nn.View( 2,-1) )
model:add( nn.Transpose( {2,1} ) )
model:add( nn.LogSoftMax() )
model:cuda()
weights, gradients = model:getParameters()
loss_criterion = nn.ClassNLLCriterion():cuda()
-- loop
for epoch=1,1e6 do
gradients:zero()
data = torch.rand(1,25,25):cuda()
pred = model:forward( data )
gt_lin = torch.Tensor( pred:size(1) ):random(2):cuda()
err = loss_criterion:forward( pred, gt_lin )
grad_crit = loss_criterion:backward( pred, gt_lin )
model:backward( data, grad_crit )
print( string.format('Epoch %d, loss %.3f', epoch, err) )
epoch = epoch + 1
collectgarbage()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment