Skip to content

Instantly share code, notes, and snippets.

@wassname
Created November 29, 2019 08:03
Show Gist options
  • Save wassname/f9c3a9ee00d67e82413979dfe76b096f to your computer and use it in GitHub Desktop.
Save wassname/f9c3a9ee00d67e82413979dfe76b096f to your computer and use it in GitHub Desktop.
unbalance_dask_dataframe.py
%pylab inline
import pandas as pd
import dask.dataframe as dd
def get_unbal_df(size = 100, balance=None):
"""Get a randomly unbalanced df"""
if balance is None:
balance = np.random.randint(-100, 100)
if balance<0:
data = [0]*abs(balance * size) + [1] * size
else:
data = [1]*abs(balance * size) + [0] * size
random.shuffle(data)
df = pd.DataFrame(data, columns=['label'])
return df
def balance_ddf(df, label_col, balance):
"""
Balance a dask dataframe however you want
- df has two classes in label col e.g. 0 and 1 (alphebetical or numberic order)
- balance: how many more of the first class
- label_col: name of label col
"""
groups = df.groupby([label_col])
a = groups.get_group(0)
b = groups.get_group(1)
la = len(a)
lb = len(b)
sizes = [la//balance, lb]
min_len = min(sizes)
a = a.head(min_len*balance, compute=False, npartitions=-1)
b = b.head(min_len, compute=False, npartitions=-1)
return dd.concat([a, b]).sample(frac=1).repartition(npartitions=5)
df = get_unbal_df()
ddf = dd.from_pandas(df, npartitions=5)
print('label_mean', df['label'].mean())
target_bal = 2
target_mean = 1/(target_bal+1)
df2 = balance_ddf(ddf, 'label', target_bal).compute()
assert df2['label'].mean()==target_mean
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment