Skip to content

Instantly share code, notes, and snippets.

@zhiyzuo
Last active September 23, 2018 22:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhiyzuo/04dd11a8fed2f400a31be5cb543790f9 to your computer and use it in GitHub Desktop.
Save zhiyzuo/04dd11a8fed2f400a31be5cb543790f9 to your computer and use it in GitHub Desktop.
a wrapper function for linear regression analysis using statsmodels
import pandas as pd
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
def get_vif(data):
d = {data.columns.values[i]: variance_inflation_factor(data.values, i)\
for i in range(data.shape[1])}
vif_series = pd.Series(d)
return vif_series
def my_fitmodel(formula, data, target_col=-1):
model = smf.ols(formula, data=data).fit()
if target_col == -1:
target_col = data.columns.size-1
feature_col = [item.strip() for item in formula.split('~')[-1].split('+')]
result = {'vif': get_vif(data[feature_col])}
for item in ['params', 'pvalues', 'tvalues', 'HC0_se']:
result[item] = eval('model.{}.copy()'.format(item))
result['conf_int_lo'] = model.conf_int()[0]
result['conf_int_hi'] = model.conf_int()[1]
return {'result': pd.DataFrame(result),
'rsquared_adj': model.rsquared_adj,
'rsquared': model.rsquared,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment