Skip to content

Instantly share code, notes, and snippets.

Created August 6, 2023 06:31
Show Gist options
  • Save rajpurkar/f96c131ba3aeffb1927255d4363496a9 to your computer and use it in GitHub Desktop.
Save rajpurkar/f96c131ba3aeffb1927255d4363496a9 to your computer and use it in GitHub Desktop.
Model Comparison Using Bootstrapping
Title: Model Comparison Using Bootstrapping
This code provides a framework for comparing the performance of two binary classification models using bootstrapping.
It consists of a data simulator class, utility functions, and a main routine that combines everything together.
Key Features:
1. Data Simulation: The DataSimulator class allows for the generation of simulated model predictions and test labels.
Users can control the fraction of correct predictions for each model through parameters.
2. Model Comparison: The code calculates differences in metrics (AUC or accuracy) between two models using
bootstrapping and computes the p-value to test the null hypothesis that there is no difference in the performance.
3. Visualization: An optional histogram plot of bootstrapped differences can be generated to visualize the comparison.
4. Flexibility: The code is designed to be easily extended and adapted. Users can provide their own data to the main
function, customize the metrics, or adjust the simulation parameters.
The code can be run from the command line with various options to specify the metric, number of samples,
number of resamples, and whether to plot the graph. For example:
python --metric auc --num-samples 500 --n-resamples 1000 --plot-graph
Potential Applications:
- Model Selection: Comparing different machine learning models to select the one with better performance.
- Statistical Testing: Evaluating whether the observed differences in model performance are statistically significant.
- Teaching and Learning: A hands-on tool for understanding bootstrapping, hypothesis testing, and model evaluation.
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
import argparse
class DataSimulator:
def __init__(self, num_samples, model_1_match_fraction=0.2, model_2_match_fraction=0.2, random_seed=None):
self.num_samples = num_samples
self.model_1_match_fraction = model_1_match_fraction
self.model_2_match_fraction = model_2_match_fraction
self.random_seed = random_seed
def generate_data(self):
if self.random_seed is not None:
test_labels = np.random.randint(0, 2, size=self.num_samples)
model_1_predictions = self._generate_model_predictions(test_labels, self.model_1_match_fraction)
model_2_predictions = self._generate_model_predictions(test_labels, self.model_2_match_fraction)
return model_1_predictions, model_2_predictions, test_labels
def _generate_model_predictions(self, test_labels, match_fraction):
predictions = np.random.rand(self.num_samples)
match_idx = np.random.choice(self.num_samples, size=int(self.num_samples * match_fraction), replace=False)
predictions[match_idx] = test_labels[match_idx]
return predictions
def calculate_metric(metric, model_predictions, test_labels):
if metric == 'auc':
return roc_auc_score(test_labels, model_predictions)
elif metric == 'accuracy':
binary_predictions = (model_predictions >= 0.5).astype(int)
return accuracy_score(test_labels, binary_predictions)
raise ValueError(f"Invalid metric: {metric}")
def print_model_results(model_name, metric, model_predictions, test_set_labels):
print(f"{metric.capitalize()}: {calculate_metric(metric, model_predictions, test_set_labels)}")
def calculate_bootstrap_difference(metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=1000):
Calculate the metric differences between two models using bootstrapping.
metric (str): The metric to calculate ('auc' or 'accuracy').
model_1_predictions (numpy.ndarray): Model 1 predictions.
model_2_predictions (numpy.ndarray): Model 2 predictions.
test_set_labels (numpy.ndarray): Binary test labels.
n_resamples (int): Number of resamples in bootstrapping.
numpy.ndarray: Array of differences in the metric for each bootstrapped sample.
float: Observed difference in the metric on the original test set.
model_1_metric = calculate_metric(metric, model_1_predictions, test_set_labels)
model_2_metric = calculate_metric(metric, model_2_predictions, test_set_labels)
observed_difference = model_1_metric - model_2_metric
differences = np.empty(n_resamples)
n_samples = len(test_set_labels)
for i in range(n_resamples):
bootstrap_indices = np.random.choice(range(n_samples), size=n_samples, replace=True)
new_test_set_labels = test_set_labels[bootstrap_indices]
new_model_1_predictions = model_1_predictions[bootstrap_indices]
new_model_2_predictions = model_2_predictions[bootstrap_indices]
model_1_metric = calculate_metric(metric, new_model_1_predictions, new_test_set_labels)
model_2_metric = calculate_metric(metric, new_model_2_predictions, new_test_set_labels)
differences[i] = model_1_metric - model_2_metric
differences = differences - observed_difference
return differences, observed_difference
def calculate_p_value(differences, observed_difference):
return sum(np.abs(differences) >= np.abs(observed_difference)) / len(differences)
def interpret_p_value(p_value):
null_hypothesis = "There is no difference in the performance of the two models."
alternative_hypothesis = "There is a difference in the performance of the two models."
if p_value < 0.05:
print("Reject the null hypothesis in favor of the alternative hypothesis.")
print(f"{alternative_hypothesis} (p-value = {p_value:e})")
print("Fail to reject the null hypothesis.")
print(f"{null_hypothesis} (p-value = {p_value:e})")
def plot_histogram(differences, observed_difference, metric):
plt.hist(differences, bins='auto')
plt.axvline(observed_difference, color='r', linestyle='dashed', linewidth=2, label='Observed Difference')
plt.title(f'Histogram of Bootstrapped Differences ({metric.capitalize()})')
plt.xlabel(f'Difference in {metric.capitalize()}')
def parse_arguments():
parser = argparse.ArgumentParser(description="Compare two models using bootstrapping.")
parser.add_argument('--no-print', dest='print_results', action='store_false', default=True, help="Disable printing of model metrics and p-values.")
parser.add_argument('--metric', type=str, choices=['auc', 'accuracy'], default='auc', help="Metric to calculate (auc or accuracy).")
parser.add_argument('--num-samples', type=int, default=500, help="Number of samples in the test set.")
parser.add_argument('--n-resamples', type=int, default=1000, help="Number of resamples in bootstrapping.")
parser.add_argument('--plot-graph', action='store_true', default=False, help="Plot the histogram of bootstrapped differences.")
return parser.parse_args()
def main():
args = parse_arguments()
# Create data simulator
simulator = DataSimulator(num_samples=args.num_samples)
# Generate simulated data
model_1_predictions, model_2_predictions, test_set_labels = simulator.generate_data()
# Print results for each model if specified
if args.print_results:
print_model_results("Model 1", args.metric, model_1_predictions, test_set_labels)
print_model_results("Model 2", args.metric, model_2_predictions, test_set_labels)
# Calculate bootstrap differences for the specified metric
differences, observed_difference = calculate_bootstrap_difference(args.metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=args.n_resamples)
# Calculate the p-value for the specified metric
p_value = calculate_p_value(differences, observed_difference)
# Interpret the p-value for the specified metric
if args.print_results:
# Plot histogram of the bootstrapped differences for the specified metric
if args.plot_graph:
plot_histogram(differences, observed_difference, metric=args.metric)
if __name__ == "__main__":
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment