Skip to content

Instantly share code, notes, and snippets.

@zilto
Last active July 28, 2023 20:10
Show Gist options
  • Save zilto/d4b8dc6fc0244d149b9a626c7b9634fb to your computer and use it in GitHub Desktop.
Save zilto/d4b8dc6fc0244d149b9a626c7b9634fb to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
from hamilton.function_modifiers import extract_columns, save_to, source, check_output
TRIPS_SOURCE_COLUMNS = [
"event_timestamp",
"driver_id",
"rider_id",
"trip_dist",
"created",
]
# extract columns allows you to split a dataframe into multiple pandas Series
@extract_columns(*TRIPS_SOURCE_COLUMNS)
def trips_raw(trips_raw_path: str) -> pd.DataFrame:
"""Load the driver dataset"""
df = pd.read_parquet(trips_raw_path)
df = df.sort_values(by="event_timestamp")
return df
def day_of_week(event_timestamp: pd.Series) -> pd.Series:
"""Encode day of the week as int"""
return pd.Series(
event_timestamp.dt.day_of_week, name="day_of_week", index=event_timestamp.index
)
# see how this function depends on the return value of `day_of_week()`
@check_output(data_type=np.int64, data_in_range(0, 1), importance="warn")
def is_weekend(day_of_week: pd.Series) -> pd.Series:
weekend = np.where(day_of_week >= 5, 1, 0)
return pd.Series(weekend, name="is_weekend", index=day_of_week.index)
def percentile_dist_rolling_3h(trip_dist: pd.Series, event_timestamp: pd.Series) -> pd.Series:
"""Compute the rolling 3H percentile trip dist"""
df = pd.concat([trip_dist, event_timestamp], axis=1)
agg = df.rolling("3H", on="event_timestamp")["trip_dist"].rank(pct=True)
return pd.Series(agg, name="percentile_trip_dist_rolling_3h", index=event_timestamp.index)
# this function has many lines, but it simply explicitly assemble columns from the raw
# source and the computed features.
# the @save_to decorator allows to easily save this result to a parquet file
@save_to.parquet(path=source("trips_stats_3h_path"), output_name_="save_trips_stats_3h")
def trips_stats_3h(
event_timestamp: pd.Series,
driver_id: pd.Series,
day_of_week: pd.Series,
is_weekend: pd.Series,
percentile_dist_rolling_3h: pd.Series,
) -> pd.DataFrame:
"""Global trip statistics over rolling 3h"""
df = pd.concat(
[
event_timestamp,
driver_id,
day_of_week,
is_weekend,
percentile_dist_rolling_3h,
],
axis=1,
)
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment