Skip to content

Instantly share code, notes, and snippets.

@soumith
Created June 22, 2015 21:35
Show Gist options
  • Save soumith/3f87517ad69e1620c70d to your computer and use it in GitHub Desktop.
Save soumith/3f87517ad69e1620c70d to your computer and use it in GitHub Desktop.
LSUN eyescream
model_G = nn.Sequential()
model_G:add(nn.JoinTable(2, 2))
model_G:add(cudnn.SpatialConvolutionUpsample(3+1, 64, 7, 7, 1, 1)):add(cudnn.ReLU(true))
model_G:add(nn.SpatialBatchNormalization(64, nil, nil, false))
model_G:add(cudnn.SpatialConvolutionUpsample(64, 368, 7, 7, 1, 4)):add(cudnn.ReLU(true))
model_G:add(nn.SpatialBatchNormalization(368, nil, nil, false))
model_G:add(nn.SpatialDropout(0.5))
model_G:add(cudnn.SpatialConvolutionUpsample(368, 128, 7, 7, 1, 4)):add(cudnn.ReLU(true))
model_G:add(nn.SpatialBatchNormalization(128, nil, nil, false))
model_G:add(nn.FeatureLPPooling(2,2,2,true))
model_G:add(cudnn.SpatialConvolutionUpsample(64, 224, 5, 5, 1, 2)):add(cudnn.ReLU(true))
model_G:add(nn.SpatialBatchNormalization(224, nil, nil, false))
model_G:add(nn.SpatialDropout(0.5))
model_G:add(cudnn.SpatialConvolutionUpsample(224, 3, 7, 7, 1, 1))
model_G:add(nn.SpatialBatchNormalization(3, nil, nil, false))
model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3]))
print(desc_G)
--- Discriminator network
--- Input: (j x j x 3) image h_k / \tilde{h}_k & (j x j x 3) image l_k
--- Output: 1 x 1 scalar propability
model_D = nn.Sequential()
model_D:add(nn.CAddTable())
model_D:add(cudnn.SpatialConvolution(3, 48, 3, 3))
model_D:add(cudnn.ReLU(true))
model_D:add(cudnn.SpatialConvolution(48, 448, 5, 5, 1, 1, 0, 0, 4))
model_D:add(cudnn.ReLU(true))
model_D:add(cudnn.SpatialConvolution(448, 416, 7, 7, 1, 1, 0, 0, 16))
model_D:add(cudnn.ReLU())
model_D:cuda()
ł dummy_input = torch.zeros(opt.batchSize, 3, opt.fineSize, opt.fineSize):cuda()
ł out = model_D:forward({dummy_input, dummy_input})
ł nElem = out:nElement() / opt.batchSize
model_D:add(nn.View(nElem):setNumInputDims(3))
model_D:add(nn.Linear(nElem, 1))
model_D:add(nn.Sigmoid())
model_D:cuda()
print(desc_D)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment