-
-
Save r-brink/aa19f3a2201525495bc03cd0c3f53942 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import polars as pl | |
SIZE = int(100_000_000) | |
np.random.seed(42) | |
df = pl.DataFrame( | |
{ | |
"a": np.random.normal(size=SIZE), | |
"b": np.random.random(size=SIZE), | |
} | |
) | |
pandas_df = df.to_pandas() | |
def rolling_median_10k_window(df: pl.DataFrame): | |
times = [] | |
window_sizes = [1, 10, 100, 1000, 2000, 3000, 5000, 7500, 10_000] | |
for window_size in window_sizes: | |
t0 = time.time() | |
df.with_columns( | |
rolling_median=pl.col("a").rolling_median(window_size=window_size) | |
) | |
t = time.time() - t0 | |
times.append(t) | |
results_df = pl.DataFrame( | |
{"Window Size": window_sizes, "Time": times, "Version": pl.__version__} | |
) | |
return results_df.write_csv(f"rolling_median_v{pl.__version__}_10k.csv") | |
def run_rolling_median_test(df): | |
times = [] | |
window_sizes = [10, 100, 500, 1000, 2000, 3000, 4000, 5000] | |
# Check if the DataFrame is a Polars DataFrame | |
if isinstance(df, pl.DataFrame): | |
for window_size in window_sizes: | |
t0 = time.time() | |
df.with_columns(pl.col("a").rolling_median(window_size=window_size)) | |
t = time.time() - t0 | |
times.append(t) | |
version = pl.__version__ | |
data_type = "Polars" | |
# Check if the DataFrame is a Pandas DataFrame | |
elif isinstance(df, pd.DataFrame): | |
for window_size in window_sizes: | |
t0 = time.time() | |
df["a"].rolling(window=window_size).median() | |
t = time.time() - t0 | |
times.append(t) | |
version = pd.__version__ | |
data_type = "Pandas" | |
else: | |
raise TypeError( | |
"The DataFrame must be either a Polars DataFrame or a Pandas DataFrame" | |
) | |
results_df = pl.DataFrame( | |
{"Window Size": window_sizes, "Time": times, "Version": version} | |
) | |
return results_df.write_csv(f"rolling_median_{data_type}_v{version}.csv") | |
def plot_results(): | |
df = pl.read_csv("*.csv") | |
for version, group in df.group_by("Version"): | |
plt.plot(group["Window Size"], group["Time"], marker="o", label=version) | |
plt.title("Benchmarking Rolling Median Improvements") | |
plt.xlabel("Window Size") | |
plt.ylabel("Time Taken (s)") | |
plt.legend() | |
plt.grid(True) | |
return plt.savefig("rolling_median_performance.png") | |
run_rolling_median_test(df) | |
run_rolling_median_test(pandas_df) | |
run_rolling_median_test(df) | |
plot_results() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@MarcoGorelli We decided to include up to 0.20.2 for this post, otherwise the blog became too long and we probably didn't finish it before the next release 😅 . I am already excited for the next post, because than we can play around with the new plotting feature!