Skip to content

Instantly share code, notes, and snippets.

@GermanCM
Last active August 10, 2020 08:01
Show Gist options
  • Save GermanCM/11e9bee6caaacf5501d1e4921963826e to your computer and use it in GitHub Desktop.
Save GermanCM/11e9bee6caaacf5501d1e4921963826e to your computer and use it in GitHub Desktop.
Bootstrap confidence interval for linear regression
# bootstrap confidence intervals
import numpy as np
from numpy.random import seed, rand, randint, randn, std
from numpy import mean, median, percentile
from scipy.stats import linregress
# seed random number generator
seed(1)
# prepare data
x = 20 * randn(1000) + 100
y = x + (10 * randn(1000) + 50)
### LINEAR REGRESSION FIT WITH ALL THE DATASET AT ONCE
b1, b0, r_value, p_value, std_err = linregress(x, y)
yhat_whole_ds = b0 + b1 * x
### LINEAR REGRESSION FIT WITH BOOTSTRAP SAMPLES
regression_scores = list()
for _ in range(100):
# bootstrap sample
indices = randint(0, 1000, 550)
x_sample = x[indices]
y_sample = y[indices]
# simple linear regression model
b1_sample, b0_sample, r_value, p_value, std_err = linregress(x_sample, y_sample)
regression_scores.append([b1_sample, b0_sample])
b0_scores = np.array(regression_scores)[:, 1]
b1_scores = np.array(regression_scores)[:, 0]
b0_estimate = median(b0_scores)
b1_estimate = median(b1_scores)
yhat_bootstrapped = b0_estimate + b1_estimate * x
### MODELS VALIDATION
from sklearn.metrics import mean_squared_error
yhat_whole_ds_rmse = mean_squared_error(y, yhat_whole_ds).round(2)
yhat_bootstrapped_rmse = mean_squared_error(y, yhat_bootstrapped).round(2)
b0_estimate = median(b0_scores)
print('b0 50th percentile (median) = %.3f' % median(b0_scores))
alpha = 5.0
lower_p = alpha / 2.0
perc_5th = percentile(b0_scores, lower_p)
upper_p = 100 - (alpha / 2.0)
perc_95th = percentile(b0_scores, upper_p)
print('confidence interval: ', [perc_5th.round(2), perc_95th.round(2)])
print('b0 estimate: ', [perc_5th.round(2), b0_estimate.round(2), perc_95th.round(2)])
b1_estimate = median(b1_scores)
print('b1 50th percentile (median) = %.3f' % median(b1_scores))
alpha = 5.0
lower_p = alpha / 2.0
perc_5th = percentile(b1_scores, lower_p)
upper_p = 100 - (alpha / 2.0)
perc_95th = percentile(b1_scores, upper_p)
print('confidence interval: ', [perc_5th.round(2), perc_95th.round(2)])
print('b1 estimate: ', [perc_5th.round(2), b1_estimate.round(2), perc_95th.round(2)])
# Those ocnfidence intervals are a rough estimate, to capture the variance of the interval alongo the x values, let's proceed as follows:
yhat_bootstrapped_predictions_low_values = []
yhat_bootstrapped_predictions_high_values = []
import pandas as pd
slope_intercept_pairs_df = pd.DataFrame({'b0_scores': b0_scores, 'b1_scores': b1_scores})
#slope_intercept_pairs = zip(b0_scores, b1_scores)
for x_value in x:
#for each x value, we calculate all possible predictions from the bootstrapped estimates
preds_for_this_x = []
for index in slope_intercept_pairs_df.index:
#print(x_value)
interc = slope_intercept_pairs_df.iloc[index]['b0_scores']
slope = slope_intercept_pairs_df.iloc[index]['b1_scores']
y_pred = interc + slope*x_value
preds_for_this_x.append(y_pred)
perc_5th_value = percentile(preds_for_this_x, lower_p)
perc_95th_value = percentile(preds_for_this_x, upper_p)
yhat_bootstrapped_predictions_low_values.append(perc_5th_value)
yhat_bootstrapped_predictions_high_values.append(perc_95th_value)
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=yhat_bootstrapped_predictions_low_values,
fill=None,
mode='lines',
line_color='indigo',
))
fig.add_trace(go.Scatter(x=x, y=yhat_whole_ds,
fill='tonexty',
mode='lines',
line_color='red',
))
fig.add_trace(go.Scatter(x=x, y=yhat_bootstrapped_predictions_high_values,
fill='tonexty', # fill area between yhat_bootstrapped_predictions_low_values and yhat_bootstrapped_predictions_high_values
mode='lines', line_color='indigo'))
fig.update_layout(annotations=[dict(xref='paper',
yref='paper',
x=0.5, y=1,
showarrow=False,
text='95% confidence interval')])
fig.show()
@GermanCM
Copy link
Author

GermanCM commented May 4, 2020

linear_reg_CI_example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment