Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save pancodia/593308bab1e1202f7213f3c0b1e3976d to your computer and use it in GitHub Desktop.
Save pancodia/593308bab1e1202f7213f3c0b1e3976d to your computer and use it in GitHub Desktop.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import Ridge, RidgeCV
# import seaborn as sns # failed to show marker with this line uncommented
# Load data (Hitters dataset)
hitters_df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Hitters.csv')
hitters_clean_df = hitters_df.dropna()
hitters_clean_df = pd.get_dummies(hitters_clean_df, drop_first=True)
# Prepare the predictors and response for regression
X = hitters_clean_df.drop('Salary', axis = 1)
y = hitters_clean_df.Salary
# Now we run ridge regression with different values of lambda
models_dict = {} # list to hold our regression models
alphas = 10**np.linspace(-4,2,100) # (i.e. lambda, Sklearn uses alpha) for grid search, evenly spaced on log scale
# Build a model for each lambda
for idx, alpha in enumerate(alphas):
# NOTE normalize = True does not standardize the regressors by standard deviation, it divides by
# the L2 norm of each column.
regr_model = Ridge(alpha=alpha, normalize=True, fit_intercept=True)
regr_model.fit(X, y)
models_dict[alpha] = regr_model
# We can make a plot of each of the ridge regression coeffecients paths as a function of lambda
ridge_coefs = np.empty( (len(alphas), X.shape[1]) ) # matrix to store the coefficients for all models
# get the ridge coeffecients
for idx, model in enumerate(models_dict.values()):
ridge_coefs[idx] = model.coef_[np.newaxis]
# Make plot
fig, ax = plt.subplots(figsize=(16, 8))
ax.plot(alphas, ridge_coefs, linewidth = 2.0);
ax.set_xscale('log')
ax.set_xlabel(r'$\lambda$ (log-scale)', fontsize=15)
ax.set_ylabel('Ridge Coeffs', fontsize=15)
ax.set_title(r'Ridge Coefficients vs $\lambda$', fontsize=20)
ax.legend(X.columns.tolist(), loc='right', bbox_to_anchor=(1.2, 0.5), ncol=1, fontsize=15, shadow=True)
plt.show()
@pancodia
Copy link
Author

pancodia commented Jul 12, 2017

Plot from Python 3.5

plot_python35

Plot from Python 3.6

plot_python36

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment