Skip to content

Instantly share code, notes, and snippets.

@MovsisyanM
Created May 18, 2022 12:10
Show Gist options
  • Save MovsisyanM/618a3dcabfb8b3c7af9a9f4f5e8b93eb to your computer and use it in GitHub Desktop.
Save MovsisyanM/618a3dcabfb8b3c7af9a9f4f5e8b93eb to your computer and use it in GitHub Desktop.
Balanced generator for 0, 1 binary classification problems
def equigen(
x: pd.DataFrame,
y: pd.Series,
batch_size: int = 256,
seed: int = seed,
preproc: callable = None,
preproc_kwargs: dict = {}):
"""Balanced generator for 0, 1 binary classification problems"""
np.random.seed(seed)
# Negative observations
x_n = x[np.array(y == 0)].copy()
x_n["_label"] = 0
# Positive observations
x_p = x[np.array(y) == 1].copy()
x_p["_label"] = 1
while True:
sample = pd.concat([
x_n.sample(int(batch_size/2), random_state=seed),
x_p.sample(int(batch_size/2), random_state=seed)
]).sample(frac=1, random_state=seed)
sample, labels = preproc(sample, **preproc_kwargs) if preproc else (sample, labels)
yield sample, labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment