Skip to content

Instantly share code, notes, and snippets.

@tbenst
Created June 10, 2021 17:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tbenst/f193cafe7855744e1d50ce56228dbbe2 to your computer and use it in GitHub Desktop.
Save tbenst/f193cafe7855744e1d50ce56228dbbe2 to your computer and use it in GitHub Desktop.
import IterTools, Cairo
using Colors, Compose, Fontconfig, PyCall, StatsBase, Glob,
FileIO, Measures, Format, Unitful
pickle = pyimport("pickle")
##
@__DIR__
## https://github.com/julia-vscode/julia-vscode/issues/2104
base_dir = joinpath(@__DIR__, "2020-reconstructions/final-outputs")
base_dir = joinpath(@__DIR__, "final-outputs")
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir)))
model_names = map(x->x[end],splitpath.(model_folders))
NUM_IMAGES_PER_MODEL = 10
order = ["originals", "linear_model", "64x64-mlp",
"64x64-mlp-small", "64x64-resnet-mlp", "64x64-convnet", "perceptual_model"]
idxs = map(o->searchsortedfirst(model_names,o), order)
model_names = model_names[idxs]
model_folders = model_folders[idxs]
# load perceptual losses
file = py"""open("each_image_pl.pickle", "rb")"""
percept_losses = pickle.load(file)
avg_percept_losses = Dict(k=>mean(values(v)) for (k,v) in percept_losses)
order = sortperm(map(x->x[2],collect(avg_percept_losses)))
println("perceptual model ranking")
@show collect(avg_percept_losses)[order]
## load MSE
se(x,y) = sum((x .- y) .^ 2)
function calc_MSE(model_name; original="originals",
data_folder=joinpath(@__DIR__,"final-outputs"))
model_path = joinpath(data_folder, model_name)
orig_path = joinpath(data_folder, original)
model_imgs = glob("test*.png", model_path)
img_names = map(x->x[end],splitpath.(model_imgs))
cum_se = 0.0
for name in img_names
# x = reinterpret(UInt8,load(joinpath(model_path,name)))
# y = reinterpret(UInt8,load(joinpath(orig_path,name)))
x = Float64.(load(joinpath(model_path,name)))
y = Float64.(load(joinpath(orig_path,name)))
cum_se += se(x,y)
end
cum_se/length(img_names)
end
# avg_mse = Dict(k=>calc_MSE(k) for k in keys(percept_losses))
# order = sortperm(map(x->x[2],collect(avg_mse)))
# println("MSE model ranking")
# @show collect(avg_mse)[order]
avg_mse = Dict(
"8x8-convnet" => 79.3,
# "16x16-convnet" => 55.3,
"16x16-convnet" => 52.8,
# "32x32-convnet" => 48.3,
"32x32-convnet" => 38.2,
# "64x64-convnet" => 38.0,
"64x64-convnet" => 36.6,
"linear_model" => 42.3,
"64x64-mlp" => 41.3,
"64x64-mlp-small" => 38.4,
"64x64-resnet-mlp" => 33.0,
"32x32-convnet-descrambling" => 10.3,
"32x32-convnet-scrambled-targets" => 72.6,
"32x32-convnet-scrambled-both" => 88.8
)
##
function drawMEA(N, everyN)
xs = collect(1:N) ./ (N+1)
grid = hcat(collect.(IterTools.product(xs, xs))[:]...)
grays = []
reds = []
for i in 1:N
for j in 1:N
if ((i-1) % everyN == 0) & ((j-1) % everyN == 0)
push!(reds, [i,j])
else
push!(grays, [i,j])
end
end
end
reds = hcat(reds...) ./ (N+1)
sz = 1 / (1.5 * N)
if length(grays) >= 1
grays = hcat(grays...) ./ (N+1)
compose(context(),
(context(), rectangle(grays[1,:], grays[2,:], [sz], [sz]), fill("darkgray")),
(context(), rectangle(reds[1,:], reds[2,:], [sz], [sz]), fill("darkred"))
)
else
compose(context(),
(context(), rectangle(reds[1,:], reds[2,:], [sz], [sz]), fill("darkred"))
)
end
end
function centered_text(the_text, fs=7pt)
compose(context(),
text(0.5,0.5,the_text, hcenter, vcenter),
fontsize(fs))
end
function get_images_for_model(model_folder)
img_names = readdir(model_folder)
# first 10 test images only
@assert NUM_IMAGES_PER_MODEL == 10
img_names = img_names[occursin.(r"test[0-9]-.*", img_names)]
read.(joinpath.(model_folder, img_names))
end
## FIGURE 1 MEA SIZE
ncol = 13
nrow = 5
W = 183mm
H = W/ncol * nrow
model_names = ["$(n)x$(n)-convnet" for n in [8, 16, 32, 64]]
tab = table(nrow, ncol, 1:nrow, 1:ncol)
im_start_col = 3
tab[1,2] = [centered_text("channel\nsampling\n per 8x8")]
# draw for each # of active channels
for (i,everyN,mn) in zip(2:nrow, [8,4,2,1], model_names)
tab[i,2] = [compose(context(), drawMEA(8, everyN))]
model_folder = joinpath(@__DIR__, "final-outputs", mn)
images = get_images_for_model(model_folder)
for (idx,j) in enumerate(im_start_col:im_start_col+9)
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))]
end
end
# original images
model_folder = joinpath(@__DIR__, "final-outputs", "originals")
images = get_images_for_model(model_folder)
# tab[6,1] = [centered_text("original")]
for (idx,j) in enumerate(im_start_col:im_start_col+9)
tab[1,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))]
end
# add MSE
tab[1, ncol-1] = [centered_text("MSE")]
for (idx,mn) in zip(2:5,model_names)
mse = "$(round(avg_mse[mn], digits=2))"
tab[idx,ncol-1] = [centered_text(mse)]
end
# add perceptual
tab[1, ncol] = [centered_text("Percept")]
for (idx,mn) in zip(2:5,model_names)
mse = "$(round(avg_percept_losses[mn], digits=2))"
tab[idx,ncol] = [centered_text(mse)]
end
# add MEA cutout
circ_lw = 3pt
mea_rect = 0.6
zoom_rect = mea_rect/8
# mea = compose(context(0w, 0h, 5cm, 5cm),
mea = compose(context(),
# zoomed selection
# (context(), text(0.5,mea_rect/2, "64x64 channel\nHD-MEA", hcenter, vcenter),
# fontsize(8pt)),
(context(), rectangle(0.5-zoom_rect/2,0.5-zoom_rect/2,zoom_rect,zoom_rect),
stroke("red"), fill(nothing)),
(context(), line([(1,0), (0.5+zoom_rect/2, 0.5-zoom_rect/2)]),
strokedash([1.2mm, 1.2mm]), stroke("red")),
(context(), line([(1,1), (0.5+zoom_rect/2, 0.5+zoom_rect/2)]),
strokedash([1.2mm, 1.2mm]), stroke("red")),
# MEA border
(context(), circle(0.5cx, 0.5cy, (1cx-circ_lw)/2),
fill(nothing), stroke("black"),linewidth(circ_lw)),
(context(), rectangle(0.5 - mea_rect/2,0.5 - mea_rect/2,mea_rect,mea_rect),
fill("gray80")),
)
tab[3,1] = [mea]
tab[4,1] = [ compose(context(),
text(0.5,0.5,"64x64\nHD-MEA", hcenter),
fontsize(7pt))]
set_default_graphic_size(W, H)
fig = compose(context(), tab)
fn = "figure1_channel_comparison"
fig |> SVG(joinpath(@__DIR__, "$fn.svg"))
@show H,W
println("$(round(H/247mm,digits=2)) of a page")
# assume 300dpi
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0))
px_w = mm2px(W)
px_h = mm2px(H)
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png"
println("to make PNG: $cmd")
fig
## FIGURE 2 model architecture
base_dir = joinpath(@__DIR__, "final-outputs")
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir)))
model_names = map(x->x[end],splitpath.(model_folders))
NUM_IMAGES_PER_MODEL = 10
order = ["originals", "linear_model", "64x64-mlp",
"64x64-resnet-mlp", "64x64-convnet"]
idxs = map(o->searchsortedfirst(model_names,o), order)
model_names = model_names[idxs]
model_folders = model_folders[idxs]
pretty_names = ["", "linear", "MLP",
"resMLP", "resUNet"]
ncol = 12
nrow = length(order)
W = 183mm
H = W/ncol * nrow
tab = table(nrow, ncol, 1:nrow, 1:ncol)
im_start_col = 2
# model names
for (i,mn) in enumerate(pretty_names)
tab[i,1] = [centered_text(mn)]
end
# render images
for (i,model_folder) in zip(1:nrow, model_folders)
images = get_images_for_model(model_folder)
for (idx,j) in enumerate(im_start_col:im_start_col+9)
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))]
end
end
# add MSE
tab[1, ncol-1] = [centered_text("MSE")]
for (idx,mn) in zip(2:nrow,model_names[2:end])
mse = "$(round(avg_mse[mn], digits=2))"
tab[idx,ncol-1] = [centered_text(mse)]
end
# add perceptual
tab[1, ncol] = [centered_text("Percept")]
for (idx,mn) in zip(2:nrow,model_names[2:end])
mse = "$(round(avg_percept_losses[mn], digits=2))"
tab[idx,ncol] = [centered_text(mse)]
end
set_default_graphic_size(W, H)
fig = compose(context(), tab)
fn = "figure2_model_architecture"
fig |> SVG(joinpath(@__DIR__, "$fn.svg"))
@show H,W
println("$(round(H/247mm,digits=2)) of a page")
# assume 300dpi
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0))
px_w = mm2px(W)
px_h = mm2px(H)
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png"
println("to make PNG: $cmd")
fig
## supp figure scramble
# base_dir = joinpath(@__DIR__, "final-outputs")
base_dir = "/mnt/dropbox/Dropbox/Science/manuscripts/2019_acuity_paper/2020-reconstructions/final-outputs"
model_folders = filter(isdir, joinpath.(base_dir, readdir(base_dir)))
model_names = map(x->x[end],splitpath.(model_folders))
NUM_IMAGES_PER_MODEL = 10
order = ["originals", "scrambled-targets", "32x32-convnet-descrambling",
"32x32-convnet-scrambled-targets", "32x32-convnet-scrambled-both"]
idxs = map(o->searchsortedfirst(model_names,o), order)
model_names = model_names[idxs]
model_folders = model_folders[idxs]
pretty_names = ["", "scrambled", "de-\nscrambling",
"targets\nscrambled", "both\nscrambled"]
ncol = 12
nrow = length(order)
W = 183mm
H = W/ncol * nrow
tab = table(nrow, ncol, 1:nrow, 1:ncol)
im_start_col = 2
# model names
for (i,mn) in enumerate(pretty_names)
tab[i,1] = [centered_text(mn)]
end
# render images
for (i,model_folder) in zip(1:nrow, model_folders)
images = get_images_for_model(model_folder)
for (idx,j) in enumerate(im_start_col:im_start_col+9)
tab[i,j] = [compose(context(), bitmap("image/png", images[idx],0,0,1,1))]
end
end
# add MSE
tab[1, ncol] = [centered_text("MSE")]
for (idx,mn) in zip(3:nrow,model_names[3:end])
mse = "$(round(avg_mse[mn], digits=2))"
tab[idx,ncol] = [centered_text(mse)]
end
set_default_graphic_size(W, H)
fig = compose(context(), tab)
fn = "figureSup_scrambling"
fig |> SVG(joinpath(@__DIR__, "$fn.svg"))
@show H,W
println("$(round(H/247mm,digits=2)) of a page")
# assume 300dpi
mm2px = x -> Int(round(uconvert(u"inch",Quantity(x.value,u"mm"))*200/1u"inch",digits=0))
px_w = mm2px(W)
px_h = mm2px(H)
cmd = "inkscape -w $px_w -h $px_h $fn.svg --export-filename $fn.png"
println("to make PNG: $cmd")
fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment