Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active March 24, 2019 01:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/5ef8e09e61604f10607380467eb82006 to your computer and use it in GitHub Desktop.
Save fehiepsi/5ef8e09e61604f10607380467eb82006 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Original\n",
"\n",
"We have\n",
"$$\n",
"\\begin{align*}\n",
"(x - a)^T A^{-1} (x - a) + (x - b)^T B^{-1} (x - b) &= x^T (A^{-1} + B^{-1}) x - 2x^T(A^{-1}a + B^{-1}b) + (a^T A^{-1} a + b^T B^{-1} b) \\\\\n",
"&= x^T C^{-1} x - 2 x^T C^{-1} c + c^T C^{-1} c + \\bbox[yellow]d \\\\\n",
"&= (x - c)^T C^{-1} (x - c) + \\bbox[yellow]d,\n",
"\\end{align*}\n",
"$$\n",
"where\n",
"$$\n",
"C = (A^{-1} + B^{-1})^{-1},\n",
"$$\n",
"$$\n",
"c = C(A^{-1}a + B^{-1}b),\n",
"$$\n",
"$$\n",
"\\bbox[yellow]d = a^T A^{-1} a + b^T B^{-1} b - c^T C^{-1} c.\n",
"$$\n",
"\n",
"On the other hand, we have:\n",
"$$\n",
"\\begin{align*}\n",
"c^T C^{-1} c &= c^T (A^{-1}a + B^{-1}b) \\\\\n",
"&= (a^T A^{-1} + b^T B^{-1})(A^{-1} + B^{-1})^{-1} (A^{-1}a + B^{-1}b), \\\\\n",
"&= a^T E a + b^T F b + 2a^T G b,\n",
"\\end{align*}\n",
"$$\n",
"where\n",
"$$\n",
"E = A^{-1}(A^{-1} + B^{-1})^{-1}A^{-1} = (A + AB^{-1}A)^{-1} \\overset{Woodbury}{=} A^{-1} - (A + B)^{-1},\n",
"$$\n",
"$$\n",
"F = (similarly) = B^{-1} - (A + B)^{-1},\n",
"$$\n",
"$$\n",
"G = A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} = (A + B)^{-1}.\n",
"$$\n",
"\n",
"Hence from\n",
"$$\n",
"c^T C^{-1} c = a^T A^{-1} a + b^T B^{-1} b - (a - b)^T (A + B)^{-1} (a - b),\n",
"$$\n",
"we come up with other derivations of $\\bbox[yellow]d$:\n",
"$$\n",
"\\bbox[yellow]d = \\bbox[orange]{(a - b)^T (A + B)^{-1} (a - b)} = (a - b)^T A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} (a - b) = \\bbox[yellow]{(a - b)^T A^{-1}CB^{-1} (a - b)}.\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use Cholesky\n",
"Define $L_a = Cholesky(A)$, $L_b = Cholesky(B)$, $L_c = Cholesky(C)$, we have\n",
"$$\n",
"\\begin{align*}\n",
"C &= (L_a^{-T} L_a^{-1} + L_b^{-T} L_b^{-1})^{-1} \\\\\n",
"&= L_a(I + L_a^{T}L_b^{-T} L_b^{-1}L_a)^{-1}L_a^T.\n",
"\\end{align*}\n",
"$$\n",
"So it is enough to solve $L = L_b \\backslash L_a$, compute $L_d = Cholesky(I + L^TL)$ (computing Cholesky here is good because eigen values of $I + L^TL$ are larger than $1$), and finally define\n",
"$L_c = L_aL_d$.\n",
"\n",
"Now, we compute $c$:\n",
"$$\n",
"CA^{-1}a = L_c L_d^T L_a^T L_a^{-T} L_a^{-1} a = L_c L_d^T (L_a \\backslash a),\n",
"$$\n",
"$$\n",
"CB^{-1}b = L_c L_d^T L_a^T L_b^{-T} L_b^{-1} b = L_c L_d^T L^T (L_b \\backslash b),\n",
"$$\n",
"$$\n",
"c = L_c L_d^T \\left[ L_a\\backslash a + L^T (L_b\\backslash b) \\right].\n",
"$$\n",
"\n",
"$L_a\\backslash a$ and $L_b\\backslash b$ can be used to compute the corresponding Mahalanobis terms of $\\bbox[yellow]d$. We can also compute\n",
"$$\n",
"L_c \\backslash c = L_d^T \\left[ L_a\\backslash a + L^T (L_b\\backslash b) \\right].\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logdet terms\n",
"\n",
"To make it compatiable with the real log density of Gaussian, we need to divide $d$ by $2$. It seems that we need to add the following adjusted normalization constant (scaled by a factor of $2$) to $d$ too\n",
"$$\n",
"\\begin{align*}\n",
"\\delta(normalization) &= n\\log(2\\pi) + \\log |A| + n\\log(2\\pi) + \\log |B| - n\\log(2\\pi) - \\log |C| \\\\\n",
"&= n\\log(2\\pi) + \\log |AC^{-1}B| = \\bbox[orange]{n\\log(2\\pi) + \\log |A + B|} \\\\\n",
"&= n\\log(2\\pi) + 2 \\sum \\left[ \\log diag(L_a)) + \\log diag(L_b) - \\log diag(L_c) \\right].\n",
"\\end{align*}\n",
"$$\n",
"\n",
"From $\\bbox[orange]{box}$, it looks like that\n",
"$$\n",
"d = -\\log \\mathcal{N}(a - b \\mid 0, A + B).\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision to scale_tril"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reference: https://math.stackexchange.com/questions/1434899/is-there-a-decomposition-u-ut"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define `flip(X) = X[::-1, ::-1]`, we have `flip(AB) = flip(A)flip(B)`. Indeed,\n",
"$$flip(AB)[i, j] = (AB)[-i, -j] = A[-i,:] * B[:,-j] = flip(A)[i,:] * flip(B)[:,j].$$\n",
"\n",
"Apply this operator, from\n",
"$$\n",
"C = LL^T,\n",
"$$\n",
"we have\n",
"$$\n",
"P = C^{-1} = L^{-T}L^{-1}.\n",
"$$\n",
"Hence\n",
"$$\n",
"flip(P) = flip(L^{-T})flip(L^{-1}).\n",
"$$\n",
"So\n",
"$$\n",
"flip(L^{-T}) = Cholesky(flip(P))\n",
"$$\n",
"and\n",
"$$\n",
"L = flip(Cholesky(flip(P)))^{-T}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def precision_to_scale_tril(P):\n",
" Lf = torch.cholesky(torch.flip(P, (-1, -2)))\n",
" L = torch.inverse(torch.transpose(torch.flip(Lf, (-1, -2)), -1, -2))\n",
" return L.tril() # torch.inverse of a triangular is not a triangular due to precision"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test for agreement on covariance:\n",
"tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482],\n",
" [ 0.4256, 2.3792, -0.6797, -2.5855, -1.5214],\n",
" [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421],\n",
" [-0.3092, -2.5855, 0.7112, 2.9499, 1.5245],\n",
" [-5.3482, -1.5214, -0.6421, 1.5245, 4.0248]])\n",
"tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482],\n",
" [ 0.4256, 2.3792, -0.6797, -2.5856, -1.5214],\n",
" [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421],\n",
" [-0.3092, -2.5856, 0.7112, 2.9500, 1.5246],\n",
" [-5.3482, -1.5214, -0.6421, 1.5246, 4.0248]])\n",
"===============\n",
"test for agreement on scale_tril:\n",
"tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000],\n",
" [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000],\n",
" [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000],\n",
" [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]])\n",
"tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000],\n",
" [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000],\n",
" [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000],\n",
" [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]])\n"
]
}
],
"source": [
"A = torch.randn(5, 5)\n",
"C = A.matmul(A.t())\n",
"P = torch.inverse(C)\n",
"\n",
"L = precision_to_scale_tril(P)\n",
"print(\"test for agreement on covariance:\")\n",
"print(C)\n",
"print(L.matmul(L.t()))\n",
"print(\"===============\")\n",
"print(\"test for agreement on scale_tril:\")\n",
"print(L)\n",
"print(torch.cholesky(C))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@fritzo
Copy link

fritzo commented Mar 1, 2019

This is great! We should add it to funsor/derivations/gaussian.ipynb 😄

@fritzo
Copy link

fritzo commented Mar 1, 2019

Re: normalization: In the current design, Gaussians are essentially numerically stable representations of strictly concave quadratic polynomials, i.e. they are intentionally non-normalized. Thus the log(2 pi) and log det terms appear only when reducing via .logaddexp(). This is a bit tricky since we often reduce only a subset of dims, and only that subset get the log(2 pi) terms. I don't yet know how to compute the log det terms of the partial reduction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment