Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created February 8, 2014 19:55
Show Gist options
  • Save mattjj/8889225 to your computer and use it in GitHub Desktop.
Save mattjj/8889225 to your computer and use it in GitHub Desktop.
{
"worksheets": [
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Setting MNIW Priors #\n\nFor an autoregressive model with $k$ lags and parameters $(A,\\Sigma)$, for each time $t$ if we write the vector of lagged observations as $\\tilde{y} \\triangleq (y_{t-1}, y_{t-2}, \\ldots, y_{t-k})$ then we can write the generative model as\n$y | \\tilde{y},A,\\Sigma \\sim \\mathcal{N}(A \\tilde{y}, \\Sigma)$\nso that the likelihood is\n\n$$ p(y|\\tilde{y},A,\\Sigma) \\propto \\exp \\left\\{ -\\frac{1}{2}(y-A\\tilde{y})^T \\Sigma^{-1} (y-A\\tilde{y}) \\right\\}. $$\n\nThe natural conjugate prior over $(A,\\Sigma)$ is the Matrix-Normal-Inverse-Wishart (MNIW) with hyperparameters $(\\nu_0,S_0,M_0,K_0)$, where $(\\Sigma,A) \\sim \\text{MNIW}(\\nu_0,S_0,M_0,K_0)$ if\n\n$$ \\Sigma \\sim \\text{InvWishart}(\\nu_0,S_0) $$\n$$ A | \\Sigma \\sim \\text{MN}(M_0,\\Sigma,K_0). $$\n\nTherefore to set a prior over $(\\Sigma,A)$ we must choose the hyperparameters $(\\nu_0,S_0,M_0,K_0)$. This document aims to explore how these hyperparameters affect the prior."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Prior over $\\Sigma$ ##\n\nAs described on [the Wikipedia page for the Inverse Wishart](https://en.wikipedia.org/wiki/Inverse-Wishart_distribution), we can think of $\\Sigma \\sim \\text{IW}(\\nu_0,S_0)$ as setting $\\Sigma$ to be the sample covariance of a set of $\\nu_0$ samples from a Gaussian distribution with covariance $S_0 / \\nu_0$, i.e.\n\n$$ v_i \\sim \\mathcal{N}(0,S_0/\\nu_0) \\quad i=1,2,\\ldots,\\nu_0 $$\n$$ \\Sigma = \\frac{1}{\\nu_0} \\sum_i v_i v_i^T $$\n\nTherefore, we can interpret $\\nu_0$ as the size of the \"prior dataset\" and $S_0$ as the sample second moment of the \"prior dataset,\" and we can see that setting $\\nu_0$ larger makes the prior stronger (more concentrated) and setting S_0 determines the prior's \"center\".\n\nBelow are some examples using the MNIW object. The settings for $M$ and $K_0$ don't matter for $\\Sigma$, and we discuss them in the next section."
},
{
"metadata": {},
"input": "import numpy as np\nimport distributions as d",
"cell_type": "code",
"prompt_number": 9,
"outputs": [],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "# set the nominal covariance to be 2*np.eye(2), strong prior\na = d.MNIW(nu_0=1000,S_0=1000*2*np.eye(2),M_0=np.zeros((2,4)),Kinv_0=np.eye(4))",
"cell_type": "code",
"prompt_number": 12,
"outputs": [],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "# generate samples from the prior by calling a.resample()\n# since the prior is strong with nu_0=1000, we expect the samples of Sigma\n# to be concentrated around 2*np.eye(2)\nfor itr in range(5):\n a.resample()\n print a.sigma\n print ''",
"cell_type": "code",
"prompt_number": 16,
"outputs": [
{
"output_type": "stream",
"text": "[[ 3.4182838 -1.24441315]\n [-1.24441315 2.6343571 ]]\n\n[[ 3.24937813 0.8742203 ]\n [ 0.8742203 4.61084024]]\n\n[[ 3.73923242 1.4215106 ]\n [ 1.4215106 1.48087701]]\n\n[[ 5.23582888 2.82447629]\n [ 2.82447629 2.16336364]]\n\n[[ 0.55187846 0.05826167]\n [ 0.05826167 1.38899071]]\n\n",
"stream": "stdout"
}
],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "# let's turn down nu_0 to make a weaker prior\n# NOTE we must always have nu_0 > n to avoid rank-deficiency\na = d.MNIW(nu_0=3,S_0=3*2*np.eye(2),M_0=np.zeros((2,4)),Kinv_0=np.eye(4))",
"cell_type": "code",
"prompt_number": 14,
"outputs": [],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "# we expect these samples of Sigma to be much less concentrated\nfor itr in range(5):\n a.resample()\n print a.sigma\n print ''",
"cell_type": "code",
"prompt_number": 17,
"outputs": [
{
"output_type": "stream",
"text": "[[ 3.05083822 2.12643776]\n [ 2.12643776 2.82605039]]\n\n[[ 91.49295688 27.8840124 ]\n [ 27.8840124 12.91677358]]\n\n[[ 16.8297896 -8.50800728]\n [ -8.50800728 5.70767819]]\n\n[[ 6.28822604 -0.96009229]\n [ -0.96009229 33.55133718]]\n\n[[ 50.14242647 -7.30162145]\n [ -7.30162145 4.78211535]]\n\n",
"stream": "stdout"
}
],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Prior over $A$ ##\n\nAs described on [the Wikipedia page for the Matrix Normal distribution](https://en.wikipedia.org/wiki/Matrix_normal_distribution), when we write $A \\sim \\text{MN}(M,U,V)$ where $M$ is an $n \\times p$ matrix then $A$ is also an $n \\times p$ matrix, and $\\mathbb{E}[A] = M$. Furthermore, $U$ is $n \\times n$ and and describes the covariance of the columns of $A$, while $V$ is $p \\times p$ and describes the covariance of the rows of $A$. That is, we can generate $A \\sim \\text{MN}(M,U,V)$ by first generating a matrix $G$ with i.i.d. Gaussian entries and then applying $U$ to the left and $V$ to the right:\n\n$$ (G)_{ij} \\sim \\mathcal{N}(0,1) \\quad i=1,2,\\ldots,n,~j=1,2,\\ldots,p $$\n$$ A = U^{\\frac{1}{2}} G V^{\\frac{1}{2}} $$\n\nTherefore, $M$ is always the prior mean value for $A$ and the two covariance matrices control the concentration of the prior around that mean.\n\nIn the case of the MNIW, the sampled $\\Sigma$ controls the column covariance, which is a bit weird; it means that the uncertainty over $A$ is related to the emission noise. That's a result of conjugacy, and it's a good reason to use a non-conjugate prior that deals with $A$ and $\\Sigma$ separately.\n\nHowever, we can independently choose the covariance on the rows, which is $K_0$. For the internals of the code, it happens to be easier to pass in $K_0^{-1}$, so keep in mind that passing `Kinv_0=10*np.eye(4)` is like passing `K_0=0.1*np.eye(4)`; in other words, large values for `Kinv_0` means small covariance and a very concentrated prior over $A$."
},
{
"metadata": {},
"input": "# set nu_0 large so we know what Sigma is going to be\n# setting Kinv_0 large should concentrate the prior over A around its mean\na = d.MNIW(nu_0=1000,S_0=np.eye(2),M_0=np.zeros((2,4)),Kinv_0=100*np.eye(4))",
"cell_type": "code",
"prompt_number": 18,
"outputs": [],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "for itr in range(5):\n a.resample()\n print a.A\n print ''",
"cell_type": "code",
"prompt_number": 19,
"outputs": [
{
"output_type": "stream",
"text": "[[ 0.00110849 -0.00010401 -0.00161664 0.00330576]\n [ 0.00045621 0.00086333 0.00118468 0.00168913]]\n\n[[-0.00538275 0.00149272 0.002627 0.0032859 ]\n [ 0.00150561 -0.01069009 0.00119642 0.00398182]]\n\n[[ 0.00681789 -0.00603051 0.00162239 -0.00023028]\n [ 0.00147648 0.00235844 0.00421152 0.00217429]]\n\n[[-0.00455712 -0.00298411 -0.0045271 -0.00390966]\n [ 0.00010498 0.0049564 -0.00225751 -0.00314763]]\n\n[[ 0.00537934 -0.00315608 -0.00329924 -0.00476248]\n [-0.00145952 0.00355204 0.00255612 -0.00200533]]\n\n",
"stream": "stdout"
}
],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "# now if we set Kinv_0 small, samples of A should be all over the place\na = d.MNIW(nu_0=1000,S_0=np.eye(2),M_0=np.zeros((2,4)),Kinv_0=0.1*np.eye(4))",
"cell_type": "code",
"prompt_number": 20,
"outputs": [],
"language": "python",
"collapsed": false
},
{
"metadata": {},
"input": "for itr in range(5):\n a.resample()\n print a.A\n print ''",
"cell_type": "code",
"prompt_number": 21,
"outputs": [
{
"output_type": "stream",
"text": "[[ 0.00045903 0.08387075 -0.00253193 0.12985548]\n [-0.04970261 0.12136798 0.18712929 0.00040652]]\n\n[[ 4.10872516e-02 6.47449685e-05 -4.32808116e-02 -7.80216619e-02]\n [ 9.50728420e-02 -2.20871548e-02 3.10194686e-02 -3.45082306e-02]]\n\n[[-0.04491345 -0.17798445 -0.02367881 0.09619346]\n [-0.02357882 0.02988262 0.0379438 -0.05779257]]\n\n[[-0.07156387 -0.03290019 0.105906 0.06760734]\n [-0.32260077 0.09236214 -0.09012521 0.05573863]]\n\n[[-0.11951635 0.10534495 -0.01634302 0.09663397]\n [ 0.13913644 -0.1410027 -0.10963274 0.06805615]]\n\n",
"stream": "stdout"
}
],
"language": "python",
"collapsed": false
}
],
"metadata": {}
}
],
"metadata": {
"name": "MNIW Hyperparameters",
"gist_id": "8889225"
},
"nbformat": 3
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment