Skip to content

Instantly share code, notes, and snippets.

@jtleek
Created April 20, 2017 21:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jtleek/a4491fe9f77ff72a7fe5972196d57a3b to your computer and use it in GitHub Desktop.
Save jtleek/a4491fe9f77ff72a7fe5972196d57a3b to your computer and use it in GitHub Desktop.
plot_net = function(nodes_per_layer, layer_shape, connection_type, stride_length){
total_nodes = sum(nodes_per_layer)
adjacency = matrix(0,nrow=total_nodes,ncol=total_nodes)
for(i in seq_along(connection_type)){
source_nodes = 1:nodes_per_layer[i] + (i-1 > 0)*sum(nodes_per_layer[1:(i-1)])
sink_nodes = 1:nodes_per_layer[(i+1)] + (i > 0)*sum(nodes_per_layer[1:i])
nsource = length(source_nodes)
nsink = length(sink_nodes)
if(connection_type[i] == "straight"){
if(nsource!=nsink){
stop(paste("Can't use straight at layer",i,"because # of nodes doesn't match"))
}else{
adjacency = index_assign(source_nodes,sink_nodes,adjacency)
}
}else if(connection_type[i] == "full"){
adjacency = index_assign(rep(source_nodes,each=nsink),
rep(sink_nodes,nsource),adjacency)
}else if(connection_type[i] == "convolution"){
if(stride_length[i] > nodes_per_layer[(i+1)]){
stop(paste("Can't use stride length",stride_length[i],"at layer",i,"because not enough nodes."))
}else{
adjacency = index_assign((0:(stride_length[i]-1)) + rep(source_nodes,each=stride_length[i]),
rep(sink_nodes,each=stride_length[i]),adjacency)
}
}
}
}
index_assign = function(x,y,mat){
for(i in seq_along(x)){
mat[x[i],y[i]] = 1
}
mat
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment