Skip to content

Instantly share code, notes, and snippets.

@miracleyoo
Created November 23, 2018 03:13
Show Gist options
  • Save miracleyoo/fecb3151c712854e66086fa7b84788b5 to your computer and use it in GitHub Desktop.
Save miracleyoo/fecb3151c712854e66086fa7b84788b5 to your computer and use it in GitHub Desktop.
[K-fold train function on Keras] #python #keras
from sklearn.model_selection import KFold
from sklearn.metrics import *
def kf_fit(model, x_train=X, y_train=y, test_data=test):
kf = KFold(n_splits=10, shuffle=True, random_state=42069)
preds = []
# test_data = pad_sequences(test_data)
fold = 0
aucs = 0
for train_idx, val_idx in kf.split(x_train):
x_train_f = x_train[train_idx]
y_train_f = y_train[train_idx]
x_val_f = x_train[val_idx]
y_val_f = y_train[val_idx]
model.fit(x_train_f, y_train_f,
batch_size=256,
epochs=12,
verbose = 0,
validation_data=(x_val_f, y_val_f))
# Get accuracy of model on validation data. It's not AUC but it's something at least!
preds_val = model.predict([x_val_f], batch_size=512)
preds.append(model.predict(test_data))
fold+=1
fpr, tpr, thresholds = roc_curve(y_val_f, preds_val, pos_label=1)
aucs += auc(fpr,tpr)
print('Fold {}, AUC = {}'.format(fold,auc(fpr, tpr)))
print("Cross Validation AUC = {}".format(aucs/10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment