-
-
Save grovduck/6877a76ba4f873b280890887fa636817 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass | |
from pathlib import Path | |
@dataclass | |
class NetworkFiles: | |
neighbor_filename: str | |
distance_filename: str | |
@dataclass | |
class Config: | |
tree_data_filename: str | |
tree_columns: list[str] | |
network_files: NetworkFiles | |
tree_id_field: str | |
plot_id_field: str | |
synthetic_id_field: str | |
k: int | |
d: int | |
r: float | |
output_filename: str | |
min_synthetic_id: int | None = None | |
max_synthetic_id: int | None = None | |
ROOT = Path("M:/research/synthetic_plots/aa") | |
QUANTILE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "quantile_mesh/quantile_mesh_k25_neighbors.csv", | |
distance_filename=ROOT / "quantile_mesh/quantile_mesh_k25_distances.csv", | |
) | |
REFERENCE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "reference_network/reference_network_k25_neighbors.csv", | |
distance_filename=ROOT / "reference_network/reference_network_k25_distances.csv", | |
) | |
EQUAL_INTERVAL_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_neighbors.csv", | |
distance_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_distances.csv", | |
) | |
FUZZED_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "fuzzed_network/fuzzed_network_k25_neighbors.csv", | |
distance_filename=ROOT / "fuzzed_network/fuzzed_network_k25_distances.csv", | |
) | |
QUANTILE_K10_D0_R1 = Config( | |
tree_data_filename="test_tree_data.parquet", | |
tree_columns=["LIVE_ID", "FCID", "SPP_SYMBOL", "DBH_CM", "TPH_FC"], | |
network_files=QUANTILE_NETWORK, | |
tree_id_field="LIVE_ID", | |
plot_id_field="FCID", | |
synthetic_id_field="SYNTHETIC_PLOT_ID", | |
k=10, | |
d=0, | |
r=1.0, | |
output_filename="quantile_k10_d0_r10.csv", | |
min_synthetic_id=1, | |
max_synthetic_id=10_000, | |
) | |
CONFIG_LOOKUP = { | |
"quantile_k10_d0_r10": QUANTILE_K10_D0_R1, | |
} |
# from dask.distributed import Client | |
# client = Client("tcp://127.0.0.1:60183") | |
import math | |
import click | |
import dask.dataframe as dd | |
import numpy as np | |
import pandas as pd | |
from config import CONFIG_LOOKUP, Config, NetworkFiles | |
# Trees-per-hectare thresholds for censoring trees, corresponding | |
# to the radii (in feet) of the microplot, subplot, and macroplot | |
# and across all four subplots. | |
THRESHOLDS = { | |
0: 1.0 / (4.0 * math.pi * ((6.8 * 0.3049) ** 2) / 10000), | |
1: 1.0 / (4.0 * math.pi * ((24.0 * 0.3049) ** 2) / 10000), | |
2: 1.0 / (4.0 * math.pi * ((58.9 * 0.3049) ** 2) / 10000), | |
} | |
def get_synthetic_network_data( | |
network: NetworkFiles, | |
) -> tuple[pd.DataFrame, pd.DataFrame]: | |
"""Read in neighbors and distances from the synthetic network.""" | |
neighbor_df = pd.read_csv(network.neighbor_filename) | |
distance_df = pd.read_csv(network.distance_filename) | |
return neighbor_df, distance_df | |
def melt_and_merge_network( | |
neighbor_df: pd.DataFrame, | |
distance_df: pd.DataFrame, | |
synthetic_id_field: str, | |
plot_id_field: str, | |
) -> dd.DataFrame: | |
"""Melt synthetic neighbors and distances and merge together.""" | |
def _melt(df: pd.DataFrame, value_name: str) -> pd.DataFrame: | |
melted_df = df.melt( | |
id_vars=[synthetic_id_field], var_name="NEIGHBOR", value_name=value_name | |
) | |
melted_df["NEIGHBOR"] = melted_df["NEIGHBOR"].str.replace("NN", "").astype(int) | |
return melted_df | |
neighbor_df = _melt(neighbor_df, value_name=plot_id_field) | |
distance_df = _melt(distance_df, value_name="DISTANCE") | |
return dd.from_pandas( | |
neighbor_df.merge(distance_df, on=[synthetic_id_field, "NEIGHBOR"]) | |
) | |
def get_weights_df(df: dd.DataFrame, k: int, d: int) -> dd.DataFrame: | |
"""Return neighbor weights based on gradient distance and weighting scheme(d).""" | |
weights = df.groupby("NEIGHBOR", as_index=False).DISTANCE.first().head(k) | |
weights["WEIGHT"] = (1.0 / weights.DISTANCE) ** d | |
weights["WEIGHT"] = weights.WEIGHT / weights.WEIGHT.sum() | |
return weights | |
def censor_by_threshold(df: dd.DataFrame, tph_field: str, r: float) -> dd.DataFrame: | |
"""Censor the tree data based on the TPH thresholds.""" | |
# TODO: Calculating each time - could be more efficient | |
r_thresholds = {k: v * r for k, v in THRESHOLDS.items()} | |
# Bin the DBH values into groups based on the microplot, subplot, and | |
# macroplot diameter thresholds | |
df["DBH_GROUP"] = np.digitize(df.DBH_CM, [12.7, 54.0]) | |
# Calulate the sum of TPH for each species and DBH group | |
grouped = ( | |
df.groupby(["SPP_SYMBOL", "DBH_GROUP"], as_index=False)[tph_field] | |
.agg("sum") | |
.rename(columns={tph_field: "SUM_TPH"}) | |
) | |
# Add a new field to indicate whether the species / DBH group met the | |
# censoring threshold | |
grouped["INCLUDE"] = np.where( | |
grouped.SUM_TPH >= grouped.DBH_GROUP.map(r_thresholds), 1, 0 | |
) | |
# Identify all species where at least one DBH group met the threshold | |
species = grouped.groupby("SPP_SYMBOL", as_index=False).INCLUDE.agg("max") | |
df = df.merge(species, on=["SPP_SYMBOL"]) | |
# Zero out the TPH associated with species that did not meet the threshold | |
df[f"{tph_field}_CENSORED"] = np.where(df.INCLUDE == 1, df[tph_field], 0.0) | |
return df | |
def weight_and_censor(df: dd.DataFrame, k: int, d: float, r: float) -> dd.DataFrame: | |
"""Weight and censor the tree data based on the k, d, and r parameters.""" | |
# Assign weights to records based on the k and d parameters | |
weight_df = get_weights_df(df, k, d) | |
df = df.merge(weight_df[["NEIGHBOR", "WEIGHT"]], on="NEIGHBOR") | |
# For plot neighbors that do not have treelists, fill in null values | |
# | |
# TODO: We're treating SPP_SYMBOL separately as the only string column. | |
# This needs to be generalized | |
df["SPP_SYMBOL"] = df["SPP_SYMBOL"].fillna("") | |
df = df.fillna(0) | |
# Calculate the modified TPH based on the weight and only retain | |
# records with a positive TPH. Note that the original network lists | |
# may contain records with neighbors all the way up to k=25, so this | |
# removes all reacords >= k | |
df[f"TPH_FC_K{k}"] = df.TPH_FC * df.WEIGHT | |
df = df[df[f"TPH_FC_K{k}"] > 0.0] | |
df = df.drop(columns="WEIGHT") | |
# Using the modified TPH, censor the data based on the r parameter | |
df = censor_by_threshold(df, f"TPH_FC_K{k}", r) | |
# Return only the records that are above the censoring threshold | |
return df[df[f"TPH_FC_K{k}_CENSORED"] > 0.0].drop(columns=["DBH_GROUP", "INCLUDE"]) | |
def _main(config_key: str) -> None: | |
# Set the configuration for this run | |
config: Config = CONFIG_LOOKUP[config_key] | |
# Read in the tree data and subset to just the columns we need | |
# Note that all plots that might be candidate neighbors are represented | |
# in this list, including plots with no tree records. In this case, | |
# all fields except the plot_id_field will be null. | |
tree_df = dd.read_parquet(config.tree_data_filename) | |
drop_columns = list(set(tree_df.columns) - set(config.tree_columns)) | |
tree_df = tree_df.drop(columns=drop_columns) | |
# Read in neighbors and distances from the synthetic network | |
syn_nn_df, syn_dist_df = get_synthetic_network_data(config.network_files) | |
# Subset the data if min_plot_id or max_plot_id are set | |
if config.min_synthetic_id or config.max_synthetic_id: | |
min_synthetic_id = config.min_synthetic_id or 1 | |
max_synthetic_id = ( | |
config.max_synthetic_id or syn_nn_df[config.synthetic_id_field].max() | |
) | |
def subset_df(df: pd.DataFrame) -> pd.Series: | |
conds = (df[config.synthetic_id_field] >= min_synthetic_id) & ( | |
df[config.synthetic_id_field] <= max_synthetic_id | |
) | |
return df[conds] | |
syn_nn_df = subset_df(syn_nn_df) | |
syn_dist_df = subset_df(syn_dist_df) | |
# Melt and merge the neighbors and distances | |
syn_df = melt_and_merge_network( | |
syn_nn_df, syn_dist_df, config.synthetic_id_field, config.plot_id_field | |
) | |
# Join the synthetic plot data and the tree data | |
syn_tree_df = syn_df.merge(tree_df, on=config.plot_id_field) | |
syn_tree_df = syn_tree_df.repartition(npartitions=30) | |
# Describe the data types for the apply function | |
meta = { | |
"NEIGHBOR": "int64", | |
config.plot_id_field: "int64", | |
"DISTANCE": "float64", | |
config.tree_id_field: "int64", | |
"SPP_SYMBOL": "string", | |
"DBH_CM": "float64", | |
"TPH_FC": "float64", | |
f"TPH_FC_K{config.k}": "float64", | |
f"TPH_FC_K{config.k}_CENSORED": "float64", | |
} | |
# Apply the k, d, and r parameters to the tree data | |
syn_tree_df = ( | |
syn_tree_df.groupby(config.synthetic_id_field) | |
.apply( | |
weight_and_censor, | |
k=config.k, | |
d=config.d, | |
r=config.r, | |
include_groups=False, | |
meta=meta, | |
) | |
.compute() | |
.reset_index() | |
.drop(columns="level_1") | |
) | |
# Change types before export | |
syn_tree_df = syn_tree_df.astype({config.tree_id_field: "int32"}) | |
# Set columns for export | |
columns = [ | |
config.synthetic_id_field, | |
config.tree_id_field, | |
f"TPH_FC_K{config.k}_CENSORED", | |
] | |
syn_tree_df.sort_values([config.synthetic_id_field])[columns].to_csv( | |
config.output_filename, index=False, float_format="%.4f" | |
) | |
@click.command() | |
@click.argument("config-key", type=str) | |
def main(config_key: str): | |
import time | |
start = time.time() | |
_main(config_key) | |
print(f"Elapsed time: {time.time() - start}") | |
if __name__ == "__main__": | |
import sys | |
main(sys.argv[1:]) |
@aazuspan, that is completely bonkers that you reduced the time so much! Thank you so much for looking at this. I'm just starting up the review of the estimator wrapper, but will see what you've done here and hopefully follow along!
@aazuspan, just got a chance to go through the code commit by commit. So cool. All of your incremental changes made really good sense, but I'm blown away that more or less taking all the logic that was in the apply
out of the apply
makes it this much faster. I don't think I've really ever used groupby.transform
too much, so that was cool to see how you did that, even on something like normalizing the weights.
If you want to try running these directly, I put the data files in a
./data
directory in the repo root, and ranmain
first to make the reference file.
I don't actually see this directory in the repo, but maybe better that we don't put it up there. I've just created a data
directory locally and will continue to test with that.
@grovduck I had some down time while waiting for models to train, so I took a stab at re-writing this with Polars, which you can find on the polars
branch. That turned out to be a pretty painless process considering I'd never used it before, and it brings the execution time down from 4s to 0.8s EDIT: 0.7s with some optimizations!
I'm not sure that's enough of a speedup to justify mixing or switching dataframe libraries, but it might be worth a closer look if the pandas
code turned out to be a significant bottleneck.
@aazuspan, this is bonkers! So cool to see. The refactoring really didn't look too painful - with_columns
definitely looks like a "go-to" method!
I tested on the full 100,000 synthetic plot IDs and my timings from the no-apply
branch to the polars
branch went from 41.9s to 6.2s, so between a 6x and 7x speedup. What are your thoughts on using polars
? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.
We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well. I think next steps are to get synthetic_knn.datasets
ready with the FIADB plots and then move this work over into synthetic-knn
?
Thanks so much for pushing on this!
What are your thoughts on using polars? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.
We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well.
My first impression of polars
is really positive - nice API, great docs, and obviously the performance speaks for itself. It automatically does a lot of the same parallelization and task graph optimization as Dask, and has at least partial support (so far) for streaming larger-than-memory data. I think for an internal tool or end-to-end applications like this (i.e. files go in, files come out), I would 100% support switching, given how those performance improvements will potentially scale as you mentioned.
Aside from the time we would need to spend learning a new package, my only reservation to switching to polars
would be cases where users would need to pass in dataframes or work with intermediate outputs, just because pandas
is so ubiquitous that everyone's already familiar with it. It's easy to move between the two formats (with pl.from_pandas
or pl.DataFrame.to_pandas
), but it obviously adds a little complexity and overhead if we're switching back and forth.
I guess I lean towards polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking with pandas
either, since there are some strong arguments there in terms of compatibility.
I think next steps are to get synthetic_knn.datasets ready with the FIADB plots and then move this work over into synthetic-knn?
Sounds like a good plan!
I guess I lean towards
polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking withpandas
either, since there are some strong arguments there in terms of compatibility.
I can definitely see the use case for polars
here if for no other reason than speed. I haven't yet hit you with the second part of this, which is to take the tree-level data and calculate stand-level attributes from it, but I'm guessing it will similarly benefit from polars
. For now, I think synthetic-knn
will be mostly an internal package, but perhaps it will be worth bringing in the crosswalking to pandas
at some later date.
For now, I think we can sit on this until I get the synthetic_knn.datasets
PR in place.
@grovduck, I spent some time optimizing things at the Pandas level and managed to get the time down from ~45s with Dask to ~4s without Dask. The main change was avoiding the
apply
call toweight_and_censor
by vectorizing things as much as possible. My understanding is thatapply
is inherently slow, and Dask can't do much to improve that.I don't think there's a way to make a PR against a Gist, so I made a private repo to track the changes (didn't push any data, just code!). There are
34 branches:main
: Your reference implementation that I used for comparison, with a few changes for repeatability1. Runs in about 45s.no-dask
: Same asmain
but without any Dask, just to keep things simple. Runs in about 76s.no-apply
: Same asno-dask
but vectorized. Runs in about 4s.polars
: Re-implementedno-apply
in Polars. Runs in about 0.8s.The
no-apply
implementation is very rough around the edges, and was just a quick proof of concept. I tried to be very granular with the commits though, so it should hopefully be easy to track what I changed and why. Fair warning, I only tested with theQUANTILE_K10_D1_R1
config, so it's totally possible that it will break with other configs. I didn't try adding Dask back in yet, but that would be good to test.If you want to try running these directly, I put the data files in a
./data
directory in the repo root, and ranmain
first to make the reference file.p.s. I'm still curious to try this in Polars, which claims to be substantially faster than Dask. I may poke around with that...Footnotes
I made a couple changes to your code to set up the testing, like fully sorting the output, not timing the CSV write, and using a new linear distance weighted config just to make sure I didn't screw up the weighting. ↩