Skip to content

Instantly share code, notes, and snippets.

@farrajota
Last active May 1, 2017 14:48
Show Gist options
  • Save farrajota/358de9ab06ebd8542f23f2102fb9c45a to your computer and use it in GitHub Desktop.
Save farrajota/358de9ab06ebd8542f23f2102fb9c45a to your computer and use it in GitHub Desktop.
Small example on how to create a binary tree in Torch7 using NN containers and nngraph.
--[[
Create a binary tree using two ways: NN containers and nn.gModule containers.
This example is fairly simple, and the default fully-connected layers are all
of size 100. However, this should also be simple to modify to allow different
fc layers with varying inputs/outputs if desired (for example: input a table
storing input+output configuration values for each of the sub-branch's level).
]]
require 'nn'
require 'nngraph'
-- (1) Example of a N-dimensional binary tree using containers
-- add tree branches into a container recursively
local function recursiveAddSubtrees(network, n_sub_trees, ninputs, noutputs)
-- creates a pair of linear layers inside a container
local function createPairFC(ninputs, noutputs)
local fc1 = nn.Linear(ninputs,noutputs)
local fc2 = nn.Linear(ninputs,noutputs)
local pair = nn.Sequential()
pair:add(nn.ConcatTable():add(nn.Sequential():add(fc1)):add(nn.Sequential():add(fc2)))
return pair
end --local function
--
if n_sub_trees > 1 then
local pair_tree = createPairFC(ninputs, noutputs)
network:add(pair_tree)
recursiveAddSubtrees(pair_tree.modules[1].modules[1], n_sub_trees-1, ninputs, noutputs)
recursiveAddSubtrees(pair_tree.modules[1].modules[2], n_sub_trees-1, ninputs, noutputs)
else
network:add(createPairFC(ninputs, noutputs))
end
end --local function
-- create a tree of NN containers
local n_sub_trees = 3 -- number of branching trees
local branches = nn.Sequential() -- Define the container to add sub-trees
recursiveAddSubtrees(branches, 3, 100, 100) -- recursively adds trees of fc layers of 100,100
local bin_tree_model = nn.Sequential() -- main model container
bin_tree_model:add(nn.Linear(10,100)) -- add 'root' fully-connected layer
bin_tree_model:add(branches) -- add branches
print(bin_tree_model) -- print the binary tree
-- (2) Example of a N-dimensional binary tree using nngraph. Warning: requires qlua to display the graph
local function recursiveAddSubtreesGraph(networkGraphTable, root_fc, n_sub_trees, ninputs, noutputs)
-- creates a pair of linear layers
local function createPairFCGraph(ninputs, noutputs)
local fc1 = nn.Linear(ninputs,noutputs)
local fc2 = nn.Linear(ninputs,noutputs)
return fc1, fc2
end --local function
--
if n_sub_trees > 1 then
local fc1, fc2 = createPairFCGraph(ninputs, noutputs)
recursiveAddSubtreesGraph(networkGraphTable, fc1({root_fc}), n_sub_trees-1, ninputs, noutputs)
recursiveAddSubtreesGraph(networkGraphTable, fc2({root_fc}), n_sub_trees-1, ninputs, noutputs)
else
local fc1, fc2 = createPairFCGraph(ninputs, noutputs)
table.insert(networkGraphTable, fc1({root_fc}))
table.insert(networkGraphTable, fc2({root_fc}))
end
end --local function
-- create a tree of nngraph.Node
local networkGraphTable = {} -- this table will sotre all the outputs necessary to define the nn.gModule
local root_fc = nn.Linear(10,100)() -- root fc layer. The () at the end is to register the module as a "nngraph.Node"
recursiveAddSubtreesGraph(networkGraphTable, root_fc, 3, 100, 100) -- recursively keep adding pairs of fc layers
-- define the nn.gModule (nngraph)
local bin_tree_modelGraph = nn.gModule(
{root_fc}, -- define the input to the model
networkGraphTable -- define the outputs of the model
)
local ok = pcall(require,'qt')
if ok then
-- to plot this graph you should use qlua to start this script
graph.dot(bin_tree_modelGraph.fg, 'binary tree') --display the forward node graph of the binary tree
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment