Skip to content

Instantly share code, notes, and snippets.

@farrajota
Last active April 15, 2017 00:10
Show Gist options
  • Save farrajota/199bffd2669c5c9dc26311f816e9939b to your computer and use it in GitHub Desktop.
Save farrajota/199bffd2669c5c9dc26311f816e9939b to your computer and use it in GitHub Desktop.
Convert a cudnn batchnorm module to nn backend (cycles all modules of a network).
local function ConvertBNcudnn2nn(net)
local function ConvertModule(net)
return net:replace(function(x)
if torch.type(x) == 'cudnn.BatchNormalization' then
return cudnn.convert(x, nn)
else
return x
end
end)
end
net:apply(function(x) return ConvertModule(x) end)
end
@arunmallya
Copy link

To be more comprehensive, it would be nice to include the spatial BN as well. The current one does not replace them.

if torch.type(x) == 'cudnn.BatchNormalization' or torch.type(x) == 'cudnn.SpatialBatchNormalization' then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment