Skip to content

Instantly share code, notes, and snippets.

@paidi
Last active February 14, 2017 10:29
Show Gist options
  • Save paidi/310d2d869ef74794b239 to your computer and use it in GitHub Desktop.
Save paidi/310d2d869ef74794b239 to your computer and use it in GitHub Desktop.
Batch SparseLinear
m = nn.ParallelTable()
layer = nn.SparseLinear(inputSize,outputSize)
m:add(nn.Sequential():add(layer):add(nn.Reshape(1,outputSize)))
for i=2,batchSize do
local repLayer = layer:clone('weight', 'bias', 'gradWeight', 'gradBias')
m:add(nn.Sequential():add(repLayer):add(nn.Reshape(1,outputSize)))
end
batchLayer = nn.Sequential():add(m):add(nn.JoinTable(1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment