Created
September 8, 2019 06:39
-
-
Save MattEding/4714c5e1c137f214c6d02f304ff9da65 to your computer and use it in GitHub Desktop.
SMOTE sampling strategy comparison - random vs evenly distributed minority class indices to oversample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import functools | |
import numpy as np | |
import pandas as pd | |
from imblearn.datasets import fetch_datasets | |
import imblearn.datasets._zenodo as zenodo | |
from imblearn.metrics import specificity_score | |
from imblearn.over_sampling import SMOTE | |
from sklearn.exceptions import UndefinedMetricWarning | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.metrics import accuracy_score | |
from sklearn.metrics import precision_score | |
from sklearn.metrics import recall_score | |
from sklearn.model_selection import GridSearchCV | |
from sklearn.model_selection import train_test_split | |
def trial(name, sampling_strategy, k_neighbors, n_jobs): | |
dataset = fetch_datasets()[name] | |
X, y = dataset.data, dataset.target | |
X_train, X_test, y_train, y_test = train_test_split( | |
X, y, stratify=y, random_state=0) | |
if sampling_strategy != 'none': | |
smote = SMOTE(sampling_strategy, k_neighbors=k_neighbors, n_jobs=n_jobs, random_state=0) | |
X_train, y_train = smote.fit_resample(X_train, y_train) | |
logreg = LogisticRegression(solver='liblinear') | |
logreg.fit(X_train, y_train) | |
y_pred = logreg.predict(X_test) | |
return score(y_test, y_pred) | |
def score(y_true, y_pred): | |
accuracy = accuracy_score(y_true, y_pred) | |
precision = precision_score(y_true, y_pred) | |
recall = recall_score(y_true, y_pred) | |
specificity = specificity_score(y_true, y_pred) | |
return accuracy, precision, recall, specificity | |
def all_trials(sampling_strategy, k_neighbors, n_jobs): | |
accuracy = [] | |
precision = [] | |
recall = [] | |
sensitivity = [] | |
ids = range(1, 28) | |
index = pd.Index([zenodo.MAP_ID_NAME[i] for i in range(1, 28)], name='Name') | |
for i in ids: | |
name = zenodo.MAP_ID_NAME[i] | |
a, p, r, s = trial( | |
name, sampling_strategy, k_neighbors, n_jobs) | |
accuracy.append(a) | |
precision.append(p) | |
recall.append(r) | |
sensitivity.append(s) | |
df = pd.DataFrame( | |
dict(Accuracy=accuracy, Precision=precision, Recall=recall, Sensitivity=sensitivity), | |
index=index, | |
) | |
return df | |
def main(): | |
parser = argparse.ArgumentParser(zenodo.__doc__) | |
parser.add_argument('dataset', help='zenodo dataset name or ID') | |
parser.add_argument('--n_jobs', '-j', type=int, help='n_jobs for SMOTE') | |
choices = ['minority', 'not majority', 'not majority', 'all', 'auto', 'none'] | |
parser.add_argument('--sampling_strategy', '-s', default='auto', choices=choices, help='sampling_strategy for SMOTE') | |
parser.add_argument('--k_neighbors', '-k', default=5, type=int, help='k_neighbors for SMOTE') | |
parser.add_argument('--file', '-f', help='file to save results to') | |
args = parser.parse_args() | |
if args.dataset in ['0', 'all']: | |
result = all_trials(args.sampling_strategy, args.k_neighbors, args.n_jobs) | |
else: | |
try: | |
name = zenodo.MAP_ID_NAME[int(args.dataset)] | |
except Exception: | |
name = args.dataset | |
a, p, r, s = trial(name, args.sampling_strategy, args.k_neighbors, args.n_jobs) | |
result = pd.Series(dict(Accuracy=a, Precision=p, Recall=r, Sensitivity=s), name=name) | |
if args.file is not None: | |
result.to_pickle(f'../{args.file}.pkl') | |
else: | |
print(result) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment