Created
October 24, 2017 01:51
-
-
Save willemolding/f0f442b10be9efe1678f87ec754496d0 to your computer and use it in GitHub Desktop.
Condition a multivariate Gaussian on observations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def condition(mu, sigma, observations): | |
""" | |
Condition on the observations to produce a new mean and covariance matrix | |
Parameters: | |
---------- | |
mu - the mean vector | |
sigma - the covariance matrix | |
observations - the observed variables. | |
Variables that are not observed should be Nan in this vector. | |
The resulting distribution will be over these variables in the order they appear | |
returns: | |
------- | |
mu_bar - the conditional mean over the unobserved variables | |
sigma_bar - the conditional covariance over unobserved variables | |
""" | |
not_obs = np.isnan(observations) # masks the unobserved variables as true | |
obs = np.logical_not(not_obs) # masks the observed variables as true | |
a = observations[obs] # observed variables only | |
mu1 = mu[not_obs] | |
mu2 = mu[obs] | |
sigma_11 = sigma[np.ix_(not_obs, not_obs)] | |
sigma_22 = sigma[np.ix_(obs, obs)] | |
sigma_21 = sigma[np.ix_(obs, not_obs)] | |
sigma_12 = sigma[np.ix_(not_obs, obs)] | |
mu_bar = mu1 + sigma_12.dot(np.linalg.inv(sigma_22)).dot(a - mu2) | |
sigma_bar = sigma_11 - sigma_12.dot(np.linalg.inv(sigma_22)).dot(sigma_21) | |
return mu_bar, sigma_bar |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment