Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created December 3, 2019 13:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/97e9f9f6ca042144c6c1e944487f6776 to your computer and use it in GitHub Desktop.
Save kingjr/97e9f9f6ca042144c6c1e944487f6776 to your computer and use it in GitHub Desktop.
OLS vs PLS
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import cross_val_score
from scipy.stats import pearsonr
ols = LinearRegression()
pls = PLSRegression()
n = 1000
dx = 2
dy = 2
repeat = 200
results = np.empty((repeat, 3))
for idx in range(repeat):
# Z is a latent factor making X correlated
# Y is only predicted by X_1
# We aim to compare PLS and OLS on their ability
# to deal with the input correlation
# Z => X -> Y
Z = np.random.randn(n, 1)
A = np.random.randn(1, dx)
B = np.random.randn(dx, dy)
B[0] = 0
X = Z @ A + np.random.randn(n, dx)
y = X @ B + np.random.randn(n, dy)
r, _ = pearsonr(X[:, 0], X[:, 1])
ols_score = cross_val_score(ols, X, y, cv=2, scoring='r2').mean()
pls_score = cross_val_score(pls, X, y, cv=2, scoring='r2').mean()
results[idx] = [r**2, ols_score, pls_score]
# Plot
plt.scatter(results[:, 0], results[:, 1], label='Multiple Regression')
plt.scatter(results[:, 0], results[:, 2], label='Partial Least Square', marker='+')
plt.xlabel('Input Correlation')
plt.ylabel('OOS R2 Score')
plt.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment