Last active
October 11, 2019 01:51
-
-
Save morkrispil/66108f7f424fab1b5d8e1d78c2dd4543 to your computer and use it in GitHub Desktop.
Balancing the training dataset to a reported positive-negative class ratio, in the unseen dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
def balance_train_ds(df_train, unseen_pos_rate, train_y_field): | |
df_train_pos = df_train[df_train[train_y_field] == 1] | |
df_train_neg = df_train[df_train[train_y_field] == 0] | |
p = df_train_pos.shape[0] | |
n = df_train_neg.shape[0] | |
train_pos_rate = float(p) / float(df_train.shape[0]) | |
print 'train ds pos rate {0}, unseen ds reported pos rate {1}'.format(train_pos_rate, unseen_pos_rate) | |
# pos_rate = r1 = (p / (p + n)) | |
# solving for r2, where r2 > r1, or r2 < r1, and we'd like to only add samples (pos or neg), not losing any | |
# using pands "sample" function with "replace=True" allows to sample more than the ds current size, if needed | |
r1 = train_pos_rate | |
r2 = unseen_pos_rate | |
if r2 < r1: | |
# add more neg samples | |
# solving balance for r2, where r2 < r1 | |
# p / (p + n + balance) = r2 | |
balance = int( (p - (r2 * p)- (r2 * n)) / r2 ) | |
print 'duplicating {0} random negatives'.format(balance) | |
df_train = pd.concat([df_train, df_train_neg.sample(n=balance, replace=True)]) | |
elif r2 > r1: | |
# add more pos samples | |
# solving balance for r2, where r2 > r1 | |
# (p + x) / (p + x + n) = r2 | |
balance = int( ((r2 * p) - p + (r2 * n)) / (1 - r2) ) | |
print 'duplicating {0} random positives'.format(balance) | |
df_train = pd.concat([df_train, df_train_pos.sample(n=balance, replace=True)]) | |
# re-check | |
df_train_pos = df_train[df_train[train_y_field] == 1] | |
df_train_neg = df_train[df_train[train_y_field] == 0] | |
train_pos_rate = float(df_train_pos.shape[0]) / float(df_train.shape[0]) | |
print 'train ds re-balanced to {0}'.format(train_pos_rate) | |
return df_train | |
if __name__ == '__main__': | |
# set the reported unseen positive ratio, ant try | |
unseen_pos_rate = 0.12 | |
df_train = pd.read_csv('train.csv', header=0) | |
df_train = balance_train_ds(df_train, unseen_pos_rate, 'is_duplicate') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment