Skip to content

Instantly share code, notes, and snippets.

@SilvaEmerson
Created May 20, 2019 12:58
Show Gist options
  • Save SilvaEmerson/23b7d54f72fd8f470f1b3bfdd60b6a98 to your computer and use it in GitHub Desktop.
Save SilvaEmerson/23b7d54f72fd8f470f1b3bfdd60b6a98 to your computer and use it in GitHub Desktop.
"""
Stratified K-Fold implementation
"""
def stratified_k_fold(arr, k=None, class_ratio=.5):
if k == None:
return None
number_elements_fold = len(arr) // k
folds = [[]] * k
zeros_ = [*filter(lambda el: not el[-1], arr)]
ones_ = [*filter(lambda el: el[-1], arr)]
ones_ratio = int(class_ratio * number_elements_fold)
zeros_ratio = number_elements_fold - ones_ratio
for fold_ind in range(k):
folds[fold_ind] = [*ones_[: ones_ratio], *zeros_[: zeros_ratio]]
del ones_[: ones_ratio]
del zeros_[: zeros_ratio]
if len(ones_) + len(zeros_) < number_elements_fold:
folds[fold_ind] = [*folds[fold_ind], *ones_, *zeros_]
return folds
if __name__ == '__main__':
pass
import unittest
from random import shuffle
from functools import reduce
from collections import Counter
import StratifiedKFold as SKF
class MainTest(unittest.TestCase):
def setUp(self):
self.class_ratio = .5
labels = [*[1] * 11, *[0] * 10]
shuffle(labels)
self.arr = [*zip(range(21), labels)]
def test_should_return_None(self):
self.assertIsNone(SKF.stratified_k_fold(self.arr))
def test_should_return_4_folds(self):
result_len = len(SKF.stratified_k_fold(self.arr, k=4))
self.assertEqual(result_len, 4)
def test_sould_not_return_even_one_empty_fold(self):
result = SKF.stratified_k_fold(self.arr, k=4)
self.assertTrue(all(result))
def test_sould_return_same_amount_of_elements(self):
result = SKF.stratified_k_fold(self.arr, k=4)
total = reduce(lambda acc, curr: acc + len(curr), result, 0)
self.assertEqual(total, len(self.arr))
def test_self_class_ratio_should_be_equal_as_passed(self):
result = SKF.stratified_k_fold(self.arr, k=4, class_ratio=self.class_ratio)
ratios = []
for fold in result:
fold_class_ratio = sum([*map(lambda el: el[-1], fold)]) / len(fold)
ratios.append(round(fold_class_ratio, 1))
most_common_ratio = Counter(ratios).most_common(1)[0][0]
self.assertLessEqual(most_common_ratio, self.class_ratio)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment