Skip to content

Instantly share code, notes, and snippets.

@MattEding
Created September 8, 2019 06:39
Show Gist options
  • Save MattEding/4714c5e1c137f214c6d02f304ff9da65 to your computer and use it in GitHub Desktop.
Save MattEding/4714c5e1c137f214c6d02f304ff9da65 to your computer and use it in GitHub Desktop.
SMOTE sampling strategy comparison - random vs evenly distributed minority class indices to oversample
import argparse
import functools
import numpy as np
import pandas as pd
from imblearn.datasets import fetch_datasets
import imblearn.datasets._zenodo as zenodo
from imblearn.metrics import specificity_score
from imblearn.over_sampling import SMOTE
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
def trial(name, sampling_strategy, k_neighbors, n_jobs):
dataset = fetch_datasets()[name]
X, y = dataset.data, dataset.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0)
if sampling_strategy != 'none':
smote = SMOTE(sampling_strategy, k_neighbors=k_neighbors, n_jobs=n_jobs, random_state=0)
X_train, y_train = smote.fit_resample(X_train, y_train)
logreg = LogisticRegression(solver='liblinear')
logreg.fit(X_train, y_train)
y_pred = logreg.predict(X_test)
return score(y_test, y_pred)
def score(y_true, y_pred):
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
specificity = specificity_score(y_true, y_pred)
return accuracy, precision, recall, specificity
def all_trials(sampling_strategy, k_neighbors, n_jobs):
accuracy = []
precision = []
recall = []
sensitivity = []
ids = range(1, 28)
index = pd.Index([zenodo.MAP_ID_NAME[i] for i in range(1, 28)], name='Name')
for i in ids:
name = zenodo.MAP_ID_NAME[i]
a, p, r, s = trial(
name, sampling_strategy, k_neighbors, n_jobs)
accuracy.append(a)
precision.append(p)
recall.append(r)
sensitivity.append(s)
df = pd.DataFrame(
dict(Accuracy=accuracy, Precision=precision, Recall=recall, Sensitivity=sensitivity),
index=index,
)
return df
def main():
parser = argparse.ArgumentParser(zenodo.__doc__)
parser.add_argument('dataset', help='zenodo dataset name or ID')
parser.add_argument('--n_jobs', '-j', type=int, help='n_jobs for SMOTE')
choices = ['minority', 'not majority', 'not majority', 'all', 'auto', 'none']
parser.add_argument('--sampling_strategy', '-s', default='auto', choices=choices, help='sampling_strategy for SMOTE')
parser.add_argument('--k_neighbors', '-k', default=5, type=int, help='k_neighbors for SMOTE')
parser.add_argument('--file', '-f', help='file to save results to')
args = parser.parse_args()
if args.dataset in ['0', 'all']:
result = all_trials(args.sampling_strategy, args.k_neighbors, args.n_jobs)
else:
try:
name = zenodo.MAP_ID_NAME[int(args.dataset)]
except Exception:
name = args.dataset
a, p, r, s = trial(name, args.sampling_strategy, args.k_neighbors, args.n_jobs)
result = pd.Series(dict(Accuracy=a, Precision=p, Recall=r, Sensitivity=s), name=name)
if args.file is not None:
result.to_pickle(f'../{args.file}.pkl')
else:
print(result)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment