Skip to content

Instantly share code, notes, and snippets.

@SHi-ON
Created March 10, 2019 17:08
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
An expirement to show how stratify option works
# Experiment to confirm the effect of stratify option in Scikit Learn, tran_test_split() method.
# by Shayan Amani
from sklearn.model_selection import train_test_split
import pandas as pd
raw_data = pd.read_csv("codebase/adrel/dataset/train.csv")
cnt = raw_data.groupby('label').count()
''' experiment begins '''
''' Part One: stratify is ON '''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed, stratify=raw_data['label'])
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
# assert that all train label classes has 90% of raw data
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9000484027105518
0.9000853970964987
0.8999281781182668
0.9000229832222477
0.900049115913556
0.8998682476943346
0.8999274836838289
0.9000227221086117
0.9000738370662072
'''
''' Part Two: stratify is OFF'''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed)
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9010164569215876
0.8936806148590948
0.8889154895858271
Traceback (most recent call last):
File "/home/shi-on/anaconda3/envs/PyON36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3267, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-115-5835d3554147>", line 4, in <module>
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
AssertionError: Ratio is not following the rules 2
'''
@tatevkaren5
Copy link

Hi, I applied your approach on my rating data: "train_data, test_data = train_test_split(rating_data, test_size=test_size, stratify= rating_data['reviewerID'])" , but it gives the following error: "ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2."
Is there any way that I can apply the same function to split my rating data into train and test such that each users 80% of reviews goes to the training set and 20% to test set? Thank you in advance!

@ShojibDE
Copy link

Hi, I applied your approach on my rating data: "train_data, test_data = train_test_split(rating_data, test_size=test_size, stratify= rating_data['reviewerID'])" , but it gives the following error: "ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2." Is there any way that I can apply the same function to split my rating data into train and test such that each users 80% of reviews goes to the training set and 20% to test set? Thank you in advance!

me also having same error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment