Skip to content

Instantly share code, notes, and snippets.

@mirth
Created October 14, 2017 15:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mirth/64f7aaf7f9002b1f877dc40dbbedffd0 to your computer and use it in GitHub Desktop.
Save mirth/64f7aaf7f9002b1f877dc40dbbedffd0 to your computer and use it in GitHub Desktop.
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
prod_to_cat = pd.read_csv('levchik_folds/prod_to_category.csv')
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=4242)
split = skf.split(prod_to_cat.drop('category_id', axis=1), prod_to_cat['category_id'])
for i, (train_index, test_index) in enumerate(split):
train, test = prod_to_cat.iloc[train_index], prod_to_cat.iloc[test_index]
train.to_csv(os.path.join('levchik_folds', '%d_train.csv' % i), sep=',', index=False)
test.to_csv(os.path.join('levchik_folds', '%d_test.csv' % i), sep=',', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment