Skip to content

Instantly share code, notes, and snippets.

@h-spiess
Created November 11, 2020 10:46
Show Gist options
  • Save h-spiess/6b539451d1e871c70a235a125e0b3597 to your computer and use it in GitHub Desktop.
Save h-spiess/6b539451d1e871c70a235a125e0b3597 to your computer and use it in GitHub Desktop.
Pluto Notebook
### A Pluto.jl notebook ###
# v0.12.7
using Markdown
using InteractiveUtils
# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error).
macro bind(def, element)
quote
local el = $(esc(element))
global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing
el
end
end
# ╔═╡ beeac00a-2277-11eb-18b5-83f8e74dcb9e
begin
using PyCall
using Plots
using Statistics
using RollingFunctions
using Glob
using PlutoUI
torch = pyimport("torch")
end
# ╔═╡ 207197de-2360-11eb-0837-111b4c20d3b0
md"""
TODO: more resolution between 3 and 6 blocks
"""
# ╔═╡ 05fad28e-229e-11eb-17ed-d16f9563e62d
@bind m_l Select(["normals_loss",
"normals_angle_dist_mean",
"normals_angle_dist_median",
"normals_within_11.5",
"normals_within_22.5",
"normals_within_30"])
# ╔═╡ 45497150-2282-11eb-2e13-0f0f03738cd5
@bind rw NumberField(1:100)
# ╔═╡ 8cd83902-2282-11eb-3efa-a5831429d532
rw
# ╔═╡ 7ffae768-2280-11eb-2836-cb08d333369d
@bind tr_or_test_str Select(["train", "test"])
# ╔═╡ 5ffaf812-2281-11eb-0036-1f6eb275deff
tr_or_test = tr_or_test_str == "train" ? 0 : 12
# ╔═╡ 1a365e9a-227a-11eb-0d39-3124dc0f5ef5
lab = ["orig" 1 2 3 4 5 6 7 8 9]
# ╔═╡ d8b1c4a8-229e-11eb-350e-e7ea1794ff69
md"""
Finetuning: $(@bind finetuning CheckBox())
"""
# ╔═╡ 44267792-229f-11eb-0f41-b1dbc0672caa
begin
finetuning_str = finetuning ? " (with_finetuning)" : ""
finetuning_path = finetuning ? "with_finetuning_" : ""
end
# ╔═╡ 18fcc3ca-227a-11eb-2860-1dffce2995b2
function l2path(l)
prefix = "/mnt/antares_raid/home/spiess/thesis/src/logs/"
suffix = "/model_checkpoints/checkpoint.chk"
if l == "orig"
glob("mtan_segnet_without_attention_equal_adam_run_*$(suffix)", prefix)
else
glob("mtan_segnet_without_attention_equal_adam_single_task_2_retrain_last_n_blocks_$(l)_$(finetuning_path)run_*$(suffix)", prefix)
end
end
# ╔═╡ 20060f0a-227a-11eb-183b-a36e933f46e6
begin
stderror(x) = std(x) / sqrt(size(x, 1))
all_paths = filter(c -> !isempty(c), [l2path(l) for l in lab])
avg_cost = [[torch.load("$(p)", map_location="cpu")["avg_cost"] for p in paths] for paths in all_paths]
max_length = maximum(maximum.([size.(avg_c, 1) for avg_c in avg_cost]))
avg_cost = [[mapslices(x -> vcat(x, repeat([missing], max_length - size(avg_cc, 1))), avg_cc, dims=1) for avg_cc in avg_c] for avg_c in avg_cost]
avg_cost = [[mean(avg_c), stderror(avg_c)] for avg_c in avg_cost]
avg_cost = cat([cat(avg_c..., dims=3) for avg_c in avg_cost]..., dims=4)
avg_cost = avg_cost[1:100, :, :, :]
# first 12 are train, rest is test
metrics = [
"segmentation_loss",
"segmentation_miou",
"segmentation_pix_acc",
"depth_loss",
"depth_abs_err",
"depth_rel_err",
"normals_loss",
"normals_angle_dist_mean",
"normals_angle_dist_median",
"normals_within_11.5",
"normals_within_22.5",
"normals_within_30"
]
end
# ╔═╡ 5aea2982-2291-11eb-3c0e-875bf9b93fb2
begin
extremafn = m_l ∈ ["normals_within_11.5",
"normals_within_22.5",
"normals_within_30"] ? maximum : minimum
plot(lab[2:end][.!(isempty.([l2path(l) for l in lab])[1, 2:end])], extremafn(avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, 2:end], dims=1)[1, :], title="Best $(m_l) in $(tr_or_test_str)$(finetuning_str)", label = "Retrained")
hline!([extremafn(avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, 1])], label="Baseline")
end
# ╔═╡ 093d12f6-2280-11eb-0669-5947419b7623
plot(mapslices(x -> runmean(x, rw), avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, :], dims=1), ribbon=mapslices(x -> runmean(x, rw), avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 2, :], dims=1), lab=map(x -> "$(x)_$(tr_or_test_str)", lab), title="$(m_l)$(finetuning_str)")
# ╔═╡ Cell order:
# ╟─207197de-2360-11eb-0837-111b4c20d3b0
# ╟─5aea2982-2291-11eb-3c0e-875bf9b93fb2
# ╟─093d12f6-2280-11eb-0669-5947419b7623
# ╟─05fad28e-229e-11eb-17ed-d16f9563e62d
# ╟─45497150-2282-11eb-2e13-0f0f03738cd5
# ╠═8cd83902-2282-11eb-3efa-a5831429d532
# ╟─7ffae768-2280-11eb-2836-cb08d333369d
# ╟─5ffaf812-2281-11eb-0036-1f6eb275deff
# ╟─20060f0a-227a-11eb-183b-a36e933f46e6
# ╠═1a365e9a-227a-11eb-0d39-3124dc0f5ef5
# ╟─d8b1c4a8-229e-11eb-350e-e7ea1794ff69
# ╟─44267792-229f-11eb-0f41-b1dbc0672caa
# ╟─18fcc3ca-227a-11eb-2860-1dffce2995b2
# ╟─beeac00a-2277-11eb-18b5-83f8e74dcb9e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment