Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active August 26, 2019 13:52
Show Gist options
  • Save sharanry/38561f178445eb34c0f35b6137bc87d3 to your computer and use it in GitHub Desktop.
Save sharanry/38561f178445eb34c0f35b6137bc87d3 to your computer and use it in GitHub Desktop.
Normalising Flows VI Prototype
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Recompiling stale cache file /home/sharan/.julia/compiled/v1.1/Bijectors/39uFz.ji for Bijectors [76274a88-744f-5084-9051-94815aaf08c4]\n",
"└ @ Base loading.jl:1184\n",
"WARNING: Method definition forward(Bijectors.TransformedDistribution{D, B, V} where V where B where D) in module Bijectors at /media/sharan/Work/Coding/MachineLearning/Projects/julia/Bijectors.jl/src/interface.jl:680 overwritten at /media/sharan/Work/Coding/MachineLearning/Projects/julia/Bijectors.jl/src/interface.jl:682.\n",
"┌ Info: Recompiling stale cache file /home/sharan/.julia/compiled/v1.1/Turing/gm4QC.ji for Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]\n",
"└ @ Base loading.jl:1184\n"
]
}
],
"source": [
"using Bijectors\n",
"using Distributions\n",
"using Turing\n",
"using TrackedDistributions\n",
"using ForwardDiff\n",
"using Random\n",
"using Tracker\n",
"using Flux"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@model gdemo_d() = begin\n",
" s ~ InverseGamma(2, 3)\n",
" m ~ Normal(0, sqrt(s))\n",
" 1.5 ~ Normal(m, sqrt(s))\n",
" 2.0 ~ Normal(m, sqrt(s))\n",
" return s, m\n",
"end\n",
"model = gdemo_d();"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Composed{Tuple{Bijectors.Scale{TrackedArray{…,Array{Float64,2}}},Bijectors.Shift{TrackedArray{…,Array{Float64,2}}},RadialLayer{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,2}}},PlanarLayer{TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},RadialLayer{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,2}}}}}((Bijectors.Scale{TrackedArray{…,Array{Float64,2}}}([1.0 0.0; 0.0 1.0] (tracked)), Bijectors.Shift{TrackedArray{…,Array{Float64,2}}}([0.0; 0.0] (tracked)), RadialLayer{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,2}}}([0.309296] (tracked), [-0.499293] (tracked), [-0.748744; -0.274501] (tracked)), PlanarLayer{TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}([0.150945; -1.05697] (tracked), [-2.08069; 0.0515419] (tracked), [-0.432142] (tracked)), RadialLayer{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,2}}}([0.642084] (tracked), [1.37536] (tracked), [0.904648; 0.523422] (tracked))))"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"base = TMVDiagonalNormal(zeros(2), ones(2))\n",
"# flow = Bijectors.compose(Bijectors.PlanarLayer(2, param) Bijectors.PlanarLayer(2, param));\n",
"# flow = Bijectors.PlanarLayer(2, param)\n",
"flow = Bijectors.compose(Bijectors.Scale(2, param), Bijectors.Shift(2, param), [ i%2==1 ? Bijectors.RadialLayer(2, param) : Bijectors.PlanarLayer(2, param) for i in 1:3]...)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"trans_base = transformed(base, flow);"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(x = [0.788159, 1.41118], y = [2.8486; 1.17072] (tracked), logabsdetjac = [0.0100974] (tracked), logpdf = [-4.00457] (tracked))"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forward(trans_base)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.15383367574095025 (tracked)"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = MersenneTwister(1234)\n",
"function F(model, trans_base, num_samples=10)\n",
" logpdf_p = []\n",
" x = rand(rng, trans_base.dist, num_samples)\n",
" _x, y, logjac, logpdf = forward(trans_base, x)\n",
" \n",
" varinfo = Turing.VarInfo(model)\n",
"\n",
" for i in 1:size(y, 2)\n",
" varinfo_new = Turing.VarInfo(varinfo, Turing.SampleFromUniform(), y[:,i])\n",
" \n",
" append!(logpdf_p, varinfo_new.logp)\n",
" end\n",
" \n",
" # logpdf_q = logpdf + logjac\n",
" mean((logjac + logpdf) - Tracker.collect(logpdf_p))\n",
" \n",
"end\n",
"# Sanity Check\n",
"F(model, transformed(base, flow), 100)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10-element Array{Tracker.TrackedReal{Float64},1}:\n",
" 3.0050780322926993 \n",
" 3.5486028826466285 \n",
" 1.0568763041012845 \n",
" 1.2848224461957847 \n",
" 1.086104065266645 \n",
" 1.3335215472759936 \n",
" -0.21375307575439786\n",
" 1.470289223760251 \n",
" 4.078633468643372 \n",
" 0.06802415504225666"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# High variability of F. A pontential problem\n",
"[F(model, transformed(base, flow), 1000) for i in 1:10]"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"get_ϕ! (generic function with 1 method)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# To extract all tracked params\n",
"function get_ϕ!(flow::Composed, ϕ)\n",
" for i in flow.ts\n",
" if typeof(i) <: Composed\n",
" get_ϕ!(i, ϕ)\n",
" else\n",
" for j in propertynames(i, false)\n",
" append!(ϕ, [getproperty(i,j)])\n",
" end\n",
" end\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"ϕ = []\n",
"get_ϕ!(flow, ϕ)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11-element Array{Any,1}:\n",
" [1.0 0.0; 0.0 1.0] (tracked) \n",
" [0.0; 0.0] (tracked) \n",
" [0.309296] (tracked) \n",
" [-0.499293] (tracked) \n",
" [-0.748744; -0.274501] (tracked)\n",
" [0.150945; -1.05697] (tracked) \n",
" [-2.08069; 0.0515419] (tracked) \n",
" [-0.432142] (tracked) \n",
" [0.642084] (tracked) \n",
" [1.37536] (tracked) \n",
" [0.904648; 0.523422] (tracked) "
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ϕ = [flow.b, flow.u, flow.w]\n",
"ϕ"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: 1.2741827085508652 (tracked)\n",
"└ @ Main In[68]:7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..........\n",
"(100/2000) done; loss=1.3692395655578096\n",
"..........\n",
"(200/2000) done; loss=-4.235651526416241\n",
"..........\n",
"(300/2000) done; loss=-16.918417146979092\n",
"..........\n",
"(400/2000) done; loss=-15.181841429751076\n",
"..........\n",
"(500/2000) done; loss=-17.839565102646546\n",
"..........\n",
"(600/2000) done; loss=-21.033234580804237\n",
"..........\n",
"(700/2000) done; loss=-18.436042242792695\n",
"..........\n",
"(800/2000) done; loss=-16.768170486650888\n",
"..........\n",
"(900/2000) done; loss=-18.672934678640846\n",
"..........\n",
"(1000/2000) done; loss=-20.279867305877353\n",
"..........\n",
"(1100/2000) done; loss=-21.661934070428607\n",
"..........\n",
"(1200/2000) done; loss=-23.026930269130762\n",
"..........\n",
"(1300/2000) done; loss=-25.01184912695476\n",
"..........\n",
"(1400/2000) done; loss=-27.05704607623275\n",
"..........\n",
"(1500/2000) done; loss=-30.870640116511463\n",
"..........\n",
"(1600/2000) done; loss=-36.09219298596677\n",
"..........\n",
"(1700/2000) done; loss=-30.67153315025095\n",
"..........\n",
"(1800/2000) done; loss=-24.371970056545187\n",
"..........\n",
"(1900/2000) done; loss=-25.47244479986402\n",
"..........\n",
"(2000/2000) done; loss=-26.876981343370787\n",
"done; 28.668578288s used\n"
]
}
],
"source": [
"Phi = Flux.Params(ϕ)\n",
"opt = ADAM(2e-3)\n",
"niters = 2_000\n",
"losses = []\n",
"trans_base = transformed(base, flow)\n",
"# Initial F\n",
"@info F(model, trans_base, 10)\n",
"timeused = @elapsed for iter = 1:niters\n",
" iter % 10 == 0 && print(\".\")\n",
"\n",
" loss = F(model, trans_base, 10)\n",
" gs = Tracker.gradient(() -> loss, Phi)\n",
" for p in ϕ\n",
"# println(p, gs[p])\n",
" Tracker.update!(opt, p, gs[p])\n",
"# println(p)\n",
" end\n",
" append!(losses, loss.data)\n",
" \n",
" if iter % 100 == 0\n",
" mean_loss = mean(losses[end-99:end])\n",
" println(\"\\n($iter/$niters) done; loss=$mean_loss\")\n",
" end\n",
"end\n",
"println(\"done; $(timeused)s used\")"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"1-element Array{PyCall.PyObject,1}:\n",
" PyObject <matplotlib.lines.Line2D object at 0x7fb9971054e0>"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Loss Plot\n",
"using PyPlot\n",
"PyPlot.plot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11-element Array{Any,1}:\n",
" [0.47511 0.638481; 0.482729 0.469116] (tracked)\n",
" [-0.0480375; -0.199192] (tracked) \n",
" [4.84201] (tracked) \n",
" [-3.70387] (tracked) \n",
" [-0.345812; -0.437857] (tracked) \n",
" [1.23621; -1.41586] (tracked) \n",
" [-5.33839; 3.05942] (tracked) \n",
" [-0.185727] (tracked) \n",
" [4.71451] (tracked) \n",
" [-2.72279] (tracked) \n",
" [-0.35254; -0.436752] (tracked) "
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ϕ"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Grads(...)\n"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = F(model, trans_base, 500)\n",
"gs = Tracker.gradient(() -> loss, Phi)\n",
"# Gradient of Shift bijector \n",
"# gs[ϕ[8]]"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11-element Array{TrackedArray{Float64,N,A} where A<:AbstractArray{Float64,N} where N,1}:\n",
" [-41.6567 27.2195; 59.7536 -20.7477] (tracked)\n",
" [3.32351; -3.19856] (tracked) \n",
" [-1.43589] (tracked) \n",
" [0.189733] (tracked) \n",
" [31.3775; -39.8101] (tracked) \n",
" [-17.4285; -20.4151] (tracked) \n",
" [0.207619; -0.1512] (tracked) \n",
" [28.1562] (tracked) \n",
" [-0.712008] (tracked) \n",
" [0.778539] (tracked) \n",
" [0.106055; 3.1436] (tracked) "
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Other gredients\n",
"[gs[ϕ[i]] for i in 1:length(ϕ)]"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(s = 2.0416666666666665, m = 1.1666666666666667)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(s = 49/24, m = 7/6) # the correct :s and :m"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(s = -0.3151962964746214, m = -0.3885797366600196)"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Posterior Sampling\n",
"(s = mean(flow(rand(rng, base, 10000))[1,:].data), m = mean(flow(rand(rng, base, 1000))[2,:].data)) # what we got"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tracked 2×1000 Array{Float64,2}:\n",
" -0.354811 -0.314419 -0.353454 … -0.379264 -0.579326 -0.198162\n",
" -0.438809 -0.403711 -0.437605 -0.460286 -0.633093 -0.302403"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flow(rand(rng, base, 1000))"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"PyPlot.hist(flow(rand(rng, base, 1000))[1,:].data, bins=10, alpha=1)\n",
"PyPlot.hist(flow(rand(rng, base, 1000))[2,:].data, bins=10, alpha=1)\n",
"show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment