Skip to content

Instantly share code, notes, and snippets.

@raven4752
Last active June 11, 2018 11:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raven4752/ee9ca307e3e6e96de36957332b6e5a06 to your computer and use it in GitHub Desktop.
Save raven4752/ee9ca307e3e6e96de36957332b6e5a06 to your computer and use it in GitHub Desktop.
cross validating model with multiple input/output
def multi_array_shuffle(*arrays, random_state=1):
array_length = arrays[0].shape[0]
permutated = []
np.random.seed(random_state)
permutation = np.random.permutation(array_length)
for array in arrays:
permutated.append(array[permutation, ...])
return permutated
def cv_model_func(model_func, inputs, targets, scores_func, label=None, seed=1, fold=5, **kwargs):
"""
cross validating model with multiple inputs /multiple output
:param seed: seed of the split
:param label: label used to do straitify
:param scores_func: scores function to report
:param targets: list of target array
:param inputs: list of feature matrix
:param fold: fold of cross validating if set to 1, do train test split
:param model_func: model to cross validate with a fit() and a predict() method
:return: list of performances
"""
num_sample = targets[0].shape[0]
inputs = multi_array_shuffle(*inputs, random_state=seed)
targets = multi_array_shuffle(*targets, random_state=seed)
if label is None:
kfold = KFold(n_splits=fold, random_state=seed, shuffle=False)
elif callable(label):
label = label(targets)
kfold = StratifiedKFold(n_splits=fold, random_state=seed, shuffle=False)
else:
kfold = StratifiedKFold(n_splits=fold, random_state=seed, shuffle=False)
holder = np.zeros([num_sample, 1])
scores = []
for train_index, test_index in kfold.split(holder, label):
input_tr = []
input_te = []
target_tr = []
target_te = []
for feature in inputs:
input_tr.append(feature[train_index, ...])
input_te.append(feature[test_index, ...])
for target in targets:
target_tr.append(target[train_index, ...])
target_te.append(target[test_index, ...])
model = model_func(input_tr, target_tr, **kwargs)
if 'predict_param' in kwargs:
predictions = model.predict(input_te, **kwargs['predict_param'])
else:
predictions = model.predict(input_te)
model = None
scores.append(scores_func(target_te, predictions))
return np.array(scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment