Simulation for R squared CI coverage
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.linear_model import LinearRegression | |
from scipy.special import expit, logit | |
from itertools import product | |
import pandas as pd | |
import seaborn as sns | |
def make_regression_data(n, alpha, sigma): | |
x = np.random.normal(size = n) | |
X = x.reshape(-1,1) | |
y = alpha*x + np.random.normal(0, sigma, size = n) | |
return (X, y) | |
def interval(ytest, ypred): | |
residuals = (ytest - ypred) | |
squared_error = np.power(residuals, 2) | |
ci = squared_error.mean() + np.array([-1.96,1.96]) * squared_error.std(ddof=1) / np.sqrt(squared_error.size) | |
ci = ci[::-1] | |
return 1 - ci/np.var(ytest) | |
def do_fit(n, alpha, sigma): | |
X, y = make_regression_data(n, alpha, sigma) | |
model = LinearRegression() | |
model.fit(X,y) | |
ypred = model.predict(X) | |
ci = interval(y, ypred) | |
return ci | |
def experiment(R2, n): | |
sigma = 1 | |
alpha = sigma * np.sqrt( 1/(1-R2) -1 ) | |
num_sims = 5000 | |
confidence_intervals = np.zeros((num_sims,2)) | |
for i in range(num_sims): | |
ci = do_fit(n, alpha, sigma) | |
confidence_intervals[i] = ci | |
lower_limit, upper_limit = confidence_intervals.T | |
coverage = np.mean((lower_limit<R2)&(upper_limit>R2)) | |
realistic = np.mean((lower_limit<0)|(upper_limit>1)) | |
return coverage, realistic | |
R2_values = np.arange(0.01, 0.99, 0.1) | |
sample_sizes = [50, 100, 250, 1000, 10000] | |
params = list(product(R2_values, sample_sizes)) | |
results = [experiment(*p) for p in params] | |
df = pd.DataFrame(params, columns = ['R2','sample_size']) | |
df['R2'] = df.R2.round(2) | |
df['coverage'] = [x[0] for x in results] | |
df['realistic'] = [x[1] for x in results] | |
pivoted_coverage = df.pivot('R2','sample_size','coverage') | |
pivoted_realistic = df.pivot('R2','sample_size','realistic') | |
fig, ax = plt.subplots(dpi = 240) | |
sns.heatmap(pivoted_coverage, square = True, cmap = 'RdBu_r', center = 0.95, ax = ax) | |
plt.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment