Skip to content

Instantly share code, notes, and snippets.

@tansey
Last active April 29, 2022 03:32
Show Gist options
  • Save tansey/2c34db232d19455c61d77ced03d9310a to your computer and use it in GitHub Desktop.
Save tansey/2c34db232d19455c61d77ced03d9310a to your computer and use it in GitHub Desktop.
Fast multivariate normal sampling for some common cases
'''Fast sampling from a multivariate normal with covariance or precision
parameterization. Supports sparse arrays. Params:
- mu: If provided, assumes the model is N(mu, Q)
- mu_part: If provided, assumes the model is N(Q mu_part, Q).
This is common in many conjugate Gibbs steps.
- sparse: If true, assumes we are working with a sparse Q
- precision: If true, assumes Q is a precision matrix (inverse covariance)
- chol_factor: If true, assumes Q is a (lower triangular) Cholesky
decomposition of the covariance matrix
(or of the precision matrix if precision=True).
Author: Wesley Tansey
Date: 5/8/2019
'''
import numpy as np
import scipy as sp
from scipy.sparse import issparse, coo_matrix, csc_matrix, vstack
from scipy.linalg import solve_triangular
from collections import defaultdict
from sksparse.cholmod import cholesky
def sample_mvn_from_precision(Q, mu=None, mu_part=None, sparse=True, chol_factor=False, Q_shape=None):
'''Fast sampling from a multivariate normal with precision parameterization.
Supports sparse arrays. Params:
- mu: If provided, assumes the model is N(mu, Q^-1)
- mu_part: If provided, assumes the model is N(Q^-1 mu_part, Q^-1)
- sparse: If true, assumes we are working with a sparse Q
- chol_factor: If true, assumes Q is a (lower triangular) Cholesky
decomposition of the precision matrix
'''
assert np.any([Q_shape is not None, not chol_factor, not sparse])
if sparse:
# Cholesky factor LL' = Q of the prior precision Q
factor = cholesky(Q) if not chol_factor else Q
# Solve L'h = z ==> L'^-1 z = h, this is a sample from the prior.
z = np.random.normal(size=Q.shape[0] if not chol_factor else Q_shape[0])
result = factor.solve_Lt(z, False)
if mu_part is not None:
result += factor.solve_A(mu_part)
return result
# Q is the precision matrix. Q_inv would be the covariance.
# We care about Q_inv, not Q. It turns out you can sample from a MVN
# using the precision matrix by doing LL' = Cholesky(Precision)
# then the covariance part of the draw is just inv(L')z where z is
# a standard normal.
Lt = np.linalg.cholesky(Q).T if not chol_factor else Q.T
z = np.random.normal(size=Q.shape[0])
result = solve_triangular(Lt, z, lower=False)
if mu_part is not None:
result += sp.linalg.cho_solve((Lt, False), mu_part)
elif mu is not None:
result += mu
return result
def sample_mvn_from_covariance(Q, mu=None, mu_part=None, sparse=True, chol_factor=False):
'''Fast sampling from a multivariate normal with covariance parameterization.
Supports sparse arrays. Params:
- mu: If provided, assumes the model is N(mu, Q)
- mu_part: If provided, assumes the model is N(Q mu_part, Q)
- sparse: If true, assumes we are working with a sparse Q
- chol_factor: If true, assumes Q is a (lower triangular) Cholesky
decomposition of the covariance matrix
'''
if sparse:
# Cholesky factor LL' = Q of the covariance matrix Q
if chol_factor:
factor = Q
Q = factor.L().dot(factor.L().T)
else:
factor = cholesky(Q)
# Get the sample as mu + Lz for z ~ N(0, I)
z = np.random.normal(size=Q.shape[0])
result = factor.L().dot(z)
if mu_part is not None:
result += Q.dot(mu_part)
elif mu is not None:
result += mu
return result
# Cholesky factor LL' = Q of the covariance matrix Q
if chol_factor:
Lt = Q
Q = Lt.dot(Lt.T)
else:
Lt = np.linalg.cholesky(Q)
# Get the sample as mu + Lz for z ~ N(0, I)
z = np.random.normal(size=Q.shape[0])
result = Lt.dot(z)
if mu_part is not None:
result += Q.dot(mu_part)
elif mu is not None:
result += mu
return result
def sample_mvn(Q, mu=None, mu_part=None, sparse=True, precision=False, chol_factor=False, Q_shape=None):
'''Fast sampling from a multivariate normal with covariance or precision
parameterization. Supports sparse arrays. Params:
- mu: If provided, assumes the model is N(mu, Q)
- mu_part: If provided, assumes the model is N(Q mu_part, Q)
- sparse: If true, assumes we are working with a sparse Q
- precision: If true, assumes Q is a precision matrix (inverse covariance)
- chol_factor: If true, assumes Q is a (lower triangular) Cholesky
decomposition of the covariance matrix
(or of the precision matrix if precision=True).
'''
assert np.any((mu is None, mu_part is None)) # The mean and mean-part are mutually exclusive
if precision:
return sample_mvn_from_precision(Q,
mu=mu, mu_part=mu_part,
sparse=sparse,
chol_factor=chol_factor,
Q_shape=Q_shape)
return sample_mvn_from_covariance(Q,
mu=mu, mu_part=mu_part,
sparse=sparse,
chol_factor=chol_factor)
if __name__ == '__main__':
####################### TESTS FOR MVN SAMPLERS ABOVE #######################
Q = np.array([[1,0.3],[0.3,1]])
Lt = np.linalg.cholesky(Q)
Q_inv = np.linalg.inv(Q)
Lt_inv = np.linalg.cholesky(Q_inv)
sp_Q = csc_matrix(Q)
sp_Lt = cholesky(sp_Q)
sp_Q_inv = csc_matrix(Q_inv)
sp_Lt_inv = cholesky(sp_Q_inv)
import matplotlib.pyplot as plt
import seaborn as sns
fig, axarr = plt.subplots(2,4, figsize=(20,10), sharex=True)
# Covariance, dense, no factor
X = np.array([sample_mvn(Q, sparse=False, chol_factor=False, precision=False) for _ in range(1000)])
axarr[0,0].scatter(X[:,0], X[:,1])
# Covariance, dense, with factor
X = np.array([sample_mvn(Lt, sparse=False, chol_factor=True, precision=False) for _ in range(1000)])
axarr[0,1].scatter(X[:,0], X[:,1])
# Covariance, sparse, no factor
X = np.array([sample_mvn(sp_Q, sparse=True, chol_factor=False, precision=False) for _ in range(1000)])
axarr[0,2].scatter(X[:,0], X[:,1])
# Covariance, sparse, with factor
X = np.array([sample_mvn(sp_Lt, sparse=True, chol_factor=True, precision=False) for _ in range(1000)])
axarr[0,3].scatter(X[:,0], X[:,1])
# Precision, dense, no factor
X = np.array([sample_mvn(Q_inv, sparse=False, chol_factor=False, precision=True) for _ in range(1000)])
axarr[1,0].scatter(X[:,0], X[:,1])
# Precision, dense, with factor
X = np.array([sample_mvn(Lt_inv, sparse=False, chol_factor=True, precision=True) for _ in range(1000)])
axarr[1,1].scatter(X[:,0], X[:,1])
# Precision, sparse, no factor
X = np.array([sample_mvn(sp_Q_inv, sparse=True, chol_factor=False, precision=True) for _ in range(1000)])
axarr[1,2].scatter(X[:,0], X[:,1])
# Precision, sparse, with factor
X = np.array([sample_mvn(sp_Lt_inv, sparse=True, chol_factor=True, precision=True, Q_shape=(2,2)) for _ in range(1000)])
axarr[1,3].scatter(X[:,0], X[:,1])
plt.tight_layout()
plt.savefig('mvn-tests.pdf', bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment