Skip to content

Instantly share code, notes, and snippets.

@jtleek
Created April 21, 2017 01:12
Show Gist options
  • Save jtleek/9257d435c25f6ab2bd536735f13d38fd to your computer and use it in GitHub Desktop.
Save jtleek/9257d435c25f6ab2bd536735f13d38fd to your computer and use it in GitHub Desktop.
Plot net working
#library(ggnet)
#library(RkittleBrewer)
#library(network)
plot_net = function(nodes_per_layer,
layer_shape,
connection_type,
stride_length,
layer_color,
node_values,
node_labels){
total_nodes = sum(nodes_per_layer)
adjacency = matrix(0,nrow=total_nodes,ncol=total_nodes)
for(i in seq_along(connection_type)){
source_nodes = 1:nodes_per_layer[i] + (i-1 > 0)*sum(nodes_per_layer[1:(i-1)])
sink_nodes = 1:nodes_per_layer[(i+1)] + (i > 0)*sum(nodes_per_layer[1:i])
nsource = length(source_nodes)
nsink = length(sink_nodes)
if(connection_type[i] == "s"){
if(nsource!=nsink){
stop(paste("Can't use straight at layer",i,"because # of nodes doesn't match"))
}else{
adjacency = index_assign(source_nodes,sink_nodes,adjacency)
}
}else if(connection_type[i] == "f"){
adjacency = index_assign(rep(source_nodes,each=nsink),
rep(sink_nodes,nsource),adjacency)
}else if(connection_type[i] == "c"){
index = rep(1:stride_length[i]) + rep(0:(nsink-1),each=stride_length[i])
adjacency = index_assign(source_nodes[index],
rep(sink_nodes,each=stride_length[i]),adjacency)
}
}
net = network(adjacency,directed=TRUE)
node_index = rep(seq_along(nodes_per_layer),times=nodes_per_layer)
net %v% "shape" = layer_shape[node_index]
net %v% "x" = node_index
net %v% "y" = unlist(sapply(nodes_per_layer,function(x){ (x:1)/(x+1) }))
net %v% "alpha" = node_values
net %v% "color" = layer_color[node_index]
ggnet2(net,mode=c("x","y"),
shape="shape",
arrow.size = 12,
arrow.gap = 0.25,
alpha="alpha",
label=node_labels,
color="color")
}
index_assign = function(x,y,mat){
for(i in seq_along(x)){
mat[x[i],y[i]] = 1
}
mat
}
# Run this code to see it work
#
# trop = RSkittleBrewer("tropical")
# nodes_per_layer = c(5,5,3,3,1)
# layer_shape = c(19,15,19,15,19)
# connection_type=c("s","c","s","c")
# stride_length=c(3,3,3,3)
# layer_color=trop[c(2,3,2,3,2)]
# node_values = runif(sum(nodes_per_layer))
# node_labels = round(node_values,1)
#
#
#
# plot_net(nodes_per_layer,
# layer_shape,
# connection_type,
# stride_length,
# layer_color,
# node_values,
# node_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment