Skip to content

Instantly share code, notes, and snippets.

@bilzard
Created February 19, 2023 06:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bilzard/71f4c4723ede10f3f009be407d16e686 to your computer and use it in GitHub Desktop.
Save bilzard/71f4c4723ede10f3f009be407d16e686 to your computer and use it in GitHub Desktop.
Fold Split by Polars
def split_fold(
df: pl.DataFrame,
target_col="label",
group_cols=["group"],
num_folds=5,
seed=42,
) -> pl.DataFrame:
"""
- add column `_fold` to input DataFrame
"""
df = df.with_columns(pl.col("contact").cumcount().cast(int).alias("_index"))
skf = StratifiedGroupKFold(n_splits=num_folds, random_state=seed, shuffle=True)
generator = skf.split(df, df[target_col], df.select(group_cols))
idx2fold: Dict[int, int] = {}
for fold, (_, idx_valid) in enumerate(generator):
idx2fold = {**idx2fold, **{idx: fold for idx in idx_valid}}
df = df.with_columns(pl.col("_index").map_dict(idx2fold).alias("_fold")).drop("_index")
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment