Skip to content

Instantly share code, notes, and snippets.

@ivallesp
Created April 16, 2023 16:13
Show Gist options
  • Save ivallesp/1137a456af22001f644edfcb21e9110b to your computer and use it in GitHub Desktop.
Save ivallesp/1137a456af22001f644edfcb21e9110b to your computer and use it in GitHub Desktop.
Example of simple test for covariate shift
import pandas as pd
import numpy as np
import scipy as sp
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
from scipy.stats import ttest_1samp
np.random.seed(655321)
# Prepare a toy dataset
x_train = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_train.csv")
x_test = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_test.csv")
x = pd.concat([x_train, x_test]) # Concatenate train and test sets
x = x.drop("median_house_value", axis=1) # Drop target variable
y = np.arange(len(x)) >= len(x_train) # Create the is_test target variable
x, y = shuffle(x, y)
# Determine the AUC(ROC) score at classifying each instance as train or test.
# Values >> 0.5 imply that covariate shift is likely a problem. Values close to
# 0.5 mean that there may not be covariate shift.
folds_aucs = cross_val_score(RandomForestClassifier(), x, y, scoring="roc_auc", cv=10, n_jobs=-1)
p_value = ttest_1samp(folds_aucs, 0.5).pvalue
print(f"AUC(ROC) = {folds_aucs.mean():.02f} ± {folds_aucs.std():.02f}")
print(f"p-value (H0: AUC(ROC) mean is 0.5): {p_value:.02}")
# Determine the importance of variables at classifying train and test.
m = RandomForestClassifier(n_jobs=-1).fit(x, y)
df_importances = (
pd.DataFrame({"variable": x.columns, "importance": m.feature_importances_})
.sort_values(by="importance", ascending=False)
)
print("Importance of variables:")
print(df_importances.to_string(index=False))
# _________________________________________________
# AUC(ROC) = 0.51 ± 0.02
# p-value (H0: AUC(ROC) mean is 0.5): 0.34
# Importance of Variables
# variable importance
# median_income 0.137577
# population 0.137418
# total_rooms 0.135345
# total_bedrooms 0.126998
# households 0.126550
# longitude 0.122192
# latitude 0.118497
# housing_median_age 0.095424
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment