Last active
August 29, 2015 14:22
-
-
Save bunelr/11ddcf034b51bb9f65ab to your computer and use it in GitHub Desktop.
Forward mode
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Without the monkeypatching (standard nn) | |
rudy@rudy-Tower:~workspace/utc-caltech[master] | |
$ time luajit luatools/prediction.lua | |
data/res/divers-bignn/set09/V010.txt | |
real 1m1.162s | |
user 0m51.132s | |
sys 0m13.084s | |
-- With the monkeypatching | |
rudy@rudy-Tower:~workspace/utc-caltech[master] | |
$ time luajit luatools/prediction.lua | |
data/res/divers-bignn/set09/V010.txt | |
real 0m43.975s | |
user 0m39.715s | |
sys 0m7.190s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- the network that I was using | |
local get_detector = function() | |
local cnn = nn.Sequential() | |
cnn:add(nn.SpatialConvolution(3, 25, 5, 5)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.SpatialDropout()) | |
cnn:add(nn.SpatialConvolution(25, 50, 5, 5)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.SpatialDropout()) | |
cnn:add(nn.SpatialMaxPooling(2,2,2,2)) | |
cnn:add(nn.SpatialConvolution(50,75,5,5)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.SpatialDropout()) | |
cnn:add(nn.SpatialConvolution(75,100,1,5)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.SpatialDropout()) | |
cnn:add(nn.SpatialMaxPooling(2,2,2,2)) | |
cnn:add(nn.SpatialConvolution(100, 25*6*4, 4, 6, 1, 1)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.Dropout()) | |
cnn:add(nn.SpatialConvolution(25*6*4 , 12*6*4, 1, 1)) | |
cnn:add(nn.PReLU()) | |
cnn:add(nn.Dropout()) | |
cnn:add(nn.SpatialConvolution(12*6*4, 2, 1, 1)) | |
cnn:cuda() | |
return cnn | |
end | |
-- function to replace updateOutput for Sequential | |
local updateOutputCleaner= function(self, input) | |
local currentOutput = input | |
for i=1,#self.modules do | |
currentOutput = self.modules[i]:updateOutput(currentOutput) | |
self.modules[i].finput = nil | |
self.modules[i].output = torch.CudaTensor() | |
collectgarbage() | |
end | |
return currentOutput | |
end | |
-- where i monkeypatch | |
local switch_to_test_mode = function(cnn) | |
cnn:evaluate() | |
cnn:zeroGradParameters() | |
cnn.updateOutput = updateOutputCleaner | |
return cnn | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment