-
-
Save grovduck/6877a76ba4f873b280890887fa636817 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass | |
from pathlib import Path | |
@dataclass | |
class NetworkFiles: | |
neighbor_filename: str | |
distance_filename: str | |
@dataclass | |
class Config: | |
tree_data_filename: str | |
tree_columns: list[str] | |
network_files: NetworkFiles | |
tree_id_field: str | |
plot_id_field: str | |
synthetic_id_field: str | |
k: int | |
d: int | |
r: float | |
output_filename: str | |
min_synthetic_id: int | None = None | |
max_synthetic_id: int | None = None | |
ROOT = Path("M:/research/synthetic_plots/aa") | |
QUANTILE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "quantile_mesh/quantile_mesh_k25_neighbors.csv", | |
distance_filename=ROOT / "quantile_mesh/quantile_mesh_k25_distances.csv", | |
) | |
REFERENCE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "reference_network/reference_network_k25_neighbors.csv", | |
distance_filename=ROOT / "reference_network/reference_network_k25_distances.csv", | |
) | |
EQUAL_INTERVAL_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_neighbors.csv", | |
distance_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_distances.csv", | |
) | |
FUZZED_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "fuzzed_network/fuzzed_network_k25_neighbors.csv", | |
distance_filename=ROOT / "fuzzed_network/fuzzed_network_k25_distances.csv", | |
) | |
QUANTILE_K10_D0_R1 = Config( | |
tree_data_filename="test_tree_data.parquet", | |
tree_columns=["LIVE_ID", "FCID", "SPP_SYMBOL", "DBH_CM", "TPH_FC"], | |
network_files=QUANTILE_NETWORK, | |
tree_id_field="LIVE_ID", | |
plot_id_field="FCID", | |
synthetic_id_field="SYNTHETIC_PLOT_ID", | |
k=10, | |
d=0, | |
r=1.0, | |
output_filename="quantile_k10_d0_r10.csv", | |
min_synthetic_id=1, | |
max_synthetic_id=10_000, | |
) | |
CONFIG_LOOKUP = { | |
"quantile_k10_d0_r10": QUANTILE_K10_D0_R1, | |
} |
# from dask.distributed import Client | |
# client = Client("tcp://127.0.0.1:60183") | |
import math | |
import click | |
import dask.dataframe as dd | |
import numpy as np | |
import pandas as pd | |
from config import CONFIG_LOOKUP, Config, NetworkFiles | |
# Trees-per-hectare thresholds for censoring trees, corresponding | |
# to the radii (in feet) of the microplot, subplot, and macroplot | |
# and across all four subplots. | |
THRESHOLDS = { | |
0: 1.0 / (4.0 * math.pi * ((6.8 * 0.3049) ** 2) / 10000), | |
1: 1.0 / (4.0 * math.pi * ((24.0 * 0.3049) ** 2) / 10000), | |
2: 1.0 / (4.0 * math.pi * ((58.9 * 0.3049) ** 2) / 10000), | |
} | |
def get_synthetic_network_data( | |
network: NetworkFiles, | |
) -> tuple[pd.DataFrame, pd.DataFrame]: | |
"""Read in neighbors and distances from the synthetic network.""" | |
neighbor_df = pd.read_csv(network.neighbor_filename) | |
distance_df = pd.read_csv(network.distance_filename) | |
return neighbor_df, distance_df | |
def melt_and_merge_network( | |
neighbor_df: pd.DataFrame, | |
distance_df: pd.DataFrame, | |
synthetic_id_field: str, | |
plot_id_field: str, | |
) -> dd.DataFrame: | |
"""Melt synthetic neighbors and distances and merge together.""" | |
def _melt(df: pd.DataFrame, value_name: str) -> pd.DataFrame: | |
melted_df = df.melt( | |
id_vars=[synthetic_id_field], var_name="NEIGHBOR", value_name=value_name | |
) | |
melted_df["NEIGHBOR"] = melted_df["NEIGHBOR"].str.replace("NN", "").astype(int) | |
return melted_df | |
neighbor_df = _melt(neighbor_df, value_name=plot_id_field) | |
distance_df = _melt(distance_df, value_name="DISTANCE") | |
return dd.from_pandas( | |
neighbor_df.merge(distance_df, on=[synthetic_id_field, "NEIGHBOR"]) | |
) | |
def get_weights_df(df: dd.DataFrame, k: int, d: int) -> dd.DataFrame: | |
"""Return neighbor weights based on gradient distance and weighting scheme(d).""" | |
weights = df.groupby("NEIGHBOR", as_index=False).DISTANCE.first().head(k) | |
weights["WEIGHT"] = (1.0 / weights.DISTANCE) ** d | |
weights["WEIGHT"] = weights.WEIGHT / weights.WEIGHT.sum() | |
return weights | |
def censor_by_threshold(df: dd.DataFrame, tph_field: str, r: float) -> dd.DataFrame: | |
"""Censor the tree data based on the TPH thresholds.""" | |
# TODO: Calculating each time - could be more efficient | |
r_thresholds = {k: v * r for k, v in THRESHOLDS.items()} | |
# Bin the DBH values into groups based on the microplot, subplot, and | |
# macroplot diameter thresholds | |
df["DBH_GROUP"] = np.digitize(df.DBH_CM, [12.7, 54.0]) | |
# Calulate the sum of TPH for each species and DBH group | |
grouped = ( | |
df.groupby(["SPP_SYMBOL", "DBH_GROUP"], as_index=False)[tph_field] | |
.agg("sum") | |
.rename(columns={tph_field: "SUM_TPH"}) | |
) | |
# Add a new field to indicate whether the species / DBH group met the | |
# censoring threshold | |
grouped["INCLUDE"] = np.where( | |
grouped.SUM_TPH >= grouped.DBH_GROUP.map(r_thresholds), 1, 0 | |
) | |
# Identify all species where at least one DBH group met the threshold | |
species = grouped.groupby("SPP_SYMBOL", as_index=False).INCLUDE.agg("max") | |
df = df.merge(species, on=["SPP_SYMBOL"]) | |
# Zero out the TPH associated with species that did not meet the threshold | |
df[f"{tph_field}_CENSORED"] = np.where(df.INCLUDE == 1, df[tph_field], 0.0) | |
return df | |
def weight_and_censor(df: dd.DataFrame, k: int, d: float, r: float) -> dd.DataFrame: | |
"""Weight and censor the tree data based on the k, d, and r parameters.""" | |
# Assign weights to records based on the k and d parameters | |
weight_df = get_weights_df(df, k, d) | |
df = df.merge(weight_df[["NEIGHBOR", "WEIGHT"]], on="NEIGHBOR") | |
# For plot neighbors that do not have treelists, fill in null values | |
# | |
# TODO: We're treating SPP_SYMBOL separately as the only string column. | |
# This needs to be generalized | |
df["SPP_SYMBOL"] = df["SPP_SYMBOL"].fillna("") | |
df = df.fillna(0) | |
# Calculate the modified TPH based on the weight and only retain | |
# records with a positive TPH. Note that the original network lists | |
# may contain records with neighbors all the way up to k=25, so this | |
# removes all reacords >= k | |
df[f"TPH_FC_K{k}"] = df.TPH_FC * df.WEIGHT | |
df = df[df[f"TPH_FC_K{k}"] > 0.0] | |
df = df.drop(columns="WEIGHT") | |
# Using the modified TPH, censor the data based on the r parameter | |
df = censor_by_threshold(df, f"TPH_FC_K{k}", r) | |
# Return only the records that are above the censoring threshold | |
return df[df[f"TPH_FC_K{k}_CENSORED"] > 0.0].drop(columns=["DBH_GROUP", "INCLUDE"]) | |
def _main(config_key: str) -> None: | |
# Set the configuration for this run | |
config: Config = CONFIG_LOOKUP[config_key] | |
# Read in the tree data and subset to just the columns we need | |
# Note that all plots that might be candidate neighbors are represented | |
# in this list, including plots with no tree records. In this case, | |
# all fields except the plot_id_field will be null. | |
tree_df = dd.read_parquet(config.tree_data_filename) | |
drop_columns = list(set(tree_df.columns) - set(config.tree_columns)) | |
tree_df = tree_df.drop(columns=drop_columns) | |
# Read in neighbors and distances from the synthetic network | |
syn_nn_df, syn_dist_df = get_synthetic_network_data(config.network_files) | |
# Subset the data if min_plot_id or max_plot_id are set | |
if config.min_synthetic_id or config.max_synthetic_id: | |
min_synthetic_id = config.min_synthetic_id or 1 | |
max_synthetic_id = ( | |
config.max_synthetic_id or syn_nn_df[config.synthetic_id_field].max() | |
) | |
def subset_df(df: pd.DataFrame) -> pd.Series: | |
conds = (df[config.synthetic_id_field] >= min_synthetic_id) & ( | |
df[config.synthetic_id_field] <= max_synthetic_id | |
) | |
return df[conds] | |
syn_nn_df = subset_df(syn_nn_df) | |
syn_dist_df = subset_df(syn_dist_df) | |
# Melt and merge the neighbors and distances | |
syn_df = melt_and_merge_network( | |
syn_nn_df, syn_dist_df, config.synthetic_id_field, config.plot_id_field | |
) | |
# Join the synthetic plot data and the tree data | |
syn_tree_df = syn_df.merge(tree_df, on=config.plot_id_field) | |
syn_tree_df = syn_tree_df.repartition(npartitions=30) | |
# Describe the data types for the apply function | |
meta = { | |
"NEIGHBOR": "int64", | |
config.plot_id_field: "int64", | |
"DISTANCE": "float64", | |
config.tree_id_field: "int64", | |
"SPP_SYMBOL": "string", | |
"DBH_CM": "float64", | |
"TPH_FC": "float64", | |
f"TPH_FC_K{config.k}": "float64", | |
f"TPH_FC_K{config.k}_CENSORED": "float64", | |
} | |
# Apply the k, d, and r parameters to the tree data | |
syn_tree_df = ( | |
syn_tree_df.groupby(config.synthetic_id_field) | |
.apply( | |
weight_and_censor, | |
k=config.k, | |
d=config.d, | |
r=config.r, | |
include_groups=False, | |
meta=meta, | |
) | |
.compute() | |
.reset_index() | |
.drop(columns="level_1") | |
) | |
# Change types before export | |
syn_tree_df = syn_tree_df.astype({config.tree_id_field: "int32"}) | |
# Set columns for export | |
columns = [ | |
config.synthetic_id_field, | |
config.tree_id_field, | |
f"TPH_FC_K{config.k}_CENSORED", | |
] | |
syn_tree_df.sort_values([config.synthetic_id_field])[columns].to_csv( | |
config.output_filename, index=False, float_format="%.4f" | |
) | |
@click.command() | |
@click.argument("config-key", type=str) | |
def main(config_key: str): | |
import time | |
start = time.time() | |
_main(config_key) | |
print(f"Elapsed time: {time.time() - start}") | |
if __name__ == "__main__": | |
import sys | |
main(sys.argv[1:]) |
@grovduck, I managed to get things working without any issues locally. Haven't had a chance to give the code a close look yet, but I'll start poking around!
My understanding was that you can run Dask locally using threads and it should "just work" based on this page. But when I run for the first 10,000 sample plots as shown in config.py, it takes almost 6 minutes to run (without using LocalCluster), whereas if I uncomment lines 1-3 for a LocalCluster that is running on my machine, it takes about 52 seconds to run.
Theoretically, you're seeing the performance difference between using threads and a local cluster. My experience is that the latter does tend to be faster, but 6 minutes for the threaded version seems excessive! I tested different scheduler options on my machine and got:
threads
: 122sprocesses
: 119ssingle-threaded
: 97sClient()
: 47s
The fact that single-threaded outperforms multi-threaded and multi-processed seems odd - maybe that's indicative of something. Not sure why my threads run 3x faster than yours though...
@grovduck, I spent some time optimizing things at the Pandas level and managed to get the time down from ~45s with Dask to ~4s without Dask. The main change was avoiding the apply
call to weight_and_censor
by vectorizing things as much as possible. My understanding is that apply
is inherently slow, and Dask can't do much to improve that.
I don't think there's a way to make a PR against a Gist, so I made a private repo to track the changes (didn't push any data, just code!). There are 3 4 branches:
main
: Your reference implementation that I used for comparison, with a few changes for repeatability1. Runs in about 45s.no-dask
: Same asmain
but without any Dask, just to keep things simple. Runs in about 76s.no-apply
: Same asno-dask
but vectorized. Runs in about 4s.polars
: Re-implementedno-apply
in Polars. Runs in about 0.8s.
The no-apply
implementation is very rough around the edges, and was just a quick proof of concept. I tried to be very granular with the commits though, so it should hopefully be easy to track what I changed and why. Fair warning, I only tested with the QUANTILE_K10_D1_R1
config, so it's totally possible that it will break with other configs. I didn't try adding Dask back in yet, but that would be good to test.
If you want to try running these directly, I put the data files in a ./data
directory in the repo root, and ran main
first to make the reference file.
p.s. I'm still curious to try this in Polars, which claims to be substantially faster than Dask. I may poke around with that...
Footnotes
-
I made a couple changes to your code to set up the testing, like fully sorting the output, not timing the CSV write, and using a new linear distance weighted config just to make sure I didn't screw up the weighting. ↩
@aazuspan, that is completely bonkers that you reduced the time so much! Thank you so much for looking at this. I'm just starting up the review of the estimator wrapper, but will see what you've done here and hopefully follow along!
@aazuspan, just got a chance to go through the code commit by commit. So cool. All of your incremental changes made really good sense, but I'm blown away that more or less taking all the logic that was in the apply
out of the apply
makes it this much faster. I don't think I've really ever used groupby.transform
too much, so that was cool to see how you did that, even on something like normalizing the weights.
If you want to try running these directly, I put the data files in a
./data
directory in the repo root, and ranmain
first to make the reference file.
I don't actually see this directory in the repo, but maybe better that we don't put it up there. I've just created a data
directory locally and will continue to test with that.
@grovduck I had some down time while waiting for models to train, so I took a stab at re-writing this with Polars, which you can find on the polars
branch. That turned out to be a pretty painless process considering I'd never used it before, and it brings the execution time down from 4s to 0.8s EDIT: 0.7s with some optimizations!
I'm not sure that's enough of a speedup to justify mixing or switching dataframe libraries, but it might be worth a closer look if the pandas
code turned out to be a significant bottleneck.
@aazuspan, this is bonkers! So cool to see. The refactoring really didn't look too painful - with_columns
definitely looks like a "go-to" method!
I tested on the full 100,000 synthetic plot IDs and my timings from the no-apply
branch to the polars
branch went from 41.9s to 6.2s, so between a 6x and 7x speedup. What are your thoughts on using polars
? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.
We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well. I think next steps are to get synthetic_knn.datasets
ready with the FIADB plots and then move this work over into synthetic-knn
?
Thanks so much for pushing on this!
What are your thoughts on using polars? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.
We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well.
My first impression of polars
is really positive - nice API, great docs, and obviously the performance speaks for itself. It automatically does a lot of the same parallelization and task graph optimization as Dask, and has at least partial support (so far) for streaming larger-than-memory data. I think for an internal tool or end-to-end applications like this (i.e. files go in, files come out), I would 100% support switching, given how those performance improvements will potentially scale as you mentioned.
Aside from the time we would need to spend learning a new package, my only reservation to switching to polars
would be cases where users would need to pass in dataframes or work with intermediate outputs, just because pandas
is so ubiquitous that everyone's already familiar with it. It's easy to move between the two formats (with pl.from_pandas
or pl.DataFrame.to_pandas
), but it obviously adds a little complexity and overhead if we're switching back and forth.
I guess I lean towards polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking with pandas
either, since there are some strong arguments there in terms of compatibility.
I think next steps are to get synthetic_knn.datasets ready with the FIADB plots and then move this work over into synthetic-knn?
Sounds like a good plan!
I guess I lean towards
polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking withpandas
either, since there are some strong arguments there in terms of compatibility.
I can definitely see the use case for polars
here if for no other reason than speed. I haven't yet hit you with the second part of this, which is to take the tree-level data and calculate stand-level attributes from it, but I'm guessing it will similarly benefit from polars
. For now, I think synthetic-knn
will be mostly an internal package, but perhaps it will be worth bringing in the crosswalking to pandas
at some later date.
For now, I think we can sit on this until I get the synthetic_knn.datasets
PR in place.
@aazuspan
A week ago Friday, we touched on the
synthetic-knn
work and I thought you'd likely be able to help me with the dask DataFrame work. As of right now, I have this standalone script that hasn't been incorporated into thesynthetic-knn
repo, but just as this Gist.It's still very rough, but I'd love some feedback on it, especially around optimization with
dask
. Chunking / npartitions are still a bit of a mystery to me - for now, I've setnpartitions
equal to 30 (there are 36 available on my machine) on line 163 when I repartition the data before applying the weighting and censoring to the treelists.There are a few data files that are on
M:
that are needed to run this:Tree file for all FIA annual trees in parquet format:
M:/research/synthetic_plots/aa/test_tree_data.parquet
Neighbor file for quantile mesh that we are experimenting with n=100,000 neighbors:
M:/research/synthetic_plots/aa/quantile_mesh/quantile_mesh_k25_neighbors.csv
Distance file for the same:
M:/research/synthetic_plots/aa/quantile_mesh/quantile_mesh_k25_distances.csv
The
config.py
file above is a hacky hard-wired configuration while I'm initially testing. You should be able to setROOT
to whatever your version ofM:/research/synthetic_plots/aa
will be.One thing that I'm not understanding. My understanding was that you can run Dask locally using threads and it should "just work" based on this page. But when I run for the first 10,000 sample plots as shown in
config.py
, it takes almost 6 minutes to run (without usingLocalCluster
), whereas if I uncomment lines 1-3 for aLocalCluster
that is running on my machine, it takes about 52 seconds to run.Please let me know if you need more information - admittedly this is a bit crude right now.