Skip to content

Instantly share code, notes, and snippets.

@kumeS
Last active December 9, 2020 13:03
Show Gist options
  • Save kumeS/41fed511efb45bd55d468d4968b0f157 to your computer and use it in GitHub Desktop.
Save kumeS/41fed511efb45bd55d468d4968b0f157 to your computer and use it in GitHub Desktop.
## This code was modified from the original code of "andrie/deepviz" (https://github.com/andrie/deepviz)
## The colores of nodes were changed in this version.
pkg.name <- "keras"
if(!require(pkg.name, character.only=TRUE)){
install.packages(pkg.name)
}
library(pkg.name, character.only=TRUE)
pkg.name <- "DiagrammeR"
if(!require(pkg.name, character.only=TRUE)){
install.packages(pkg.name)
}
library(pkg.name, character.only=TRUE)
pkg.name <- "assertthat"
if(!require(pkg.name, character.only=TRUE)){
install.packages(pkg.name)
}
library(pkg.name, character.only=TRUE)
pkg.name <- "purrr"
if(!require(pkg.name, character.only=TRUE)){
install.packages(pkg.name)
}
library(pkg.name, character.only=TRUE)
##############################################
plot_model_modi <- function(model, ...){
UseMethod("plot_model", model)
}
globalVariables(c(".", "V1", "V2", "x"))
model_nodes <- function(x){
assert_that(is.keras_model(x))
if (is.keras_model_sequential(x)) {
model_layers <- x$get_config()$layers
l_name <- map_chr(model_layers, ~purrr::pluck(., "config", "name"))
} else {
model_layers <- x$get_config()$layers
l_name <- model_layers %>% map_chr("name")
}
l_type <- model_layers %>% map_chr("class_name")
l_activation <- model_layers %>%
map_chr(
~(purrr::pluck(., "config", "activation") %||% "")
)
create_node_df(
n = length(model_layers),
name = l_name,
type = l_type,
label = glue::glue("{l_name}\n{l_type}\n{l_activation}"),
shape = "rectangle",
activation = l_activation
)
}
model_edges_sequential <- function(ndf){
assert_that(is.data.frame(ndf))
z <- embed(ndf$id, dimension = 2)
create_edge_df(
from = z[, 2],
to = z[, 1]
)
}
inbound_nodes <- function(model){
assert_that(is.keras_model_network(model))
model_layers <- model$get_config()$layers
inbound <- map(
model_layers,
function(x){
if (length(x$inbound_nodes))
x$inbound_nodes[[1]] %>%
map_chr(c(1, 1))
else NA
}
)
names(inbound) <- map(model_layers, "name")
z <- imap_dfr(
inbound,
~ data.frame(to = .y, from = .x, stringsAsFactors = FALSE)
)
na.omit(z)[, c("from", "to")]
}
# The input x must be a nodes df
model_edges_network <- function(model, ndf){
assert_that(is.keras_model_network(model))
assert_that(is.data.frame(ndf))
z <- inbound_nodes(model)
z$from <- ndf$id[match(z$from, ndf$name)]
z$to <- ndf$id[match(z$to, ndf$name)]
z
}
is.keras_model <- function(x){
inherits(x, "keras.engine.training.Model")
}
is.keras_model_sequential <- function(x){
is.keras_model(x) && inherits(x, "keras.engine.sequential.Sequential")
}
is.keras_model_network <- function(x){
is.keras_model(x) && !is.keras_model_sequential(x)
}
plot_model.keras.engine.training.Model <- function(model, width=4.5, height=1, ...){
nodes_df <- model_nodes(model)
if (is.keras_model_sequential(model)){edges_df <- model_edges_sequential(nodes_df)}else{edges_df <- model_edges_network(model, nodes_df)}
graph <- DiagrammeR::create_graph(nodes_df, edges_df)
graph <- DiagrammeR::set_edge_attrs(graph, "arrowhead", "vee")
graph <- DiagrammeR::set_edge_attrs(graph, "arrowsize", 1)
graph <- DiagrammeR::set_edge_attrs(graph, "color", "grey30")
graph <- DiagrammeR::set_node_attrs(graph, "fixedsize", FALSE)
graph <- DiagrammeR::set_node_attrs(graph, "nodesep", 2)
graph <- DiagrammeR::set_node_attrs(graph, "fontcolor", "black")
graph <- DiagrammeR::set_node_attrs(graph, "fontsize", 15)
graph <- DiagrammeR::set_node_attrs(graph, "color", "blue", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv2D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "skyblue", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv2D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "blue", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv3D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "skyblue", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv3D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "red", nodes = (1:nrow(nodes_df))[nodes_df$type == "Activation"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "pink", nodes = (1:nrow(nodes_df))[nodes_df$type == "Activation"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "green", nodes = (1:nrow(nodes_df))[nodes_df$type == "SpatialDropout2D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "aquamarine", nodes = (1:nrow(nodes_df))[nodes_df$type == "SpatialDropout3D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "green", nodes = (1:nrow(nodes_df))[nodes_df$type == "SpatialDropout2D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "aquamarine", nodes = (1:nrow(nodes_df))[nodes_df$type == "SpatialDropout3D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "black", nodes = (1:nrow(nodes_df))[nodes_df$type == "InputLayer"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "azure", nodes = (1:nrow(nodes_df))[nodes_df$type == "InputLayer"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "cora;", nodes = (1:nrow(nodes_df))[nodes_df$type == "MaxPooling2D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "cornsilk", nodes = (1:nrow(nodes_df))[nodes_df$type == "MaxPooling2D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "cora;", nodes = (1:nrow(nodes_df))[nodes_df$type == "MaxPooling3D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "cornsilk", nodes = (1:nrow(nodes_df))[nodes_df$type == "MaxPooling3D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "brown1;", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv3DTranspose"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "brown1", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv3DTranspose"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "brown1;", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv2DTranspose"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "brown1", nodes = (1:nrow(nodes_df))[nodes_df$type == "Conv2DTranspose"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "darkorange;", nodes = (1:nrow(nodes_df))[nodes_df$type == "UpSampling3D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "darkorange", nodes = (1:nrow(nodes_df))[nodes_df$type == "UpSampling3D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "darkorange;", nodes = (1:nrow(nodes_df))[nodes_df$type == "UpSampling2D"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "darkorange", nodes = (1:nrow(nodes_df))[nodes_df$type == "UpSampling2D"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "gold;", nodes = (1:nrow(nodes_df))[nodes_df$type == "BatchNormalizationV1"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "beige", nodes = (1:nrow(nodes_df))[nodes_df$type == "BatchNormalizationV1"])
graph <- DiagrammeR::set_node_attrs(graph, "color", "cyan;", nodes = (1:nrow(nodes_df))[nodes_df$type == "Concatenate"])
graph <- DiagrammeR::set_node_attrs(graph, "fillcolor", "cyan", nodes = (1:nrow(nodes_df))[nodes_df$type == "Concatenate"])
coords <- local({
(igraph::layout_with_sugiyama(DiagrammeR::to_igraph(graph)))[[2]] %>%
dplyr::as_tibble() %>%
dplyr::rename(
x = V1,
y = V2) %>%
dplyr::mutate(x = width * x) %>%
dplyr::mutate(y = height * y)
})
graph$nodes_df <- graph$nodes_df %>%
dplyr::bind_cols(coords)
DiagrammeR::render_graph(graph, layout="dot")
}
####### ####### ####### ####### ####### ####### ####### #######
####### ####### ####### ####### ####### ####### ####### #######
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment