Skip to content

Instantly share code, notes, and snippets.

@theogf
Last active April 28, 2020 18:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save theogf/4759d3841565cf1d351056416549d28d to your computer and use it in GitHub Desktop.
Save theogf/4759d3841565cf1d351056416549d28d to your computer and use it in GitHub Desktop.
Animation of normalizing flows
using Distributions
using Bijectors
using Makie, Colors
using Animations
p0 = MvNormal(ones(2))
samples = rand(p0, 1)
abslim = -7
scene = Scene(limits= FRect(-abslim, -abslim, 2abslim, 2abslim))
xrange = range(-abslim, abslim, length = 100)
_pdf(p0, x, y) = pdf(p0, [x, y])
_tpdf(p, bij, x, y) = pdf(p, [x,y])
# _tpdf(p, bij, x, y) = pdf(p, inv(bij)([x,y]))
pdfp0 = _pdf.(Ref(p0), xrange, xrange')
time_node = Node(0.0)
f_samp = Node{Any}(t->Point2f0.(eachcol(samples)))
f_pdf = Node{Any}(t->pdfp0)
points = lift((f,t)->f(t), f_samp, time_node)
valpdf = lift((f,t)->f(t), f_pdf, time_node)
scene = contour!(xrange, xrange, valpdf, fillrange = true, linewidth = 0)
scene = scatter!(points, markersize = 0.2, color = RGBA(1.0, 1.0, 1.0, 0.7))
# scene[Axis][:showgrid] = false
scene[Axis][:showaxis] = (false,false)
scene[Axis][:ticks][:textsize] = (0.0, 0.0)
scene[Axis][:names][:axisnames] = ("","")
update!(scene)
record(scene, joinpath(@__DIR__, "norm_flows.gif"); framerate = 20) do io
### Sampling part
for i in 1:20
global samples = hcat(samples, rand(p0, 10))
time_node[] = i
# sleep(0.01)
recordframe!(io)
end
### Movement part
N2 = 20
tr = RadialLayer(-3.0, 2.0, zeros(2))
# tr = Bijectors.Shift(2.0*ones(2))
p1 = transformed(p0, tr)
moved_samples = forward(tr, samples)[1]
# Bijectors._logpdf_with_trans(p1, [2.0,3.0])
moved_pdf = _pdf.(Ref(p1), xrange, xrange')
int_samples = Animation.(0.0, samples, N2, moved_samples; defaulteasing = sineio())
int_pdf = Animation.(0.0, pdfp0, N2, moved_pdf; defaulteasing = sineio())
time_node[] = 0
f_samp[] = t->Point2f0.(eachcol(at.(int_samples, t)))
f_pdf[] = t->at.(int_pdf, t)
for i in 1:N2
time_node[] = i
# sleep(0.001)
recordframe!(io)
end
### Sampling again
N3 = 20
f_samp[] = t->Point2f0.(eachcol(moved_samples))
f_pdf[] = t -> moved_pdf
rand(p1, 1)
for i in 1:N3
moved_samples = hcat(moved_samples, rand(p1, 10))
time_node[] = i
# sleep(0.001)
recordframe!(io)
end
### Movement again
N4 = 20
tr = RadialLayer(-3.0, 2.0, zeros(2))
_tr = PlanarLayer([2.0,2.0],[0.5,1.0], 4.0)
tr2 = _tr ∘ tr
# tr = Bijectors.Shift(2.0*ones(2))
p2 = transformed(p0, tr2)
moved_samples_2 = forward(_tr, moved_samples)[1]
# Bijectors._logpdf_with_trans(p1, [2.0,3.0])
moved_pdf_2 = _pdf.(Ref(p2), xrange, xrange')
int_samples = Animation.(0.0, moved_samples, N4, moved_samples_2; defaulteasing = sineio())
int_pdf = Animation.(0.0, moved_pdf, N2, moved_pdf_2; defaulteasing = sineio())
time_node[] = 0
f_samp[] = t->Point2f0.(eachcol(at.(int_samples, t)))
f_pdf[] = t->at.(int_pdf, t)
for i in 1:N4
time_node[] = i
# sleep(0.001)
recordframe!(io)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment