Created
August 6, 2023 06:31
-
-
Save rajpurkar/f96c131ba3aeffb1927255d4363496a9 to your computer and use it in GitHub Desktop.
Model Comparison Using Bootstrapping
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Title: Model Comparison Using Bootstrapping | |
Description: | |
This code provides a framework for comparing the performance of two binary classification models using bootstrapping. | |
It consists of a data simulator class, utility functions, and a main routine that combines everything together. | |
Key Features: | |
1. Data Simulation: The DataSimulator class allows for the generation of simulated model predictions and test labels. | |
Users can control the fraction of correct predictions for each model through parameters. | |
2. Model Comparison: The code calculates differences in metrics (AUC or accuracy) between two models using | |
bootstrapping and computes the p-value to test the null hypothesis that there is no difference in the performance. | |
3. Visualization: An optional histogram plot of bootstrapped differences can be generated to visualize the comparison. | |
4. Flexibility: The code is designed to be easily extended and adapted. Users can provide their own data to the main | |
function, customize the metrics, or adjust the simulation parameters. | |
Usage: | |
The code can be run from the command line with various options to specify the metric, number of samples, | |
number of resamples, and whether to plot the graph. For example: | |
python script_name.py --metric auc --num-samples 500 --n-resamples 1000 --plot-graph | |
Potential Applications: | |
- Model Selection: Comparing different machine learning models to select the one with better performance. | |
- Statistical Testing: Evaluating whether the observed differences in model performance are statistically significant. | |
- Teaching and Learning: A hands-on tool for understanding bootstrapping, hypothesis testing, and model evaluation. | |
""" | |
import numpy as np | |
from sklearn.metrics import roc_auc_score, accuracy_score | |
import matplotlib.pyplot as plt | |
import argparse | |
class DataSimulator: | |
def __init__(self, num_samples, model_1_match_fraction=0.2, model_2_match_fraction=0.2, random_seed=None): | |
self.num_samples = num_samples | |
self.model_1_match_fraction = model_1_match_fraction | |
self.model_2_match_fraction = model_2_match_fraction | |
self.random_seed = random_seed | |
def generate_data(self): | |
if self.random_seed is not None: | |
np.random.seed(self.random_seed) | |
test_labels = np.random.randint(0, 2, size=self.num_samples) | |
model_1_predictions = self._generate_model_predictions(test_labels, self.model_1_match_fraction) | |
model_2_predictions = self._generate_model_predictions(test_labels, self.model_2_match_fraction) | |
return model_1_predictions, model_2_predictions, test_labels | |
def _generate_model_predictions(self, test_labels, match_fraction): | |
predictions = np.random.rand(self.num_samples) | |
match_idx = np.random.choice(self.num_samples, size=int(self.num_samples * match_fraction), replace=False) | |
predictions[match_idx] = test_labels[match_idx] | |
return predictions | |
def calculate_metric(metric, model_predictions, test_labels): | |
if metric == 'auc': | |
return roc_auc_score(test_labels, model_predictions) | |
elif metric == 'accuracy': | |
binary_predictions = (model_predictions >= 0.5).astype(int) | |
return accuracy_score(test_labels, binary_predictions) | |
else: | |
raise ValueError(f"Invalid metric: {metric}") | |
def print_model_results(model_name, metric, model_predictions, test_set_labels): | |
print(f"{model_name}:") | |
print(f"{metric.capitalize()}: {calculate_metric(metric, model_predictions, test_set_labels)}") | |
print() | |
def calculate_bootstrap_difference(metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=1000): | |
""" | |
Calculate the metric differences between two models using bootstrapping. | |
Parameters: | |
metric (str): The metric to calculate ('auc' or 'accuracy'). | |
model_1_predictions (numpy.ndarray): Model 1 predictions. | |
model_2_predictions (numpy.ndarray): Model 2 predictions. | |
test_set_labels (numpy.ndarray): Binary test labels. | |
n_resamples (int): Number of resamples in bootstrapping. | |
Returns: | |
numpy.ndarray: Array of differences in the metric for each bootstrapped sample. | |
float: Observed difference in the metric on the original test set. | |
""" | |
model_1_metric = calculate_metric(metric, model_1_predictions, test_set_labels) | |
model_2_metric = calculate_metric(metric, model_2_predictions, test_set_labels) | |
observed_difference = model_1_metric - model_2_metric | |
differences = np.empty(n_resamples) | |
n_samples = len(test_set_labels) | |
for i in range(n_resamples): | |
bootstrap_indices = np.random.choice(range(n_samples), size=n_samples, replace=True) | |
new_test_set_labels = test_set_labels[bootstrap_indices] | |
new_model_1_predictions = model_1_predictions[bootstrap_indices] | |
new_model_2_predictions = model_2_predictions[bootstrap_indices] | |
model_1_metric = calculate_metric(metric, new_model_1_predictions, new_test_set_labels) | |
model_2_metric = calculate_metric(metric, new_model_2_predictions, new_test_set_labels) | |
differences[i] = model_1_metric - model_2_metric | |
differences = differences - observed_difference | |
return differences, observed_difference | |
def calculate_p_value(differences, observed_difference): | |
return sum(np.abs(differences) >= np.abs(observed_difference)) / len(differences) | |
def interpret_p_value(p_value): | |
null_hypothesis = "There is no difference in the performance of the two models." | |
alternative_hypothesis = "There is a difference in the performance of the two models." | |
if p_value < 0.05: | |
print("Reject the null hypothesis in favor of the alternative hypothesis.") | |
print(f"{alternative_hypothesis} (p-value = {p_value:e})") | |
else: | |
print("Fail to reject the null hypothesis.") | |
print(f"{null_hypothesis} (p-value = {p_value:e})") | |
def plot_histogram(differences, observed_difference, metric): | |
plt.figure() | |
plt.hist(differences, bins='auto') | |
plt.axvline(observed_difference, color='r', linestyle='dashed', linewidth=2, label='Observed Difference') | |
plt.title(f'Histogram of Bootstrapped Differences ({metric.capitalize()})') | |
plt.xlabel(f'Difference in {metric.capitalize()}') | |
plt.ylabel('Frequency') | |
plt.legend() | |
plt.show() | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description="Compare two models using bootstrapping.") | |
parser.add_argument('--no-print', dest='print_results', action='store_false', default=True, help="Disable printing of model metrics and p-values.") | |
parser.add_argument('--metric', type=str, choices=['auc', 'accuracy'], default='auc', help="Metric to calculate (auc or accuracy).") | |
parser.add_argument('--num-samples', type=int, default=500, help="Number of samples in the test set.") | |
parser.add_argument('--n-resamples', type=int, default=1000, help="Number of resamples in bootstrapping.") | |
parser.add_argument('--plot-graph', action='store_true', default=False, help="Plot the histogram of bootstrapped differences.") | |
return parser.parse_args() | |
def main(): | |
args = parse_arguments() | |
# Create data simulator | |
simulator = DataSimulator(num_samples=args.num_samples) | |
# Generate simulated data | |
model_1_predictions, model_2_predictions, test_set_labels = simulator.generate_data() | |
# Print results for each model if specified | |
if args.print_results: | |
print_model_results("Model 1", args.metric, model_1_predictions, test_set_labels) | |
print_model_results("Model 2", args.metric, model_2_predictions, test_set_labels) | |
# Calculate bootstrap differences for the specified metric | |
differences, observed_difference = calculate_bootstrap_difference(args.metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=args.n_resamples) | |
# Calculate the p-value for the specified metric | |
p_value = calculate_p_value(differences, observed_difference) | |
# Interpret the p-value for the specified metric | |
if args.print_results: | |
print(f"{args.metric.upper()}:") | |
interpret_p_value(p_value) | |
# Plot histogram of the bootstrapped differences for the specified metric | |
if args.plot_graph: | |
plot_histogram(differences, observed_difference, metric=args.metric) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment