Skip to content

Instantly share code, notes, and snippets.

@Lirimy
Last active December 3, 2019 05:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lirimy/6d614e6c073defa50cdf8b019fbadbed to your computer and use it in GitHub Desktop.
Save Lirimy/6d614e6c073defa50cdf8b019fbadbed to your computer and use it in GitHub Desktop.
Show Flux.jl model as a tree structure
# Show Flux.jl model as a tree structure
#
# Copyright (c) 2019 Lirimy
# Released under the MIT license
# https://opensource.org/licenses/MIT
# require GraphRecipes v0.5.0
using Flux
using Plots
using GraphRecipes, AbstractTrees
function AbstractTrees.children(pair::Pair)
obj = pair.second
props = propertynames(obj)
[prop => getproperty(obj, prop) for prop in props]
end
function AbstractTrees.printnode(io::IO, pair::Pair)
name, obj = pair
if obj isa Function || obj isa Dense
print(io, obj)
else
print(io, name)
end
end
# https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
model = Chain(
Dense(28^2, 32, relu),
Dense(32, 10),
softmax)
tree = :Chain => model
print_tree(tree)
plt = plot(TreePlot(tree),
nodesize = 0.12,
nodeshape=:ellipse,
nodecolor=:lightskyblue1,
fontsize=12,
curves=false,
linewidth=3.0,
linecolor=:darkseagreen3,
background_color=:lightyellow,
background_color_outside=:white,
title="MNIST multi-layer-perceptron model")
plt |> display
savefig(plt, "modeltree.png")
@Lirimy
Copy link
Author

Lirimy commented Dec 2, 2019

Output

Chain
└─ layers
   ├─ Dense(784, 32, relu)
   │  ├─ W
   │  ├─ b
   │  └─ relu
   ├─ Dense(32, 10)
   │  ├─ W
   │  ├─ b
   │  └─ identity
   └─ softmax

modeltree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment