Skip to content

Instantly share code, notes, and snippets.

@morkrispil
Last active October 11, 2019 01:51
Show Gist options
  • Save morkrispil/66108f7f424fab1b5d8e1d78c2dd4543 to your computer and use it in GitHub Desktop.
Save morkrispil/66108f7f424fab1b5d8e1d78c2dd4543 to your computer and use it in GitHub Desktop.
Balancing the training dataset to a reported positive-negative class ratio, in the unseen dataset
import pandas as pd
def balance_train_ds(df_train, unseen_pos_rate, train_y_field):
df_train_pos = df_train[df_train[train_y_field] == 1]
df_train_neg = df_train[df_train[train_y_field] == 0]
p = df_train_pos.shape[0]
n = df_train_neg.shape[0]
train_pos_rate = float(p) / float(df_train.shape[0])
print 'train ds pos rate {0}, unseen ds reported pos rate {1}'.format(train_pos_rate, unseen_pos_rate)
# pos_rate = r1 = (p / (p + n))
# solving for r2, where r2 > r1, or r2 < r1, and we'd like to only add samples (pos or neg), not losing any
# using pands "sample" function with "replace=True" allows to sample more than the ds current size, if needed
r1 = train_pos_rate
r2 = unseen_pos_rate
if r2 < r1:
# add more neg samples
# solving balance for r2, where r2 < r1
# p / (p + n + balance) = r2
balance = int( (p - (r2 * p)- (r2 * n)) / r2 )
print 'duplicating {0} random negatives'.format(balance)
df_train = pd.concat([df_train, df_train_neg.sample(n=balance, replace=True)])
elif r2 > r1:
# add more pos samples
# solving balance for r2, where r2 > r1
# (p + x) / (p + x + n) = r2
balance = int( ((r2 * p) - p + (r2 * n)) / (1 - r2) )
print 'duplicating {0} random positives'.format(balance)
df_train = pd.concat([df_train, df_train_pos.sample(n=balance, replace=True)])
# re-check
df_train_pos = df_train[df_train[train_y_field] == 1]
df_train_neg = df_train[df_train[train_y_field] == 0]
train_pos_rate = float(df_train_pos.shape[0]) / float(df_train.shape[0])
print 'train ds re-balanced to {0}'.format(train_pos_rate)
return df_train
if __name__ == '__main__':
# set the reported unseen positive ratio, ant try
unseen_pos_rate = 0.12
df_train = pd.read_csv('train.csv', header=0)
df_train = balance_train_ds(df_train, unseen_pos_rate, 'is_duplicate')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment