Skip to content

Instantly share code, notes, and snippets.

@rajpurkar
Created August 6, 2023 06:31
Show Gist options
  • Save rajpurkar/f96c131ba3aeffb1927255d4363496a9 to your computer and use it in GitHub Desktop.
Save rajpurkar/f96c131ba3aeffb1927255d4363496a9 to your computer and use it in GitHub Desktop.
Model Comparison Using Bootstrapping
"""
Title: Model Comparison Using Bootstrapping
Description:
This code provides a framework for comparing the performance of two binary classification models using bootstrapping.
It consists of a data simulator class, utility functions, and a main routine that combines everything together.
Key Features:
1. Data Simulation: The DataSimulator class allows for the generation of simulated model predictions and test labels.
Users can control the fraction of correct predictions for each model through parameters.
2. Model Comparison: The code calculates differences in metrics (AUC or accuracy) between two models using
bootstrapping and computes the p-value to test the null hypothesis that there is no difference in the performance.
3. Visualization: An optional histogram plot of bootstrapped differences can be generated to visualize the comparison.
4. Flexibility: The code is designed to be easily extended and adapted. Users can provide their own data to the main
function, customize the metrics, or adjust the simulation parameters.
Usage:
The code can be run from the command line with various options to specify the metric, number of samples,
number of resamples, and whether to plot the graph. For example:
python script_name.py --metric auc --num-samples 500 --n-resamples 1000 --plot-graph
Potential Applications:
- Model Selection: Comparing different machine learning models to select the one with better performance.
- Statistical Testing: Evaluating whether the observed differences in model performance are statistically significant.
- Teaching and Learning: A hands-on tool for understanding bootstrapping, hypothesis testing, and model evaluation.
"""
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
import argparse
class DataSimulator:
def __init__(self, num_samples, model_1_match_fraction=0.2, model_2_match_fraction=0.2, random_seed=None):
self.num_samples = num_samples
self.model_1_match_fraction = model_1_match_fraction
self.model_2_match_fraction = model_2_match_fraction
self.random_seed = random_seed
def generate_data(self):
if self.random_seed is not None:
np.random.seed(self.random_seed)
test_labels = np.random.randint(0, 2, size=self.num_samples)
model_1_predictions = self._generate_model_predictions(test_labels, self.model_1_match_fraction)
model_2_predictions = self._generate_model_predictions(test_labels, self.model_2_match_fraction)
return model_1_predictions, model_2_predictions, test_labels
def _generate_model_predictions(self, test_labels, match_fraction):
predictions = np.random.rand(self.num_samples)
match_idx = np.random.choice(self.num_samples, size=int(self.num_samples * match_fraction), replace=False)
predictions[match_idx] = test_labels[match_idx]
return predictions
def calculate_metric(metric, model_predictions, test_labels):
if metric == 'auc':
return roc_auc_score(test_labels, model_predictions)
elif metric == 'accuracy':
binary_predictions = (model_predictions >= 0.5).astype(int)
return accuracy_score(test_labels, binary_predictions)
else:
raise ValueError(f"Invalid metric: {metric}")
def print_model_results(model_name, metric, model_predictions, test_set_labels):
print(f"{model_name}:")
print(f"{metric.capitalize()}: {calculate_metric(metric, model_predictions, test_set_labels)}")
print()
def calculate_bootstrap_difference(metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=1000):
"""
Calculate the metric differences between two models using bootstrapping.
Parameters:
metric (str): The metric to calculate ('auc' or 'accuracy').
model_1_predictions (numpy.ndarray): Model 1 predictions.
model_2_predictions (numpy.ndarray): Model 2 predictions.
test_set_labels (numpy.ndarray): Binary test labels.
n_resamples (int): Number of resamples in bootstrapping.
Returns:
numpy.ndarray: Array of differences in the metric for each bootstrapped sample.
float: Observed difference in the metric on the original test set.
"""
model_1_metric = calculate_metric(metric, model_1_predictions, test_set_labels)
model_2_metric = calculate_metric(metric, model_2_predictions, test_set_labels)
observed_difference = model_1_metric - model_2_metric
differences = np.empty(n_resamples)
n_samples = len(test_set_labels)
for i in range(n_resamples):
bootstrap_indices = np.random.choice(range(n_samples), size=n_samples, replace=True)
new_test_set_labels = test_set_labels[bootstrap_indices]
new_model_1_predictions = model_1_predictions[bootstrap_indices]
new_model_2_predictions = model_2_predictions[bootstrap_indices]
model_1_metric = calculate_metric(metric, new_model_1_predictions, new_test_set_labels)
model_2_metric = calculate_metric(metric, new_model_2_predictions, new_test_set_labels)
differences[i] = model_1_metric - model_2_metric
differences = differences - observed_difference
return differences, observed_difference
def calculate_p_value(differences, observed_difference):
return sum(np.abs(differences) >= np.abs(observed_difference)) / len(differences)
def interpret_p_value(p_value):
null_hypothesis = "There is no difference in the performance of the two models."
alternative_hypothesis = "There is a difference in the performance of the two models."
if p_value < 0.05:
print("Reject the null hypothesis in favor of the alternative hypothesis.")
print(f"{alternative_hypothesis} (p-value = {p_value:e})")
else:
print("Fail to reject the null hypothesis.")
print(f"{null_hypothesis} (p-value = {p_value:e})")
def plot_histogram(differences, observed_difference, metric):
plt.figure()
plt.hist(differences, bins='auto')
plt.axvline(observed_difference, color='r', linestyle='dashed', linewidth=2, label='Observed Difference')
plt.title(f'Histogram of Bootstrapped Differences ({metric.capitalize()})')
plt.xlabel(f'Difference in {metric.capitalize()}')
plt.ylabel('Frequency')
plt.legend()
plt.show()
def parse_arguments():
parser = argparse.ArgumentParser(description="Compare two models using bootstrapping.")
parser.add_argument('--no-print', dest='print_results', action='store_false', default=True, help="Disable printing of model metrics and p-values.")
parser.add_argument('--metric', type=str, choices=['auc', 'accuracy'], default='auc', help="Metric to calculate (auc or accuracy).")
parser.add_argument('--num-samples', type=int, default=500, help="Number of samples in the test set.")
parser.add_argument('--n-resamples', type=int, default=1000, help="Number of resamples in bootstrapping.")
parser.add_argument('--plot-graph', action='store_true', default=False, help="Plot the histogram of bootstrapped differences.")
return parser.parse_args()
def main():
args = parse_arguments()
# Create data simulator
simulator = DataSimulator(num_samples=args.num_samples)
# Generate simulated data
model_1_predictions, model_2_predictions, test_set_labels = simulator.generate_data()
# Print results for each model if specified
if args.print_results:
print_model_results("Model 1", args.metric, model_1_predictions, test_set_labels)
print_model_results("Model 2", args.metric, model_2_predictions, test_set_labels)
# Calculate bootstrap differences for the specified metric
differences, observed_difference = calculate_bootstrap_difference(args.metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=args.n_resamples)
# Calculate the p-value for the specified metric
p_value = calculate_p_value(differences, observed_difference)
# Interpret the p-value for the specified metric
if args.print_results:
print(f"{args.metric.upper()}:")
interpret_p_value(p_value)
# Plot histogram of the bootstrapped differences for the specified metric
if args.plot_graph:
plot_histogram(differences, observed_difference, metric=args.metric)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment