Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active July 27, 2021 00:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/8bcd25081776dd6e24972d9980aa9a9a to your computer and use it in GitHub Desktop.
Save brandonwillard/8bcd25081776dd6e24972d9980aa9a9a to your computer and use it in GitHub Desktop.
SVD-based Gaussian conjugate update
import numpy as np
sigma2_true = 2.0
def simulate_regression(N=100, M=1000):
X = np.random.normal(size=(N, M))
beta_true = np.random.normal(size=(M,))
y = X.dot(beta_true) + np.random.normal(0, np.sqrt(sigma2_true), size=(N,))
return X, y, beta_true
X, y, beta_true = simulate_regression()
sigma2 = 2.0
S_inv = np.eye(M) / sigma2
omega2 = 0.2
V_inv = np.eye(N) / omega2
def naive_solve(X, y):
E = X.T.dot(V_inv).dot(X) + S_inv
beta_hat = np.linalg.solve(E, y.T.dot(X))
return beta_hat
beta_hat_ref, *_ = np.linalg.lstsq(X, y)
beta_hat = naive_solve(X, y)
%timeit naive_solve(X, y)
# 26 ms ± 8.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.lstsq(X, y)
# 13 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# The norms against the true beta should be comparable
np.linalg.norm(beta_hat_ref - beta_true, 2)
np.linalg.norm(beta_hat - beta_true, 2)
U_X, d_X, V_X_T = np.linalg.svd(X, full_matrices=False)
R_X = U_X * d_X
S_inv_svd = np.eye(N) / sigma2
def naive_svd_solve(X, y):
Z = R_X.T.dot(V_inv).dot(R_X) + S_inv_svd
eta_hat = np.linalg.solve(Z, y.T.dot(R_X))
beta_hat = V_X_T.T.dot(eta_hat)
return beta_hat
beta_hat = naive_svd_solve(X, y)
np.linalg.norm(beta_hat_ref - beta_hat, 2)
np.linalg.norm(beta_hat - beta_true, 2)
%timeit naive_svd_solve(X, y)
# 298 µs ± 14.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# This will take forever using the non-SVD approach
X, y, beta_true = simulate_regression(M=50000)
# This is only ever performed once for a given `X`
U_X, d_X, V_X_T = np.linalg.svd(X, full_matrices=False)
R_X = U_X * d_X
S_inv_svd = np.eye(X.shape[0]) / sigma2
beta_hat = naive_svd_solve(X, y)
%timeit naive_svd_solve(X, y)
# 3.21 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment