Created
November 11, 2020 10:46
-
-
Save h-spiess/6b539451d1e871c70a235a125e0b3597 to your computer and use it in GitHub Desktop.
Pluto Notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### A Pluto.jl notebook ### | |
# v0.12.7 | |
using Markdown | |
using InteractiveUtils | |
# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error). | |
macro bind(def, element) | |
quote | |
local el = $(esc(element)) | |
global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing | |
el | |
end | |
end | |
# ╔═╡ beeac00a-2277-11eb-18b5-83f8e74dcb9e | |
begin | |
using PyCall | |
using Plots | |
using Statistics | |
using RollingFunctions | |
using Glob | |
using PlutoUI | |
torch = pyimport("torch") | |
end | |
# ╔═╡ 207197de-2360-11eb-0837-111b4c20d3b0 | |
md""" | |
TODO: more resolution between 3 and 6 blocks | |
""" | |
# ╔═╡ 05fad28e-229e-11eb-17ed-d16f9563e62d | |
@bind m_l Select(["normals_loss", | |
"normals_angle_dist_mean", | |
"normals_angle_dist_median", | |
"normals_within_11.5", | |
"normals_within_22.5", | |
"normals_within_30"]) | |
# ╔═╡ 45497150-2282-11eb-2e13-0f0f03738cd5 | |
@bind rw NumberField(1:100) | |
# ╔═╡ 8cd83902-2282-11eb-3efa-a5831429d532 | |
rw | |
# ╔═╡ 7ffae768-2280-11eb-2836-cb08d333369d | |
@bind tr_or_test_str Select(["train", "test"]) | |
# ╔═╡ 5ffaf812-2281-11eb-0036-1f6eb275deff | |
tr_or_test = tr_or_test_str == "train" ? 0 : 12 | |
# ╔═╡ 1a365e9a-227a-11eb-0d39-3124dc0f5ef5 | |
lab = ["orig" 1 2 3 4 5 6 7 8 9] | |
# ╔═╡ d8b1c4a8-229e-11eb-350e-e7ea1794ff69 | |
md""" | |
Finetuning: $(@bind finetuning CheckBox()) | |
""" | |
# ╔═╡ 44267792-229f-11eb-0f41-b1dbc0672caa | |
begin | |
finetuning_str = finetuning ? " (with_finetuning)" : "" | |
finetuning_path = finetuning ? "with_finetuning_" : "" | |
end | |
# ╔═╡ 18fcc3ca-227a-11eb-2860-1dffce2995b2 | |
function l2path(l) | |
prefix = "/mnt/antares_raid/home/spiess/thesis/src/logs/" | |
suffix = "/model_checkpoints/checkpoint.chk" | |
if l == "orig" | |
glob("mtan_segnet_without_attention_equal_adam_run_*$(suffix)", prefix) | |
else | |
glob("mtan_segnet_without_attention_equal_adam_single_task_2_retrain_last_n_blocks_$(l)_$(finetuning_path)run_*$(suffix)", prefix) | |
end | |
end | |
# ╔═╡ 20060f0a-227a-11eb-183b-a36e933f46e6 | |
begin | |
stderror(x) = std(x) / sqrt(size(x, 1)) | |
all_paths = filter(c -> !isempty(c), [l2path(l) for l in lab]) | |
avg_cost = [[torch.load("$(p)", map_location="cpu")["avg_cost"] for p in paths] for paths in all_paths] | |
max_length = maximum(maximum.([size.(avg_c, 1) for avg_c in avg_cost])) | |
avg_cost = [[mapslices(x -> vcat(x, repeat([missing], max_length - size(avg_cc, 1))), avg_cc, dims=1) for avg_cc in avg_c] for avg_c in avg_cost] | |
avg_cost = [[mean(avg_c), stderror(avg_c)] for avg_c in avg_cost] | |
avg_cost = cat([cat(avg_c..., dims=3) for avg_c in avg_cost]..., dims=4) | |
avg_cost = avg_cost[1:100, :, :, :] | |
# first 12 are train, rest is test | |
metrics = [ | |
"segmentation_loss", | |
"segmentation_miou", | |
"segmentation_pix_acc", | |
"depth_loss", | |
"depth_abs_err", | |
"depth_rel_err", | |
"normals_loss", | |
"normals_angle_dist_mean", | |
"normals_angle_dist_median", | |
"normals_within_11.5", | |
"normals_within_22.5", | |
"normals_within_30" | |
] | |
end | |
# ╔═╡ 5aea2982-2291-11eb-3c0e-875bf9b93fb2 | |
begin | |
extremafn = m_l ∈ ["normals_within_11.5", | |
"normals_within_22.5", | |
"normals_within_30"] ? maximum : minimum | |
plot(lab[2:end][.!(isempty.([l2path(l) for l in lab])[1, 2:end])], extremafn(avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, 2:end], dims=1)[1, :], title="Best $(m_l) in $(tr_or_test_str)$(finetuning_str)", label = "Retrained") | |
hline!([extremafn(avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, 1])], label="Baseline") | |
end | |
# ╔═╡ 093d12f6-2280-11eb-0669-5947419b7623 | |
plot(mapslices(x -> runmean(x, rw), avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 1, :], dims=1), ribbon=mapslices(x -> runmean(x, rw), avg_cost[:, tr_or_test + findfirst(x -> x == m_l, metrics), 2, :], dims=1), lab=map(x -> "$(x)_$(tr_or_test_str)", lab), title="$(m_l)$(finetuning_str)") | |
# ╔═╡ Cell order: | |
# ╟─207197de-2360-11eb-0837-111b4c20d3b0 | |
# ╟─5aea2982-2291-11eb-3c0e-875bf9b93fb2 | |
# ╟─093d12f6-2280-11eb-0669-5947419b7623 | |
# ╟─05fad28e-229e-11eb-17ed-d16f9563e62d | |
# ╟─45497150-2282-11eb-2e13-0f0f03738cd5 | |
# ╠═8cd83902-2282-11eb-3efa-a5831429d532 | |
# ╟─7ffae768-2280-11eb-2836-cb08d333369d | |
# ╟─5ffaf812-2281-11eb-0036-1f6eb275deff | |
# ╟─20060f0a-227a-11eb-183b-a36e933f46e6 | |
# ╠═1a365e9a-227a-11eb-0d39-3124dc0f5ef5 | |
# ╟─d8b1c4a8-229e-11eb-350e-e7ea1794ff69 | |
# ╟─44267792-229f-11eb-0f41-b1dbc0672caa | |
# ╟─18fcc3ca-227a-11eb-2860-1dffce2995b2 | |
# ╟─beeac00a-2277-11eb-18b5-83f8e74dcb9e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment