Skip to content

Instantly share code, notes, and snippets.

@sixy6e
Last active May 16, 2024 23:39
Show Gist options
  • Save sixy6e/c963d98391085dff932b1d9ab51b0e4c to your computer and use it in GitHub Desktop.
Save sixy6e/c963d98391085dff932b1d9ab51b0e4c to your computer and use it in GitHub Desktop.
blockmedian
A toy script to replicate (somewhat) the functionality of the
[GMT Blockmedian](https://docs.generic-mapping-tools.org/latest/blockmedian.html) utility.
The main difference is that a grid is pre-determined, which avoided the grid edge "buffering" GMT does.
"""
A toy prototype that mimic's the GMT blockmedian utility.
Rationale:
* GMT, currently, doesn't support TileDB for reading
* GMT defines a grid spec based on the input data, which often results in extent buffering
* Potential to paralellise the actual processing
* Simplify the utility and better streamlined workflows
* Potential to establish as a TileDB UDF and share this functionality (and eventually the entire workflow, that this blockmedian is just one piece of)
* Efficiency gains
* Moving from CSV -> TileDB
* Compression
* I/O speed
* Paralell I/O for distrbuted processing
The sample data constisted of a CSV, of X,Y,Z with horizontal and vetical uncertainty.
Source data was from MBES, but no idea from what sensor or survey.
"""
import numpy
import rasterio
import tiledb
import click
import structlog
LOG = structlog.get_logger()
DSM_URI = ""
PC_URI = ""
OUT_PC_URI = ""
def lonlat_domain():
"""Set array lon/lat domain."""
index_filters = tiledb.FilterList([tiledb.ZstdFilter(level=16)])
xdim = tiledb.Dim(
"X",
domain=(None, None),
tile=1000,
dtype=numpy.float64,
filters=index_filters,
)
ydim = tiledb.Dim(
"Y",
domain=(None, None),
tile=1000,
dtype=numpy.float64,
filters=index_filters,
)
domain = tiledb.Domain(xdim, ydim)
return domain
def xyz_schema(ctx=None):
"""Data used for testing only had XYZ, and horizontal/vertical uncertainty"""
domain = lonlat_domain() # only 2 dims for the GMRT project
filters = [tiledb.ZstdFilter(level=16)]
attributes = [
tiledb.Attr("Z", dtype=numpy.float32, filters=filters),
tiledb.Attr("POSITION_TPU", dtype=numpy.float32, filters=filters),
tiledb.Attr("VERTICAL_TPU", dtype=numpy.float32, filters=filters),
]
schema = tiledb.ArraySchema(
domain=domain,
sparse=True,
attrs=attributes,
cell_order="hilbert",
tile_order="row-major",
capacity=100_000,
allows_duplicates=True,
ctx=ctx,
)
return schema
def append_ping_dataframe(dataframe, array_uri, ctx=None):
"""
Append the ping dataframe read from a GSF file.
Only to be used with sparse arrays.
"""
kwargs = {
"mode": "append",
"sparse": True,
"ctx": ctx,
}
tiledb.dataframe_.from_pandas(array_uri, dataframe, **kwargs)
def _median(dataframe, transform):
xcol, ycol = ~transform * (dataframe.X, dataframe.Y)
dataframe["X_Col"] = xcol.astype("int64") # int64 might be overkill
dataframe["Y_Col"] = ycol.astype("int64") # int64 might be overkill
dataframe.set_index(["X_Col", "Y_Col"], inplace=True)
decimated = dataframe.groupby(level=[0, 1]).median()
return decimated
def block_median(ras_ds, pc_ds, out_uri, ctx, attrs):
total_points = 0
remaining_points = 0
for _, window in ras_ds.block_windows(1):
bounds = ras_ds.window_bounds(window) # (xmin, ymin, xmax, ymax)
xstart, ystart = bounds[0], bounds[1]
xend, yend = bounds[2], bounds[3]
# just use full dataset transform for time being
# the window transform could benefit from lower memory footprint
# where max X_Col and max Y_Col is defined by the window size
# transform = ras_ds.window_transform(window)
# potential here to hit memory limits
# eg large window, and lots of points
# for production it might be better to check for incomplete query
# and resize the block window (recurse to smaller again if still fails)
df = pc_ds.query(attrs=attrs).df[xstart:xend, ystart:yend]
# nothing to process
if len(df) == 0:
continue
total_points += len(df)
# unfortunately we have side effects in the function
# it's only 6 lines, but means we can test the specific func
decimated = _median(df, ras_ds.transform)
remaining_points += len(decimated)
# serialise to decimated pc
append_ping_dataframe(decimated, OUT_PC_URI, ctx)
return total_points, remaining_points
@click.command()
@click.option("--tiledb-config", type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), required=True, help="Path to TileDB configuration document.")
def main(tiledb_config):
cfg = tiledb.Config.load(tiledb_config)
ctx = tiledb.Ctx(cfg)
vfs = tiledb.vfs.VFS(ctx=ctx)
if vfs.is_dir(OUT_PC_URI):
LOG.info("Removing previous array", array_uri=OUT_PC_URI)
vfs.remove_dir(OUT_PC_URI)
LOG.info("Creating array", array_uri=OUT_PC_URI)
schema = xyz_schema(ctx)
tiledb.Array.create(OUT_PC_URI, schema)
with rasterio.open(DSM_URI, tiledb_config=tiledb_config) as ras_ds:
with tiledb.open(PC_URI, ctx=ctx) as pc_ds:
# attrs = ["Z"]
attrs = None # all attrs
LOG.info("Processing block median")
summary = block_median(ras_ds, pc_ds, OUT_PC_URI, ctx, attrs)
LOG.info(
"Finished processing block median",
input_points=f"{summary[0]:_}",
output_points=f"{summary[1]:_}",
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment