-
-
Save grovduck/6877a76ba4f873b280890887fa636817 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass | |
from pathlib import Path | |
@dataclass | |
class NetworkFiles: | |
neighbor_filename: str | |
distance_filename: str | |
@dataclass | |
class Config: | |
tree_data_filename: str | |
tree_columns: list[str] | |
network_files: NetworkFiles | |
tree_id_field: str | |
plot_id_field: str | |
synthetic_id_field: str | |
k: int | |
d: int | |
r: float | |
output_filename: str | |
min_synthetic_id: int | None = None | |
max_synthetic_id: int | None = None | |
ROOT = Path("M:/research/synthetic_plots/aa") | |
QUANTILE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "quantile_mesh/quantile_mesh_k25_neighbors.csv", | |
distance_filename=ROOT / "quantile_mesh/quantile_mesh_k25_distances.csv", | |
) | |
REFERENCE_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "reference_network/reference_network_k25_neighbors.csv", | |
distance_filename=ROOT / "reference_network/reference_network_k25_distances.csv", | |
) | |
EQUAL_INTERVAL_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_neighbors.csv", | |
distance_filename=ROOT | |
/ "equal_interval_mesh/equal_interval_mesh_k25_distances.csv", | |
) | |
FUZZED_NETWORK = NetworkFiles( | |
neighbor_filename=ROOT / "fuzzed_network/fuzzed_network_k25_neighbors.csv", | |
distance_filename=ROOT / "fuzzed_network/fuzzed_network_k25_distances.csv", | |
) | |
QUANTILE_K10_D0_R1 = Config( | |
tree_data_filename="test_tree_data.parquet", | |
tree_columns=["LIVE_ID", "FCID", "SPP_SYMBOL", "DBH_CM", "TPH_FC"], | |
network_files=QUANTILE_NETWORK, | |
tree_id_field="LIVE_ID", | |
plot_id_field="FCID", | |
synthetic_id_field="SYNTHETIC_PLOT_ID", | |
k=10, | |
d=0, | |
r=1.0, | |
output_filename="quantile_k10_d0_r10.csv", | |
min_synthetic_id=1, | |
max_synthetic_id=10_000, | |
) | |
CONFIG_LOOKUP = { | |
"quantile_k10_d0_r10": QUANTILE_K10_D0_R1, | |
} |
# from dask.distributed import Client | |
# client = Client("tcp://127.0.0.1:60183") | |
import math | |
import click | |
import dask.dataframe as dd | |
import numpy as np | |
import pandas as pd | |
from config import CONFIG_LOOKUP, Config, NetworkFiles | |
# Trees-per-hectare thresholds for censoring trees, corresponding | |
# to the radii (in feet) of the microplot, subplot, and macroplot | |
# and across all four subplots. | |
THRESHOLDS = { | |
0: 1.0 / (4.0 * math.pi * ((6.8 * 0.3049) ** 2) / 10000), | |
1: 1.0 / (4.0 * math.pi * ((24.0 * 0.3049) ** 2) / 10000), | |
2: 1.0 / (4.0 * math.pi * ((58.9 * 0.3049) ** 2) / 10000), | |
} | |
def get_synthetic_network_data( | |
network: NetworkFiles, | |
) -> tuple[pd.DataFrame, pd.DataFrame]: | |
"""Read in neighbors and distances from the synthetic network.""" | |
neighbor_df = pd.read_csv(network.neighbor_filename) | |
distance_df = pd.read_csv(network.distance_filename) | |
return neighbor_df, distance_df | |
def melt_and_merge_network( | |
neighbor_df: pd.DataFrame, | |
distance_df: pd.DataFrame, | |
synthetic_id_field: str, | |
plot_id_field: str, | |
) -> dd.DataFrame: | |
"""Melt synthetic neighbors and distances and merge together.""" | |
def _melt(df: pd.DataFrame, value_name: str) -> pd.DataFrame: | |
melted_df = df.melt( | |
id_vars=[synthetic_id_field], var_name="NEIGHBOR", value_name=value_name | |
) | |
melted_df["NEIGHBOR"] = melted_df["NEIGHBOR"].str.replace("NN", "").astype(int) | |
return melted_df | |
neighbor_df = _melt(neighbor_df, value_name=plot_id_field) | |
distance_df = _melt(distance_df, value_name="DISTANCE") | |
return dd.from_pandas( | |
neighbor_df.merge(distance_df, on=[synthetic_id_field, "NEIGHBOR"]) | |
) | |
def get_weights_df(df: dd.DataFrame, k: int, d: int) -> dd.DataFrame: | |
"""Return neighbor weights based on gradient distance and weighting scheme(d).""" | |
weights = df.groupby("NEIGHBOR", as_index=False).DISTANCE.first().head(k) | |
weights["WEIGHT"] = (1.0 / weights.DISTANCE) ** d | |
weights["WEIGHT"] = weights.WEIGHT / weights.WEIGHT.sum() | |
return weights | |
def censor_by_threshold(df: dd.DataFrame, tph_field: str, r: float) -> dd.DataFrame: | |
"""Censor the tree data based on the TPH thresholds.""" | |
# TODO: Calculating each time - could be more efficient | |
r_thresholds = {k: v * r for k, v in THRESHOLDS.items()} | |
# Bin the DBH values into groups based on the microplot, subplot, and | |
# macroplot diameter thresholds | |
df["DBH_GROUP"] = np.digitize(df.DBH_CM, [12.7, 54.0]) | |
# Calulate the sum of TPH for each species and DBH group | |
grouped = ( | |
df.groupby(["SPP_SYMBOL", "DBH_GROUP"], as_index=False)[tph_field] | |
.agg("sum") | |
.rename(columns={tph_field: "SUM_TPH"}) | |
) | |
# Add a new field to indicate whether the species / DBH group met the | |
# censoring threshold | |
grouped["INCLUDE"] = np.where( | |
grouped.SUM_TPH >= grouped.DBH_GROUP.map(r_thresholds), 1, 0 | |
) | |
# Identify all species where at least one DBH group met the threshold | |
species = grouped.groupby("SPP_SYMBOL", as_index=False).INCLUDE.agg("max") | |
df = df.merge(species, on=["SPP_SYMBOL"]) | |
# Zero out the TPH associated with species that did not meet the threshold | |
df[f"{tph_field}_CENSORED"] = np.where(df.INCLUDE == 1, df[tph_field], 0.0) | |
return df | |
def weight_and_censor(df: dd.DataFrame, k: int, d: float, r: float) -> dd.DataFrame: | |
"""Weight and censor the tree data based on the k, d, and r parameters.""" | |
# Assign weights to records based on the k and d parameters | |
weight_df = get_weights_df(df, k, d) | |
df = df.merge(weight_df[["NEIGHBOR", "WEIGHT"]], on="NEIGHBOR") | |
# For plot neighbors that do not have treelists, fill in null values | |
# | |
# TODO: We're treating SPP_SYMBOL separately as the only string column. | |
# This needs to be generalized | |
df["SPP_SYMBOL"] = df["SPP_SYMBOL"].fillna("") | |
df = df.fillna(0) | |
# Calculate the modified TPH based on the weight and only retain | |
# records with a positive TPH. Note that the original network lists | |
# may contain records with neighbors all the way up to k=25, so this | |
# removes all reacords >= k | |
df[f"TPH_FC_K{k}"] = df.TPH_FC * df.WEIGHT | |
df = df[df[f"TPH_FC_K{k}"] > 0.0] | |
df = df.drop(columns="WEIGHT") | |
# Using the modified TPH, censor the data based on the r parameter | |
df = censor_by_threshold(df, f"TPH_FC_K{k}", r) | |
# Return only the records that are above the censoring threshold | |
return df[df[f"TPH_FC_K{k}_CENSORED"] > 0.0].drop(columns=["DBH_GROUP", "INCLUDE"]) | |
def _main(config_key: str) -> None: | |
# Set the configuration for this run | |
config: Config = CONFIG_LOOKUP[config_key] | |
# Read in the tree data and subset to just the columns we need | |
# Note that all plots that might be candidate neighbors are represented | |
# in this list, including plots with no tree records. In this case, | |
# all fields except the plot_id_field will be null. | |
tree_df = dd.read_parquet(config.tree_data_filename) | |
drop_columns = list(set(tree_df.columns) - set(config.tree_columns)) | |
tree_df = tree_df.drop(columns=drop_columns) | |
# Read in neighbors and distances from the synthetic network | |
syn_nn_df, syn_dist_df = get_synthetic_network_data(config.network_files) | |
# Subset the data if min_plot_id or max_plot_id are set | |
if config.min_synthetic_id or config.max_synthetic_id: | |
min_synthetic_id = config.min_synthetic_id or 1 | |
max_synthetic_id = ( | |
config.max_synthetic_id or syn_nn_df[config.synthetic_id_field].max() | |
) | |
def subset_df(df: pd.DataFrame) -> pd.Series: | |
conds = (df[config.synthetic_id_field] >= min_synthetic_id) & ( | |
df[config.synthetic_id_field] <= max_synthetic_id | |
) | |
return df[conds] | |
syn_nn_df = subset_df(syn_nn_df) | |
syn_dist_df = subset_df(syn_dist_df) | |
# Melt and merge the neighbors and distances | |
syn_df = melt_and_merge_network( | |
syn_nn_df, syn_dist_df, config.synthetic_id_field, config.plot_id_field | |
) | |
# Join the synthetic plot data and the tree data | |
syn_tree_df = syn_df.merge(tree_df, on=config.plot_id_field) | |
syn_tree_df = syn_tree_df.repartition(npartitions=30) | |
# Describe the data types for the apply function | |
meta = { | |
"NEIGHBOR": "int64", | |
config.plot_id_field: "int64", | |
"DISTANCE": "float64", | |
config.tree_id_field: "int64", | |
"SPP_SYMBOL": "string", | |
"DBH_CM": "float64", | |
"TPH_FC": "float64", | |
f"TPH_FC_K{config.k}": "float64", | |
f"TPH_FC_K{config.k}_CENSORED": "float64", | |
} | |
# Apply the k, d, and r parameters to the tree data | |
syn_tree_df = ( | |
syn_tree_df.groupby(config.synthetic_id_field) | |
.apply( | |
weight_and_censor, | |
k=config.k, | |
d=config.d, | |
r=config.r, | |
include_groups=False, | |
meta=meta, | |
) | |
.compute() | |
.reset_index() | |
.drop(columns="level_1") | |
) | |
# Change types before export | |
syn_tree_df = syn_tree_df.astype({config.tree_id_field: "int32"}) | |
# Set columns for export | |
columns = [ | |
config.synthetic_id_field, | |
config.tree_id_field, | |
f"TPH_FC_K{config.k}_CENSORED", | |
] | |
syn_tree_df.sort_values([config.synthetic_id_field])[columns].to_csv( | |
config.output_filename, index=False, float_format="%.4f" | |
) | |
@click.command() | |
@click.argument("config-key", type=str) | |
def main(config_key: str): | |
import time | |
start = time.time() | |
_main(config_key) | |
print(f"Elapsed time: {time.time() - start}") | |
if __name__ == "__main__": | |
import sys | |
main(sys.argv[1:]) |
What are your thoughts on using polars? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.
We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well.
My first impression of polars
is really positive - nice API, great docs, and obviously the performance speaks for itself. It automatically does a lot of the same parallelization and task graph optimization as Dask, and has at least partial support (so far) for streaming larger-than-memory data. I think for an internal tool or end-to-end applications like this (i.e. files go in, files come out), I would 100% support switching, given how those performance improvements will potentially scale as you mentioned.
Aside from the time we would need to spend learning a new package, my only reservation to switching to polars
would be cases where users would need to pass in dataframes or work with intermediate outputs, just because pandas
is so ubiquitous that everyone's already familiar with it. It's easy to move between the two formats (with pl.from_pandas
or pl.DataFrame.to_pandas
), but it obviously adds a little complexity and overhead if we're switching back and forth.
I guess I lean towards polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking with pandas
either, since there are some strong arguments there in terms of compatibility.
I think next steps are to get synthetic_knn.datasets ready with the FIADB plots and then move this work over into synthetic-knn?
Sounds like a good plan!
I guess I lean towards
polars
, especially if this is primarily going to be used by us, but I'm not opposed to sticking withpandas
either, since there are some strong arguments there in terms of compatibility.
I can definitely see the use case for polars
here if for no other reason than speed. I haven't yet hit you with the second part of this, which is to take the tree-level data and calculate stand-level attributes from it, but I'm guessing it will similarly benefit from polars
. For now, I think synthetic-knn
will be mostly an internal package, but perhaps it will be worth bringing in the crosswalking to pandas
at some later date.
For now, I think we can sit on this until I get the synthetic_knn.datasets
PR in place.
@aazuspan, this is bonkers! So cool to see. The refactoring really didn't look too painful -
with_columns
definitely looks like a "go-to" method!I tested on the full 100,000 synthetic plot IDs and my timings from the
no-apply
branch to thepolars
branch went from 41.9s to 6.2s, so between a 6x and 7x speedup. What are your thoughts on usingpolars
? To me, the code looks fairly straightforward to understand and obviously relies quite a bit on chaining. But, as you say, that's also a big move to switch/mix the libraries, so it would be good to think through the ramifications.We might be running this code hundreds of thousands of times with one of our projects, so that might factor into the decision as well. I think next steps are to get
synthetic_knn.datasets
ready with the FIADB plots and then move this work over intosynthetic-knn
?Thanks so much for pushing on this!