Skip to content

Instantly share code, notes, and snippets.

@culurciello
Created September 2, 2016 19:32
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save culurciello/1955aa1d6ea8381a9da908f9f3d13228 to your computer and use it in GitHub Desktop.
Save culurciello/1955aa1d6ea8381a9da908f9f3d13228 to your computer and use it in GitHub Desktop.
-- Eugenio Culurciello
-- August 2016
-- a test to learn to code PredNet-like nets in Torch7
require 'nn'
require 'nngraph'
torch.setdefaulttensortype('torch.FloatTensor')
nngraph.setDebug(true)
local nlayers = 4
-- local input = nn.Identity()()
-- local pOut = nn.Identity()()
local inputs = {}
local outputs = {}
table.insert(inputs, nn.Identity()()) -- input image x
for L = 1, nlayers do
table.insert(inputs, nn.Identity()()) -- previous output D
end
for L = 1, nlayers do
print('Creating layer-test:', L)
-- define layer functions:
local cD = nn.MulConstant(2)
local cG = nn.MulConstant(0.5)
local E = nn.CSubTable(1)
local D
if L == 1 then
D = {inputs[1]} - cD -- output
else
D = {outputs[2*L-3]} - cD
end
D:annotate{graphAttributes = {color = 'green', fontcolor = 'green'}}
local G = {inputs[L+1]} - cG
local Df
if L == 1 then
Df = {inputs[L], G} - E -- output difference
else
Df = {outputs[2*L-3], G} - E
end
Df:annotate{graphAttributes = {color = 'blue', fontcolor = 'blue'}}
table.insert(outputs, D)
table.insert(outputs, Df)
end
-- create graph
print('Creating model-test:')
nngraph.annotateNodes()
local model = nn.gModule(inputs, outputs)
-- test:
print('Testing model-test:')
local c = require 'trepl.colorize'
local nT = 3 -- time sequence length
--local inTable = {}
local outTable = {}
for L = 1, nlayers * 2 do
table.insert(outTable, torch.zeros(2, 2))
end
local x = {} -- size nT
for t = 1, nT do table.insert(x, torch.ones(2, 2)) end
local tmp
for i = 1, nT do
inTable = {x[i]} -- size (nlayers + 1)
for j = 1, nlayers do
table.insert(inTable, outTable[2*j - 1])
end
print(c.red('Input of iteration '.. i ..' is:'))
tmp = inTable[1]
for j = 2, #inTable do tmp = torch.cat(tmp, inTable[j], 2) end
print(tmp)
outTable = model:forward(inTable) -- size 2*nlayers
print(c.cyan('Output of iteration '.. i ..' is: '))
tmp = outTable[1]
for j = 2, #outTable do tmp = torch.cat(tmp, outTable[j], 2) end
print(tmp)
end
graph.dot(model.fg, 'test','Model-test') -- graph the model!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment