Last active
March 24, 2019 01:04
-
-
Save fehiepsi/5ef8e09e61604f10607380467eb82006 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Original\n", | |
"\n", | |
"We have\n", | |
"$$\n", | |
"\\begin{align*}\n", | |
"(x - a)^T A^{-1} (x - a) + (x - b)^T B^{-1} (x - b) &= x^T (A^{-1} + B^{-1}) x - 2x^T(A^{-1}a + B^{-1}b) + (a^T A^{-1} a + b^T B^{-1} b) \\\\\n", | |
"&= x^T C^{-1} x - 2 x^T C^{-1} c + c^T C^{-1} c + \\bbox[yellow]d \\\\\n", | |
"&= (x - c)^T C^{-1} (x - c) + \\bbox[yellow]d,\n", | |
"\\end{align*}\n", | |
"$$\n", | |
"where\n", | |
"$$\n", | |
"C = (A^{-1} + B^{-1})^{-1},\n", | |
"$$\n", | |
"$$\n", | |
"c = C(A^{-1}a + B^{-1}b),\n", | |
"$$\n", | |
"$$\n", | |
"\\bbox[yellow]d = a^T A^{-1} a + b^T B^{-1} b - c^T C^{-1} c.\n", | |
"$$\n", | |
"\n", | |
"On the other hand, we have:\n", | |
"$$\n", | |
"\\begin{align*}\n", | |
"c^T C^{-1} c &= c^T (A^{-1}a + B^{-1}b) \\\\\n", | |
"&= (a^T A^{-1} + b^T B^{-1})(A^{-1} + B^{-1})^{-1} (A^{-1}a + B^{-1}b), \\\\\n", | |
"&= a^T E a + b^T F b + 2a^T G b,\n", | |
"\\end{align*}\n", | |
"$$\n", | |
"where\n", | |
"$$\n", | |
"E = A^{-1}(A^{-1} + B^{-1})^{-1}A^{-1} = (A + AB^{-1}A)^{-1} \\overset{Woodbury}{=} A^{-1} - (A + B)^{-1},\n", | |
"$$\n", | |
"$$\n", | |
"F = (similarly) = B^{-1} - (A + B)^{-1},\n", | |
"$$\n", | |
"$$\n", | |
"G = A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} = (A + B)^{-1}.\n", | |
"$$\n", | |
"\n", | |
"Hence from\n", | |
"$$\n", | |
"c^T C^{-1} c = a^T A^{-1} a + b^T B^{-1} b - (a - b)^T (A + B)^{-1} (a - b),\n", | |
"$$\n", | |
"we come up with other derivations of $\\bbox[yellow]d$:\n", | |
"$$\n", | |
"\\bbox[yellow]d = \\bbox[orange]{(a - b)^T (A + B)^{-1} (a - b)} = (a - b)^T A^{-1}(A^{-1} + B^{-1})^{-1}B^{-1} (a - b) = \\bbox[yellow]{(a - b)^T A^{-1}CB^{-1} (a - b)}.\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Use Cholesky\n", | |
"Define $L_a = Cholesky(A)$, $L_b = Cholesky(B)$, $L_c = Cholesky(C)$, we have\n", | |
"$$\n", | |
"\\begin{align*}\n", | |
"C &= (L_a^{-T} L_a^{-1} + L_b^{-T} L_b^{-1})^{-1} \\\\\n", | |
"&= L_a(I + L_a^{T}L_b^{-T} L_b^{-1}L_a)^{-1}L_a^T.\n", | |
"\\end{align*}\n", | |
"$$\n", | |
"So it is enough to solve $L = L_b \\backslash L_a$, compute $L_d = Cholesky(I + L^TL)$ (computing Cholesky here is good because eigen values of $I + L^TL$ are larger than $1$), and finally define\n", | |
"$L_c = L_aL_d$.\n", | |
"\n", | |
"Now, we compute $c$:\n", | |
"$$\n", | |
"CA^{-1}a = L_c L_d^T L_a^T L_a^{-T} L_a^{-1} a = L_c L_d^T (L_a \\backslash a),\n", | |
"$$\n", | |
"$$\n", | |
"CB^{-1}b = L_c L_d^T L_a^T L_b^{-T} L_b^{-1} b = L_c L_d^T L^T (L_b \\backslash b),\n", | |
"$$\n", | |
"$$\n", | |
"c = L_c L_d^T \\left[ L_a\\backslash a + L^T (L_b\\backslash b) \\right].\n", | |
"$$\n", | |
"\n", | |
"$L_a\\backslash a$ and $L_b\\backslash b$ can be used to compute the corresponding Mahalanobis terms of $\\bbox[yellow]d$. We can also compute\n", | |
"$$\n", | |
"L_c \\backslash c = L_d^T \\left[ L_a\\backslash a + L^T (L_b\\backslash b) \\right].\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Logdet terms\n", | |
"\n", | |
"To make it compatiable with the real log density of Gaussian, we need to divide $d$ by $2$. It seems that we need to add the following adjusted normalization constant (scaled by a factor of $2$) to $d$ too\n", | |
"$$\n", | |
"\\begin{align*}\n", | |
"\\delta(normalization) &= n\\log(2\\pi) + \\log |A| + n\\log(2\\pi) + \\log |B| - n\\log(2\\pi) - \\log |C| \\\\\n", | |
"&= n\\log(2\\pi) + \\log |AC^{-1}B| = \\bbox[orange]{n\\log(2\\pi) + \\log |A + B|} \\\\\n", | |
"&= n\\log(2\\pi) + 2 \\sum \\left[ \\log diag(L_a)) + \\log diag(L_b) - \\log diag(L_c) \\right].\n", | |
"\\end{align*}\n", | |
"$$\n", | |
"\n", | |
"From $\\bbox[orange]{box}$, it looks like that\n", | |
"$$\n", | |
"d = -\\log \\mathcal{N}(a - b \\mid 0, A + B).\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Precision to scale_tril" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Reference: https://math.stackexchange.com/questions/1434899/is-there-a-decomposition-u-ut" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Define `flip(X) = X[::-1, ::-1]`, we have `flip(AB) = flip(A)flip(B)`. Indeed,\n", | |
"$$flip(AB)[i, j] = (AB)[-i, -j] = A[-i,:] * B[:,-j] = flip(A)[i,:] * flip(B)[:,j].$$\n", | |
"\n", | |
"Apply this operator, from\n", | |
"$$\n", | |
"C = LL^T,\n", | |
"$$\n", | |
"we have\n", | |
"$$\n", | |
"P = C^{-1} = L^{-T}L^{-1}.\n", | |
"$$\n", | |
"Hence\n", | |
"$$\n", | |
"flip(P) = flip(L^{-T})flip(L^{-1}).\n", | |
"$$\n", | |
"So\n", | |
"$$\n", | |
"flip(L^{-T}) = Cholesky(flip(P))\n", | |
"$$\n", | |
"and\n", | |
"$$\n", | |
"L = flip(Cholesky(flip(P)))^{-T}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"def precision_to_scale_tril(P):\n", | |
" Lf = torch.cholesky(torch.flip(P, (-1, -2)))\n", | |
" L = torch.inverse(torch.transpose(torch.flip(Lf, (-1, -2)), -1, -2))\n", | |
" return L.tril() # torch.inverse of a triangular is not a triangular due to precision" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test for agreement on covariance:\n", | |
"tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482],\n", | |
" [ 0.4256, 2.3792, -0.6797, -2.5855, -1.5214],\n", | |
" [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421],\n", | |
" [-0.3092, -2.5855, 0.7112, 2.9499, 1.5245],\n", | |
" [-5.3482, -1.5214, -0.6421, 1.5245, 4.0248]])\n", | |
"tensor([[10.3777, 0.4256, -1.2044, -0.3092, -5.3482],\n", | |
" [ 0.4256, 2.3792, -0.6797, -2.5856, -1.5214],\n", | |
" [-1.2044, -0.6797, 6.9600, 0.7112, -0.6421],\n", | |
" [-0.3092, -2.5856, 0.7112, 2.9500, 1.5246],\n", | |
" [-5.3482, -1.5214, -0.6421, 1.5246, 4.0248]])\n", | |
"===============\n", | |
"test for agreement on scale_tril:\n", | |
"tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000],\n", | |
" [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000],\n", | |
" [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000],\n", | |
" [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000],\n", | |
" [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]])\n", | |
"tensor([[ 3.2214, 0.0000, 0.0000, 0.0000, 0.0000],\n", | |
" [ 0.1321, 1.5368, 0.0000, 0.0000, 0.0000],\n", | |
" [-0.3739, -0.4102, 2.5791, 0.0000, 0.0000],\n", | |
" [-0.0960, -1.6742, -0.0044, 0.3713, 0.0000],\n", | |
" [-1.6602, -0.8473, -0.6244, -0.1508, 0.3716]])\n" | |
] | |
} | |
], | |
"source": [ | |
"A = torch.randn(5, 5)\n", | |
"C = A.matmul(A.t())\n", | |
"P = torch.inverse(C)\n", | |
"\n", | |
"L = precision_to_scale_tril(P)\n", | |
"print(\"test for agreement on covariance:\")\n", | |
"print(C)\n", | |
"print(L.matmul(L.t()))\n", | |
"print(\"===============\")\n", | |
"print(\"test for agreement on scale_tril:\")\n", | |
"print(L)\n", | |
"print(torch.cholesky(C))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Re: normalization: In the current design, Gaussian
s are essentially numerically stable representations of strictly concave quadratic polynomials, i.e. they are intentionally non-normalized. Thus the log(2 pi)
and log det
terms appear only when reducing via .logaddexp()
. This is a bit tricky since we often reduce only a subset of dims, and only that subset get the log(2 pi)
terms. I don't yet know how to compute the log det
terms of the partial reduction.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is great! We should add it to funsor/derivations/gaussian.ipynb 😄