Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active August 29, 2019 08:50
Show Gist options
  • Save sharanry/2d4e824f27ed01e4f7a1d6fb6f02305a to your computer and use it in GitHub Desktop.
Save sharanry/2d4e824f27ed01e4f7a1d6fb6f02305a to your computer and use it in GitHub Desktop.
Simple NFVI
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 320,
"metadata": {},
"outputs": [],
"source": [
"using Debugger"
]
},
{
"cell_type": "code",
"execution_count": 321,
"metadata": {},
"outputs": [],
"source": [
"using Bijectors\n",
"using Distributions\n",
"using Turing\n",
"using TrackedDistributions\n",
"using ForwardDiff\n",
"using Random\n",
"using Tracker\n",
"using Flux\n",
"using Distances"
]
},
{
"cell_type": "code",
"execution_count": 335,
"metadata": {},
"outputs": [],
"source": [
"@model simple() = begin\n",
" a ~ Normal(10, 5) \n",
" b ~ Normal(a, 5) \n",
" return a, b\n",
"end\n",
"model = simple();\n",
"# @model gdemo_d() = begin\n",
"# s ~ InverseGamma(2, 3)\n",
"# m ~ Normal(0, sqrt(s))\n",
"# 1.5 ~ Normal(m, sqrt(s))\n",
"# 2.0 ~ Normal(m, sqrt(s))\n",
"# return s, m\n",
"# end\n",
"# model = gdemo_d();"
]
},
{
"cell_type": "code",
"execution_count": 362,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"get_transforms (generic function with 1 method)"
]
},
"execution_count": 362,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function get_transforms(model::Turing.Model)\n",
" varinfo = Turing.VarInfo(model)\n",
" num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)])\n",
" \n",
" base = MvNormal(zeros(num_params), ones(num_params))\n",
" flow = Bijectors.compose(Bijectors.Scale(num_params, param), Bijectors.Shift(num_params, param), [ i%2==1 ? Bijectors.RadialLayer(num_params, param) : Bijectors.PlanarLayer(num_params, param) for i in 1:10]...);\n",
" trans_base = transformed(base, flow);\n",
" return (base=base, flow=flow, trans_base=trans_base) \n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 363,
"metadata": {},
"outputs": [],
"source": [
"base, flow, trans_base = get_transforms(model);"
]
},
{
"cell_type": "code",
"execution_count": 364,
"metadata": {},
"outputs": [],
"source": [
"rng = MersenneTwister(1234);"
]
},
{
"cell_type": "code",
"execution_count": 365,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.4086145128384295 (tracked)"
]
},
"execution_count": 365,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function F(rng, model, trans_base; num_samples=10, x=nothing, print_info=false)\n",
" logpdf_p = []\n",
" if x==nothing\n",
" x = rand(rng, trans_base.dist, num_samples)\n",
" end\n",
" \n",
" _x, y, logjac, logpdf = forward(trans_base, x)\n",
" if print_info\n",
" @info \"x\" x\n",
" @info \"y\" y\n",
" @info \"logjac\" logjac\n",
" @info \"logpdf\" logpdf\n",
" end\n",
" \n",
" varinfo = Turing.VarInfo(model)\n",
" \n",
" for i in 1:size(y, 2)\n",
" varinfo_new = Turing.VarInfo(varinfo, Turing.SampleFromUniform(), y[:,i])\n",
" model(varinfo_new)\n",
" if print_info\n",
" @info \"varinfo_new\" varinfo_new.logp\n",
" end\n",
" append!(logpdf_p, varinfo_new.logp)\n",
" end\n",
" \n",
" logpdf_q = logpdf - logjac# - [x[1,i] for i in 1:size(x, 2)]\n",
" if print_info\n",
" @info \"logpdf_p\" logpdf_p \n",
" @info \"logpdf_q\" logpdf_q\n",
" end\n",
" logpdf_p = Tracker.collect(logpdf_p)\n",
"# mean(logpdf_q - logpdf_p)#, logpdf_q, logpdf_p\n",
" kl_divergence(exp.(logpdf_q), exp.(logpdf_p))\n",
"end\n",
"\n",
"# Sanity Check\n",
"out = F(rng, model, trans_base, print_info=false, num_samples=10)"
]
},
{
"cell_type": "code",
"execution_count": 366,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# High variability of F. A pontential problem\n",
"using PyPlot\n",
"PyPlot.hist([F(rng, model, trans_base, num_samples=100).data for i in 1:1000], bins=50);"
]
},
{
"cell_type": "code",
"execution_count": 367,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"get_ϕ! (generic function with 1 method)"
]
},
"execution_count": 367,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# To extract all tracked params\n",
"function get_ϕ!(flow::Composed, ϕ)\n",
" for i in flow.ts\n",
" if typeof(i) <: Composed\n",
" get_ϕ!(i, ϕ)\n",
" else\n",
" for j in propertynames(i, false)\n",
" append!(ϕ, [getproperty(i,j)])\n",
" end\n",
" end\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 368,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32-element Array{Any,1}:\n",
" [1.0 0.0; 0.0 1.0] (tracked) \n",
" [0.0; 0.0] (tracked) \n",
" [-0.517413] (tracked) \n",
" [0.718511] (tracked) \n",
" [1.77198; 1.4844] (tracked) \n",
" [1.55561; -2.24004] (tracked) \n",
" [-0.417338; -3.55184] (tracked) \n",
" [-1.46857] (tracked) \n",
" [-0.427506] (tracked) \n",
" [-0.425289] (tracked) \n",
" [-0.33839; -0.133261] (tracked) \n",
" [-1.87512; -0.371103] (tracked) \n",
" [1.91111; 0.645955] (tracked) \n",
" ⋮ \n",
" [0.987035] (tracked) \n",
" [0.842437] (tracked) \n",
" [-0.604559; -0.282666] (tracked)\n",
" [0.524474; 1.61454] (tracked) \n",
" [-1.82578; 2.07189] (tracked) \n",
" [0.209868] (tracked) \n",
" [-0.257118] (tracked) \n",
" [1.20459] (tracked) \n",
" [0.279818; 2.293] (tracked) \n",
" [0.366493; -1.01509] (tracked) \n",
" [0.575756; -1.35863] (tracked) \n",
" [-0.210145] (tracked) "
]
},
"execution_count": 368,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ϕ = []\n",
"get_ϕ!(flow, ϕ)\n",
"ϕ"
]
},
{
"cell_type": "code",
"execution_count": 369,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: 5.687605605915064 (tracked)\n",
"└ @ Main In[369]:6\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..........\n",
"(100/10000) done; loss=5.798888461987018\n",
"..........\n",
"(200/10000) done; loss=5.532302103696383\n",
"..........\n",
"(300/10000) done; loss=5.40636217290014\n",
"..........\n",
"(400/10000) done; loss=5.277394195225437\n",
"..........\n",
"(500/10000) done; loss=5.089418426079885\n",
"..........\n",
"(600/10000) done; loss=5.062693849123746\n",
"..........\n",
"(700/10000) done; loss=5.07495216807752\n",
"..........\n",
"(800/10000) done; loss=4.935018449783298\n",
"..........\n",
"(900/10000) done; loss=4.838258708632543\n",
"..........\n",
"(1000/10000) done; loss=4.6412013565065005\n",
"..........\n",
"(1100/10000) done; loss=4.622866199425547\n",
"..........\n",
"(1200/10000) done; loss=4.322743314518878\n",
"..........\n",
"(1300/10000) done; loss=4.480712672678272\n",
"..........\n",
"(1400/10000) done; loss=4.158097191002094\n",
"..........\n",
"(1500/10000) done; loss=4.264347791127247\n",
"..........\n",
"(1600/10000) done; loss=4.173307206543895\n",
"..........\n",
"(1700/10000) done; loss=3.9286591644533093\n",
"..........\n",
"(1800/10000) done; loss=4.002175241609857\n",
"..........\n",
"(1900/10000) done; loss=3.93243579345907\n",
"..........\n",
"(2000/10000) done; loss=3.886448858843637\n",
"..........\n",
"(2100/10000) done; loss=3.8873126379114797\n",
"..........\n",
"(2200/10000) done; loss=3.931958139182028\n",
"..........\n",
"(2300/10000) done; loss=3.7904366437152492\n",
"..........\n",
"(2400/10000) done; loss=3.751041661311117\n",
"..........\n",
"(2500/10000) done; loss=3.716294025864434\n",
"..........\n",
"(2600/10000) done; loss=3.697364805947052\n",
"..........\n",
"(2700/10000) done; loss=3.646594445889464\n",
"..........\n",
"(2800/10000) done; loss=3.575102239215171\n",
"..........\n",
"(2900/10000) done; loss=3.655719770510173\n",
"..........\n",
"(3000/10000) done; loss=3.5701629347308574\n",
"..........\n",
"(3100/10000) done; loss=3.568918757494621\n",
"..........\n",
"(3200/10000) done; loss=3.6122952747296955\n",
"..........\n",
"(3300/10000) done; loss=3.384335376480747\n",
"..........\n",
"(3400/10000) done; loss=3.345501497890491\n",
"..........\n",
"(3500/10000) done; loss=3.5238478717671415\n",
"..........\n",
"(3600/10000) done; loss=3.410132750398626\n",
"..........\n",
"(3700/10000) done; loss=3.3041982877108067\n",
"..........\n",
"(3800/10000) done; loss=3.2532203455570756\n",
"..........\n",
"(3900/10000) done; loss=3.3720866060014174\n",
"..........\n",
"(4000/10000) done; loss=3.2459250924003067\n",
"..........\n",
"(4100/10000) done; loss=3.214724416423852\n",
"..........\n",
"(4200/10000) done; loss=3.174708212806242\n",
"..........\n",
"(4300/10000) done; loss=3.289272111654915\n",
"..........\n",
"(4400/10000) done; loss=3.250804309497583\n",
"..........\n",
"(4500/10000) done; loss=3.2941254904724366\n",
"..........\n",
"(4600/10000) done; loss=3.178987992552885\n",
"..........\n",
"(4700/10000) done; loss=3.125103800561713\n",
"..........\n",
"(4800/10000) done; loss=3.228475500384154\n",
"..........\n",
"(4900/10000) done; loss=3.0838138888630042\n",
"..........\n",
"(5000/10000) done; loss=3.136007874265254\n",
"..........\n",
"(5100/10000) done; loss=3.0248921318124626\n",
"..........\n",
"(5200/10000) done; loss=3.085563156977406\n",
"..........\n",
"(5300/10000) done; loss=3.0645348100492162\n",
"..........\n",
"(5400/10000) done; loss=2.958291727211016\n",
"..........\n",
"(5500/10000) done; loss=2.9304965829886176\n",
"..........\n",
"(5600/10000) done; loss=2.796315805425349\n",
"..........\n",
"(5700/10000) done; loss=2.180155893884919\n",
"..........\n",
"(5800/10000) done; loss=2.0358903633729097\n",
"..........\n",
"(5900/10000) done; loss=2.144586747322932\n",
"..........\n",
"(6000/10000) done; loss=2.1564737450051426\n",
"..........\n",
"(6100/10000) done; loss=2.167814701455386\n",
"..........\n",
"(6200/10000) done; loss=2.105650026763533\n",
"..........\n",
"(6300/10000) done; loss=2.2037724907621925\n",
"..........\n",
"(6400/10000) done; loss=2.0986737473657593\n",
"..........\n",
"(6500/10000) done; loss=2.139643804825428\n",
"..........\n",
"(6600/10000) done; loss=2.0869740899137\n",
"..........\n",
"(6700/10000) done; loss=2.144057399045887\n",
"..........\n",
"(6800/10000) done; loss=2.140419313082119\n",
"..........\n",
"(6900/10000) done; loss=2.123319137390585\n",
"..........\n",
"(7000/10000) done; loss=2.272158844173498\n",
"..........\n",
"(7100/10000) done; loss=2.077289583190817\n",
"..........\n",
"(7200/10000) done; loss=2.131396233779441\n",
"..........\n",
"(7300/10000) done; loss=2.1276211355646013\n",
"...."
]
},
{
"ename": "InterruptException",
"evalue": "InterruptException:",
"output_type": "error",
"traceback": [
"InterruptException:",
"",
"Stacktrace:",
" [1] gradient_(::getfield(Main, Symbol(\"##396#397\")), ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:7",
" [2] #gradient#24(::Bool, ::Function, ::Function, ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:164",
" [3] gradient(::Function, ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:164",
" [4] top-level scope at In[369]:11",
" [5] top-level scope at util.jl:213",
" [6] top-level scope at In[369]:7"
]
}
],
"source": [
"Phi = Flux.Params(ϕ)\n",
"opt = ADAM(2e-4)\n",
"niters = 10_000\n",
"losses = []\n",
"# Initial F\n",
"@info F(rng, model, trans_base, num_samples=10)\n",
"timeused = @elapsed for iter = 1:niters\n",
" iter % 10 == 0 && print(\".\")\n",
"\n",
" loss = F(rng, model, trans_base, num_samples=10)\n",
" gs = Tracker.gradient(() -> loss, Phi)\n",
" for p in ϕ\n",
" Tracker.update!(opt, p, gs[p])\n",
" end\n",
" append!(losses, loss.data)\n",
" \n",
" if iter % 100 == 0\n",
" mean_loss = mean(losses[end-99:end])\n",
" println(\"\\n($iter/$niters) done; loss=$mean_loss\")\n",
" end\n",
"end\n",
"println(\"done; $(timeused)s used\")"
]
},
{
"cell_type": "code",
"execution_count": 371,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"1-element Array{PyCall.PyObject,1}:\n",
" PyObject <matplotlib.lines.Line2D object at 0x7f648f7ace80>"
]
},
"execution_count": 371,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Moving Mean Loss Plot\n",
"using PyPlot\n",
"PyPlot.plot([mean(losses[i-99:i]) for i in 100:7347])\n",
"# PyPlot.plot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 372,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32-element Array{TrackedArray{Float64,N,A} where A<:AbstractArray{Float64,N} where N,1}:\n",
" [-0.00539625 0.00162034; 0.00945505 -0.000361126] (tracked)\n",
" [-0.00731297; 0.00805593] (tracked) \n",
" [0.00160014] (tracked) \n",
" [-0.00177673] (tracked) \n",
" [0.000831536; -0.000520562] (tracked) \n",
" [-0.00307323; -0.00107566] (tracked) \n",
" [0.00131467; -0.00104153] (tracked) \n",
" [-0.00300304] (tracked) \n",
" [-0.00100636] (tracked) \n",
" [0.000581572] (tracked) \n",
" [-0.000491236; 2.52041e-5] (tracked) \n",
" [0.000513647; -0.00079435] (tracked) \n",
" [0.00075561; -0.00257502] (tracked) \n",
" ⋮ \n",
" [0.0001627] (tracked) \n",
" [-0.000169308] (tracked) \n",
" [-0.000202415; 0.000218317] (tracked) \n",
" [0.00332348; 0.000783873] (tracked) \n",
" [0.00102101; 0.000384939] (tracked) \n",
" [0.00212457] (tracked) \n",
" [0.000564511] (tracked) \n",
" [-0.000680569] (tracked) \n",
" [7.86032e-5; 5.91268e-5] (tracked) \n",
" [0.339627; -1.19814] (tracked) \n",
" [-0.000927946; 0.000120444] (tracked) \n",
" [-0.0133198] (tracked) "
]
},
"execution_count": 372,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = F(rng, model, trans_base, num_samples=10)\n",
"gs = Tracker.gradient(() -> loss, Phi)\n",
"[gs[ϕ[i]] for i in 1:length(ϕ)]"
]
},
{
"cell_type": "code",
"execution_count": 373,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.890143206195424"
]
},
"execution_count": 373,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean((flow(rand(rng, base, 10000)))[1,:].data)"
]
},
{
"cell_type": "code",
"execution_count": 374,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.979383902117426"
]
},
"execution_count": 374,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean((flow(rand(rng, base, 10000)))[2,:].data)"
]
},
{
"cell_type": "code",
"execution_count": 375,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"Figure(PyObject <Figure size 640x480 with 1 Axes>)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"PyPlot.hist(flow(rand(rng, base, 10000))[1,:].data);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment