Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active February 22, 2017 23:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/6bddbbcca71bdf9fd37e8495d70b42e8 to your computer and use it in GitHub Desktop.
Save jnothman/6bddbbcca71bdf9fd37e8495d70b42e8 to your computer and use it in GitHub Desktop.
Scikit-learn resampling as CV wrapper
import numpy as np
class Resample(object):
def __init__(self, cv, method='under'):
self.cv = cv
self.method = method
def split(self, X, y, **kwargs):
for train_idx, test_idx in self.cv.split(X, y, **kwargs):
counts = np.bincount(y[train_idx]) # assumes y are from {0, 1..., n_classes-1}
sampled_train_idx = []
if self.method == 'under':
per_class = counts.min()
elif self.method == 'over':
per_class = counts.max()
else:
raise ValueError()
for k in range(len(counts)):
k_idx = train_idx[y[train_idx] == k]
sampled_train_idx.extend(k_idx[np.random.randint(0, counts[k], size=per_class)])
yield np.array(sampled_train_idx), test_idx
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold
from resample import Resample
# classes distributed 20%, 50%, 20%, 10%:
y = np.digitize(np.random.rand(1000), [.2, .7, .9, 1])
X = np.random.rand(len(y), 1)
# print train_idx size, train_idx class distribution, test_idx class distribution:
def cv_distrib(cv):
return [(len(train_idx), np.bincount(y[train_idx]) / len(train_idx), np.bincount(y[test_idx]) / len(test_idx))
for train_idx, test_idx in cv.split(X, y)]
from pprint import pprint
pprint(cv_distrib(KFold(3)))
# Output:
# [(666,
# array([ 0.17717718, 0.51951952, 0.2012012 , 0.1021021 ]),
# array([ 0.19760479, 0.46706587, 0.23353293, 0.10179641])),
# (667,
# array([ 0.1904048 , 0.48575712, 0.21589205, 0.10794603]),
# array([ 0.17117117, 0.53453453, 0.2042042 , 0.09009009])),
# (667,
# array([ 0.1844078 , 0.50074963, 0.21889055, 0.09595202]),
# array([ 0.18318318, 0.5045045 , 0.1981982 , 0.11411411]))]
pprint(cv_distrib(StratifiedKFold(3)))
# Output:
# [(665,
# array([ 0.18345865, 0.50225564, 0.21203008, 0.10225564]),
# array([ 0.18507463, 0.50149254, 0.2119403 , 0.10149254])),
# (667,
# array([ 0.1844078 , 0.50224888, 0.2113943 , 0.10194903]),
# array([ 0.18318318, 0.5015015 , 0.21321321, 0.1021021 ])),
# (668,
# array([ 0.18413174, 0.50149701, 0.21257485, 0.10179641]),
# array([ 0.18373494, 0.50301205, 0.21084337, 0.10240964]))]
pprint(cv_distrib(Resample(StratifiedKFold(3), 'over')))
# Output:
# [(1336,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18507463, 0.50149254, 0.2119403 , 0.10149254])),
# (1340,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18318318, 0.5015015 , 0.21321321, 0.1021021 ])),
# (1340,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18373494, 0.50301205, 0.21084337, 0.10240964]))]
pprint(cv_distrib(Resample(StratifiedKFold(3), 'under')))
# Output:
# [(272,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18507463, 0.50149254, 0.2119403 , 0.10149254])),
# (272,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18318318, 0.5015015 , 0.21321321, 0.1021021 ])),
# (272,
# array([ 0.25, 0.25, 0.25, 0.25]),
# array([ 0.18373494, 0.50301205, 0.21084337, 0.10240964]))]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment