Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active December 30, 2020 06:45
Show Gist options
  • Save sharanry/86d89d7d37594fc51a3da44e96cbd6ea to your computer and use it in GitHub Desktop.
Save sharanry/86d89d7d37594fc51a3da44e96cbd6ea to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDAdrv.name(CuDevice(0)) = \"GeForce RTX 2060\"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: using CUDA.CuDevice in module Main conflicts with an existing identifier.\n"
]
}
],
"source": [
"using CUDAdrv; @show CUDAdrv.name(CuDevice(0))\n",
"using CUDA\n",
"using AdvancedHMC\n",
"using Zygote\n",
"using Distributions\n",
"using Functors\n",
"using Flux\n",
"\n",
"import Random\n",
"Random.seed!(123);\n",
"CUDA.seed!(123);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: CUDA is on\n",
"└ @ Main In[2]:2\n"
]
}
],
"source": [
"if has_cuda()\t\t# Check if CUDA is available\n",
" @info \"CUDA is on\"\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# toy data generating function f: R^3 -> R^5\n",
"function f(x)\n",
" @assert length(x) == 3\n",
" y = softmax([sin(x[1] +x[3]), cos(x[2] + x[1]), sin(x[1] + x[2]), cos(x[2] + x[3]), cos(x[3] + x[1])])\n",
" argmax(y)\n",
"end\n",
"\n",
"# toy data generation\n",
"N = 1000\n",
"x = [rand(Normal(0, 4), 3) for i in 1:N]\n",
"y = f.(x);\n",
"x = gpu(hcat(x...));\n",
"y = map(x -> Flux.onehot(x, 1:5), y);\n",
"y = gpu(Float32.(hcat(y...)));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Custom Dense Dropout Layer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"toy_model (generic function with 1 method)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"abstract type ProbablisticLayer end\n",
"\n",
"struct DenseProbDropout{F,S,T,P} <: ProbablisticLayer\n",
" σ::F\n",
" W::S\n",
" b::T\n",
" p::P # inferred using MCMC(posterior) / SGD(MAP est)\n",
"end\n",
"\n",
"function DenseProbDropout(in::Integer, out::Integer, σ = identity;\n",
" initW = Flux.glorot_uniform, initb = zeros)\n",
" return DenseProbDropout(σ, CUDA.cu(initW(out, in)), CUDA.cu(initb(out)), CUDA.randn(out))\n",
"end\n",
"\n",
"Functors.functor(a::DenseProbDropout) = ((σ=a.σ, W=a.W, b=a.b), x -> DenseProbDropout(x.σ, x.W, x.b, a.p))\n",
"\n",
"function replace_probs(a::DenseProbDropout, p)\n",
" @assert length(p) == length(a.p)\n",
" return DenseProbDropout(a.σ, a.W, a.b, p)\n",
"end\n",
"\n",
"function replace_probs(c::Chain, probs)\n",
" i = 0\n",
" layers = [\n",
" (layer isa DenseProbDropout) ? \n",
" begin\n",
" i += 1\n",
" replace_probs(layer, probs[i])\n",
" end : layer\n",
" \n",
" for layer in c.layers\n",
" ]\n",
" \n",
" return Chain(layers...)\n",
"end\n",
"\n",
"function (a::DenseProbDropout)(x)\n",
" W, b, σ, p = a.W, a.b, a.σ, a.p\n",
" return σ.((W*x .+ b) .* (p .+ 1))\n",
"end\n",
"\n",
"function dropout_params(a::DenseProbDropout)\n",
" return [a.p]\n",
"end\n",
"\n",
"function dropout_params(model::Chain)\n",
" inf_params = Any[]\n",
" for layer in model\n",
" if layer isa ProbablisticLayer\n",
" append!(inf_params, dropout_params(layer))\n",
" end\n",
" end\n",
" return inf_params\n",
"end\n",
"\n",
"function toy_model(;inp=3, out=5)\n",
" return Chain(\n",
" DenseProbDropout(inp, 10, relu),\n",
" Dense(10, out)\n",
" )\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×10 CuArray{Float32,2}:\n",
" -0.235227 -0.228232 -0.0799891 … -0.257696 -0.133126 -0.0745518\n",
" -0.161473 -0.156672 -0.0569255 -0.176898 -0.104453 -0.160184\n",
" 0.0903603 0.0876732 0.0228563 0.0989917 0.0391316 -0.0821547\n",
" -0.23681 -0.229768 -0.087452 -0.259431 -0.141727 -0.134847\n",
" 0.281871 0.273489 0.101331 0.308797 0.157763 0.00874894"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = toy_model()\n",
"m = gpu(m)\n",
"m(CUDA.rand(3, 10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Zygote gradient check w.r.t Dropout parameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 21.660557 seconds (52.09 M allocations: 2.601 GiB, 3.28% gc time)\n"
]
},
{
"data": {
"text/plain": [
"IdDict{Any,Any} with 4 entries:\n",
" :(Main.y) => Float32[0.000918661 0.000467054 … 0.000817337 0.…\n",
" Float32[-0.568664, 1.061… => Float32[0.130434, 0.271732, 0.0962617, 0.0610819…\n",
" :(Main.m) => (layers = ((σ = nothing, W = Float32[-0.10411 0.…\n",
" :(Main.x) => Float32[-1.23266f-5 -0.000423586 … 8.00842f-5 -9…"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@time grad = gradient(() -> Flux.logitcrossentropy(m(x), y), Flux.Params(dropout_params(m)))\n",
"grad.grads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Probabilistic Inference"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_inf_params = sum(length.(dropout_params(m)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ZeroMeanIsoNormal(\n",
"dim: 10\n",
"μ: 10-element Zeros{Float64}\n",
"Σ: [1.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 0.0; … ; 0.0 0.0 … 1.0 0.0; 0.0 0.0 … 0.0 1.0]\n",
")\n"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prior = MvNormal(n_inf_params, 1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Float32[-0.83967954 1.2163887 … -1.1735133 3.871691; 0.012412363 -6.3201947 … 5.328995 -3.1478426; 5.438018 2.4701521 … 2.3253121 1.6540126], Float32[0.0 0.0 … 1.0 0.0; 1.0 0.0 … 0.0 1.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# let us consider a single minibatch of train_data\n",
"d = (x, y)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-14.486293887900626"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lengths = length.(dropout_params(m))\n",
"cumsum_lengths = cumsum(lengths)\n",
"log_pdf(params) = begin\n",
" -Flux.logitcrossentropy(\n",
" replace_probs(\n",
" m, \n",
" [params[s:e] for (s, e) in zip(cumsum_lengths .- lengths .+ 1, cumsum_lengths)]\n",
" )(first(d)),\n",
" last(d)\n",
" ) + logpdf(prior, params)\n",
"end\n",
"log_pdf(gpu(randn(10)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"Zygote.@adjoint function Iterators.Zip(xs)\n",
" back(dy::NamedTuple{(:is,)}) = tuple(dy.is)\n",
" back(dy::AbstractArray) = ntuple(length(xs)) do d\n",
" dx = map(y->y[d], dy)\n",
" length(dx) == length(xs[d]) ? dx : vcat(dx, falses(length(xs[d])-length(dx)))\n",
" end |> tuple\n",
" back(::AbstractArray{Nothing}) = nothing\n",
" Iterators.Zip(xs), back\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup single step HMC aka Metropolis-adjusted Langevin algorithm (MALA)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"StaticTrajectory{EndPointTS}(integrator=Leapfrog(ϵ=0.1), λ=1))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_θ = vcat(dropout_params(m)...)\n",
"metric = DiagEuclideanMetric(n_inf_params)\n",
"hamiltonian = Hamiltonian(metric, log_pdf, Zygote)\n",
"initial_ϵ = 0.1 #find_good_stepsize(hamiltonian, initial_θ)\n",
"integrator = Leapfrog(initial_ϵ)\n",
"proposal = StaticTrajectory(integrator, 1)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"ename": "LoadError",
"evalue": "GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float64,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}}, Int64) failed\nKernelError: passing and using non-bitstype argument\n\nArgument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}}, which is not isbits:\n .args is of type Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}} which is not isbits.\n .2 is of type Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}} which is not isbits.\n .x is of type Array{Float64,1} which is not isbits.\n\n",
"output_type": "error",
"traceback": [
"GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float64,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}}, Int64) failed\nKernelError: passing and using non-bitstype argument\n\nArgument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}}, which is not isbits:\n .args is of type Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}} which is not isbits.\n .2 is of type Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}} which is not isbits.\n .x is of type Array{Float64,1} which is not isbits.\n\n",
"",
"Stacktrace:",
" [1] check_invocation(::GPUCompiler.CompilerJob, ::LLVM.Function) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\validation.jl:68",
" [2] macro expansion at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\driver.jl:238 [inlined]",
" [3] macro expansion at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\TimerOutputs\\ZmKD7\\src\\TimerOutput.jl:206 [inlined]",
" [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\driver.jl:237",
" [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\driver.jl:39",
" [6] compile at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\driver.jl:35 [inlined]",
" [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\compiler\\execution.jl:310",
" [8] cufunction_compile(::GPUCompiler.FunctionSpec) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\compiler\\execution.jl:305",
" [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var\"#broadcast_kernel#12\",Tuple{CUDA.CuKernelContext,CuDeviceArray{Float64,1,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\cache.jl:40",
" [10] broadcast_kernel at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUArrays\\jhRU7\\src\\host\\broadcast.jl:60 [inlined]",
" [11] cached_compilation at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUCompiler\\uTpNx\\src\\cache.jl:65 [inlined]",
" [12] cufunction(::GPUArrays.var\"#broadcast_kernel#12\", ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Float64,1,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(*),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(abs2),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Base.Broadcast.Extruded{Array{Float64,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\compiler\\execution.jl:297",
" [13] cufunction at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\compiler\\execution.jl:294 [inlined]",
" [14] #launch_heuristic#853 at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\gpuarrays.jl:19 [inlined]",
" [15] launch_heuristic at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\CUDA\\YeS8q\\src\\gpuarrays.jl:17 [inlined]",
" [16] copyto! at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\GPUArrays\\jhRU7\\src\\host\\broadcast.jl:66 [inlined]",
" [17] copyto! at .\\broadcast.jl:886 [inlined]",
" [18] copy at .\\broadcast.jl:862 [inlined]",
" [19] materialize at .\\broadcast.jl:837 [inlined]",
" [20] neg_energy(::Hamiltonian{DiagEuclideanMetric{Float64,Array{Float64,1}},typeof(log_pdf),AdvancedHMC.var\"#∂ℓπ∂θ#55\"{typeof(log_pdf)}}, ::CuArray{Float32,1}, ::CuArray{Float32,1}) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\hamiltonian.jl:134",
" [21] phasepoint(::Hamiltonian{DiagEuclideanMetric{Float64,Array{Float64,1}},typeof(log_pdf),AdvancedHMC.var\"#∂ℓπ∂θ#55\"{typeof(log_pdf)}}, ::CuArray{Float32,1}, ::Array{Float64,1}) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\hamiltonian.jl:87",
" [22] phasepoint at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\hamiltonian.jl:161 [inlined]",
" [23] sample_init at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\sampler.jl:41 [inlined]",
" [24] sample(::Random._GLOBAL_RNG, ::Hamiltonian{DiagEuclideanMetric{Float64,Array{Float64,1}},typeof(log_pdf),AdvancedHMC.var\"#∂ℓπ∂θ#55\"{typeof(log_pdf)}}, ::StaticTrajectory{EndPointTS,Leapfrog{Float64}}, ::CuArray{Float32,1}, ::Int64, ::AdvancedHMC.Adaptation.NoAdaptation, ::Int64; drop_warmup::Bool, verbose::Bool, progress::Bool, pm_next!::typeof(AdvancedHMC.pm_next!)) at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\sampler.jl:172",
" [25] #sample#28 at C:\\Users\\Sharan Yalburgi\\.julia\\packages\\AdvancedHMC\\LcyEE\\src\\sampler.jl:116 [inlined]",
" [26] top-level scope at In[13]:2",
" [27] include_string(::Function, ::Module, ::String, ::String) at .\\loading.jl:1091"
]
}
],
"source": [
"CUDA.allowscalar(false)\n",
"samples = sample(hamiltonian, proposal, initial_θ, 100; progress=true)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# CUDA.allowscalar(true)\n",
"# samples = sample(hamiltonian, proposal, initial_θ, 100; progress=true)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.5.3",
"language": "julia",
"name": "julia-1.5"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment