Skip to content

Instantly share code, notes, and snippets.

@jeffbaumes
Created June 18, 2018 12:49
Show Gist options
  • Save jeffbaumes/83bd453ed70af48352d4329136ad4a84 to your computer and use it in GitHub Desktop.
Save jeffbaumes/83bd453ed70af48352d4329136ad4a84 to your computer and use it in GitHub Desktop.
Pulling out seaborn regression plotter methods into utility functions
def fit_regression(x, y, n_boot=1000, units=None, ci=95, order=1, logistic=False, lowess=False, robust=False, logx=False):
"""Fit the regression model."""
# Create the grid for the regression
x_min, x_max = [x.min(), x.max()]
grid = np.linspace(x_min, x_max, 100)
# Fit the regression
if order > 1:
yhat, yhat_boots = fit_poly(x, y, ci, grid, order, n_boot, units)
elif logistic:
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.genmod.families import Binomial
yhat, yhat_boots = fit_statsmodels(x, y, ci, grid, GLM, nboot, units, family=Binomial())
elif lowess:
# ci = None
grid, yhat = fit_lowess(x, y)
elif robust:
from statsmodels.robust.robust_linear_model import RLM
yhat, yhat_boots = fit_statsmodels(x, y, ci, grid, RLM, nboot, units)
elif logx:
yhat, yhat_boots = fit_logx(x, y, ci, grid, n_boot, units)
else:
yhat, yhat_boots = fit_fast(x, y, ci, grid, n_boot, units)
# Compute the confidence interval at each grid point
if ci is None:
err_bands = None
else:
err_bands = utils.ci(yhat_boots, ci, axis=0)
return grid, yhat, err_bands
def fit_fast(x, y, ci, grid, n_boot, units):
"""Low-level regression and prediction using linear algebra."""
def reg_func(_x, _y):
return np.linalg.pinv(_x).dot(_y)
X, y = np.c_[np.ones(len(x)), x], y
grid = np.c_[np.ones(len(grid)), grid]
yhat = grid.dot(reg_func(X, y))
if ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y, func=reg_func,
n_boot=n_boot, units=units).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
def fit_poly(x, y, ci, grid, order, n_boot, units):
"""Regression using numpy polyfit for higher-order trends."""
def reg_func(_x, _y):
return np.polyval(np.polyfit(_x, _y, order), grid)
x, y = x, y
yhat = reg_func(x, y)
if ci is None:
return yhat, None
yhat_boots = algo.bootstrap(x, y, func=reg_func,
n_boot=n_boot, units=units)
return yhat, yhat_boots
def fit_statsmodels(x, y, ci, grid, model, n_boot, units, **kwargs):
"""More general regression function using statsmodels objects."""
import statsmodels.genmod.generalized_linear_model as glm
X, y = np.c_[np.ones(len(x)), x], y
grid = np.c_[np.ones(len(grid)), grid]
def reg_func(_x, _y):
try:
yhat = model(_y, _x, **kwargs).fit().predict(grid)
except glm.PerfectSeparationError:
yhat = np.empty(len(grid))
yhat.fill(np.nan)
return yhat
yhat = reg_func(X, y)
if ci is None:
return yhat, None
yhat_boots = algo.bootstrap(X, y, func=reg_func,
n_boot=n_boot, units=units)
return yhat, yhat_boots
def fit_lowess(x, y):
"""Fit a locally-weighted regression, which returns its own grid."""
from statsmodels.nonparametric.smoothers_lowess import lowess
grid, yhat = lowess(y, x).T
return grid, yhat
def fit_logx(x, y, ci, grid, n_boot, units):
"""Fit the model in log-space."""
X, y = np.c_[np.ones(len(x)), x], y
grid = np.c_[np.ones(len(grid)), np.log(grid)]
def reg_func(_x, _y):
_x = np.c_[_x[:, 0], np.log(_x[:, 1])]
return np.linalg.pinv(_x).dot(_y)
yhat = grid.dot(reg_func(X, y))
if ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y, func=reg_func,
n_boot=n_boot, units=units).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment