Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Created June 28, 2020 22:52
Show Gist options
  • Save nbertagnolli/71cb0fd786557ff714ba0144a6f520ee to your computer and use it in GitHub Desktop.
Save nbertagnolli/71cb0fd786557ff714ba0144a6f520ee to your computer and use it in GitHub Desktop.
def split_on_date(data: pd.DataFrame, train_percent: float=0.9, seed: int=1234):
"""Splits a DataFrame into train and validation sets based on the date.
Args:
data: The data we want to split. It must contain a date column.
train_percent: The percent of data to use for training
seed: The random seed to use for selecting the sets
Returns:
data: A DataFrame with a new split column with values 'train' and 'val'.
"""
dates = set(data["date"].tolist())
dates_df = pd.DataFrame(dates, columns=["date"])
np.random.seed(seed)
dates_df["split"] = np.random.choice(["train", "val"], dates_df.shape[0],p=[train_percent, 1 - train_percent])
return data.merge(dates_df, on="date")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment