Skip to content

Instantly share code, notes, and snippets.

@montyhall
Last active February 19, 2017 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save montyhall/ef3890e49a2b8bd5b621107ce795123d to your computer and use it in GitHub Desktop.
Save montyhall/ef3890e49a2b8bd5b621107ce795123d to your computer and use it in GitHub Desktop.
require 'nn'
require 'cunn'
require 'image'
require 'nnx'
require 'nngraph'
require 'Upsample'
require 'util.misc'
model_utils = require 'util.model_utils'
print('loading image')
input=image.load('/home/xxx/git/RIS/plants_learning/tests/210_615_16.jpg'):sub(1,3)
print('loading plants_pre_lstm')
fcn_model = torch.load('/home/xxx/git/RIS/plants_learning/models/plants_pre_lstm.model')
print('loading plants_convlstm')
protos = torch.load('/home/xxx/git/RIS/plants_learning/models/plants_convlstm.model')
print('making cuda')
fcn_model=fcn_model:cuda()
--protos=protos:cuda()
rnn_layers = 2 -- Number of layers of the ConvLSTM
nChannels = 30 -- Number of channels in the state of the ConvLSTM
rnn_size = nChannels
xSize = 106
ySize = 100
height=530
width=500
itorch.image(input:sub(1,3):permute(1,3,2))
actual_input = torch.Tensor(3, height, width):fill(0)
actual_input[{ {1,3}, {1,height}, {1,width} }] = input
actual_input_aux = torch.Tensor(#actual_input)
actual_input_aux[{{1},{},{}}] = actual_input[{{3},{},{}}]
actual_input_aux[{{2},{},{}}] = actual_input[{{2},{},{}}]
actual_input_aux[{{3},{},{}}] = actual_input[{{1},{},{}}]
actual_input_aux = actual_input_aux:permute(1,3,2)
actual_input_aux=actual_input_aux:cuda()
fcn_model:evaluate()
x = fcn_model:forward(actual_input_aux)
itorch.image(x:permute(1,3,2))
--a bunch of clones after flattening, as that reallocates memory
seq_length = 20 -- This is the total number of iterations we run the model for each image.
-- You can change it at your convenience.
-- the initial state of the cell/hidden states
init_state = {}
for L=1,rnn_layers do
h_init = torch.zeros(rnn_size,xSize, ySize)
h_init = h_init:cuda()
table.insert(init_state, h_init:clone())
table.insert(init_state, h_init:clone())
end
Plot = require 'itorch.Plot'
-- the initial state of the cell/hidden states
init_state_global = {}
for L=1,rnn_layers do
h_init = torch.zeros(rnn_size,xSize, ySize)
h_init = h_init:cuda()
table.insert(init_state_global, h_init:clone())
table.insert(init_state_global, h_init:clone())
end
rnn_state = {[0] = init_state_global}
predictions_small = {}
predictions = {}
prediction = {}
loss = 0
gt_indices = {}
x = fcn_model:forward(actual_input_aux)
for t=1,seq_length do
lst = protos.rnn:forward{x, unpack(rnn_state[t-1])}
rnn_state[t] = {}
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
predictions_small[t] = lst[#lst] -- last element is the prediction
state = lst[#lst-1]:clone():pow(2):sum(1):squeeze()
predictions[t] = protos.shall_we_stop:forward(protos.post_lstm:forward(predictions_small[t]))
result = predictions[t][1]
itorch.image(result:permute(1,3,2))
print(predictions[t][2])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment