Last active
February 19, 2017 21:03
-
-
Save montyhall/ef3890e49a2b8bd5b621107ce795123d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'cunn' | |
require 'image' | |
require 'nnx' | |
require 'nngraph' | |
require 'Upsample' | |
require 'util.misc' | |
model_utils = require 'util.model_utils' | |
print('loading image') | |
input=image.load('/home/xxx/git/RIS/plants_learning/tests/210_615_16.jpg'):sub(1,3) | |
print('loading plants_pre_lstm') | |
fcn_model = torch.load('/home/xxx/git/RIS/plants_learning/models/plants_pre_lstm.model') | |
print('loading plants_convlstm') | |
protos = torch.load('/home/xxx/git/RIS/plants_learning/models/plants_convlstm.model') | |
print('making cuda') | |
fcn_model=fcn_model:cuda() | |
--protos=protos:cuda() | |
rnn_layers = 2 -- Number of layers of the ConvLSTM | |
nChannels = 30 -- Number of channels in the state of the ConvLSTM | |
rnn_size = nChannels | |
xSize = 106 | |
ySize = 100 | |
height=530 | |
width=500 | |
itorch.image(input:sub(1,3):permute(1,3,2)) | |
actual_input = torch.Tensor(3, height, width):fill(0) | |
actual_input[{ {1,3}, {1,height}, {1,width} }] = input | |
actual_input_aux = torch.Tensor(#actual_input) | |
actual_input_aux[{{1},{},{}}] = actual_input[{{3},{},{}}] | |
actual_input_aux[{{2},{},{}}] = actual_input[{{2},{},{}}] | |
actual_input_aux[{{3},{},{}}] = actual_input[{{1},{},{}}] | |
actual_input_aux = actual_input_aux:permute(1,3,2) | |
actual_input_aux=actual_input_aux:cuda() | |
fcn_model:evaluate() | |
x = fcn_model:forward(actual_input_aux) | |
itorch.image(x:permute(1,3,2)) | |
--a bunch of clones after flattening, as that reallocates memory | |
seq_length = 20 -- This is the total number of iterations we run the model for each image. | |
-- You can change it at your convenience. | |
-- the initial state of the cell/hidden states | |
init_state = {} | |
for L=1,rnn_layers do | |
h_init = torch.zeros(rnn_size,xSize, ySize) | |
h_init = h_init:cuda() | |
table.insert(init_state, h_init:clone()) | |
table.insert(init_state, h_init:clone()) | |
end | |
Plot = require 'itorch.Plot' | |
-- the initial state of the cell/hidden states | |
init_state_global = {} | |
for L=1,rnn_layers do | |
h_init = torch.zeros(rnn_size,xSize, ySize) | |
h_init = h_init:cuda() | |
table.insert(init_state_global, h_init:clone()) | |
table.insert(init_state_global, h_init:clone()) | |
end | |
rnn_state = {[0] = init_state_global} | |
predictions_small = {} | |
predictions = {} | |
prediction = {} | |
loss = 0 | |
gt_indices = {} | |
x = fcn_model:forward(actual_input_aux) | |
for t=1,seq_length do | |
lst = protos.rnn:forward{x, unpack(rnn_state[t-1])} | |
rnn_state[t] = {} | |
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output | |
predictions_small[t] = lst[#lst] -- last element is the prediction | |
state = lst[#lst-1]:clone():pow(2):sum(1):squeeze() | |
predictions[t] = protos.shall_we_stop:forward(protos.post_lstm:forward(predictions_small[t])) | |
result = predictions[t][1] | |
itorch.image(result:permute(1,3,2)) | |
print(predictions[t][2]) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment