Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active August 1, 2019 04:35
Show Gist options
  • Save sharanry/a4b251c59de4f12cf68d81dea8721ec4 to your computer and use it in GitHub Desktop.
Save sharanry/a4b251c59de4f12cf68d81dea8721ec4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Tracker, Bijectors"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PlanarLayer([-1.38262; 0.107765] (tracked), [2.61178; -0.740486] (tracked), [0.659587; -0.588327] (tracked), [0.910486] (tracked))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flow = PlanarLayer(2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×100 Array{Float64,2}:\n",
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n",
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z = randn(2, 100)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"inv (generic function with 1 method)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Roots, LinearAlgebra\n",
"function inv(flow::PlanarLayer, y)\n",
" function f(y) \n",
" return loss(alpha) = (transpose(flow.w.data)*y)[1] - alpha -(transpose(flow.w.data)*flow.u_hat.data)[1]*tanh(alpha+flow.b.data[1]) \n",
" end\n",
" alphas = transpose([find_zero(f(y[:,i:i]), randn(), Order16()) for i in 1:size(y)[2]])\n",
" z_para = (flow.w.data ./norm(flow.w.data,2))*alphas\n",
" z_per = y - z_para - flow.u_hat.data*tanh.(transpose(flow.w.data)*z_para .+ flow.b.data)\n",
" \n",
" return z_para + z_per\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×100 Array{Float64,2}:\n",
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n",
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×100 Array{Float64,2}:\n",
" -0.727291 -0.785757 -0.142205 1.79629 … -1.27419 -0.997874 0.381794\n",
" -1.43388 -0.139674 -0.613618 0.206456 0.236856 1.44647 0.581317"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv(flow, transform(flow, z).data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RadialLayer([0.704708] (tracked), [-0.340588] (tracked), [0.829187; 0.840773] (tracked))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"flow2 = RadialLayer(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abs(y - flow2.z_not.data) - r * (1 + flow2.β.data / (flow2.α.data + r)) "
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"inv (generic function with 2 methods)"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Roots, LinearAlgebra\n",
"using StatsFuns: softplus\n",
"function inv(flow::RadialLayer, y)\n",
" α = softplus(flow.α_.data[1])\n",
" β_hat = -α + softplus(flow.β.data[1])\n",
" function f(y) \n",
" return loss(r) = norm(y - flow.z_not.data, 2) - r * (1 + β_hat / (α + r)) \n",
" end\n",
" rs = transpose([find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]])\n",
"# print(rs)\n",
" z = (y.-flow.z_not.data) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) )\n",
" return z\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×100 Array{Float64,2}:\n",
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n",
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×100 Array{Float64,2}:\n",
" -2.72452 -1.84197 -0.929259 … -3.06125 -2.22814 -0.120209 \n",
" -4.30542 -1.0546 -1.52326 -0.720227 0.607145 -0.0389007"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv(flow2, transform(flow2, z).data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@xukai92
Copy link

xukai92 commented Jul 31, 2019

Does changing Order16 to other order level improve things?

@sharanry
Copy link
Author

sharanry commented Aug 1, 2019

Does changing Order16 to other order level improve things?

Unfortunately no. Since both inv and logabsdetjac have problems for radial flows, there might be a problem with transformation implementation itself.

In case of radial flows, where the inverse in completely off, one main deviation I had made from the paper, was the usage of softplus to accommodate the constrained nature of the \alpha parameter. Could this be the reason? Is there a better way to accomplish non-negative constraint?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment