Created
April 22, 2013 18:21
-
-
Save bmcfee/5437267 to your computer and use it in GitHub Desktop.
This function is a wrapper for sklearn.cross_validation.Stratified*. Sometimes, you have data where multiple samples are related to the same object (eg: several audio clips from the same song), and blindly partitioning the data into train/test without accounting for this can bias your estimator. This wrapper allows you to specify a "meta id" for…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def filtered_stratified_split(ids, splitter, Y, **kwargs): | |
'''Cross-validation split filtration. Ensures that points of the same meta-id end up on the same side of the split | |
input: | |
ids: n-by-1 mapping of data points to meta-id | |
splitter: handle to the cross-validation class (eg, StratifiedShuffleSplit) | |
Y: n-by-1 vector of class labels | |
**kwargs: arguments to the cross-validation class | |
yields: | |
(train, test) indices | |
''' | |
n = len(Y) | |
indices = ('indices' in kwargs) and (kwargs['indices']) | |
kwargs['indices'] = True | |
def unfold(meta_ids, X_id, indices): | |
split_ids = [] | |
for i in meta_ids: | |
split_ids.extend(X_id[i]) | |
split_ids = numpy.array(split_ids) | |
if not indices: | |
z = numpy.zeros(n, dtype=bool) | |
z[split_ids] = True | |
return z | |
# 1: make a new label vector Yid | |
X_id = [] | |
Y_id = [] | |
last_id = None | |
for i in xrange(len(ids)): | |
if i > 0 and last_id == ids[i]: | |
X_id[-1].append(i) | |
else: | |
last_id = ids[i] | |
X_id.append([i]) | |
Y_id.append(Y[i]) | |
# 2: CV split on Yid | |
splits = splitter(Y_id, **kwargs) | |
# 3: Map CV indices back to Y space | |
for meta_train, meta_test in splits: | |
yield (unfold(meta_train, X_id, indices), | |
unfold(meta_test, X_id, indices)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment