Last active
June 1, 2023 00:41
-
-
Save aulemahal/4873db65992369420a96834026b33470 to your computer and use it in GitHub Desktop.
Spatial intersection of an xarray defined grid and a shapely polygon, using pyproj to convert to an equal earth projection for more accuracy. With tools to get the grid corners and rotated_pole projections.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Spatial intersection and average | |
# Compute weights corresponding to the intersection between grid cells and a polygon. | |
# Pascal bourgault, Sept 2020 | |
from cartopy import crs | |
from shapely.geometry import Polygon | |
from shapely.ops import transform | |
import xarray as xr | |
import numpy as np | |
from functools import partial | |
import pyproj | |
def add_bounds(da, xdim='lon', ydim='lat', edge='reflect'): | |
"""Add bounds coordinates to a dataset. Only regular grid supported (1D coordinates). | |
Parameters | |
---------- | |
da : xr.Dataset | |
Xarray dataset defining grid centers | |
xdim, ydim : str | |
Name of the two coordinates. | |
edge : {'reflect', 'mean', float} | |
How the points at the edge are extrapolated. | |
If 'reflect', the boundary grid steps are the same as their immediate neighbors. | |
if 'mean', the boundary grid steps are the mean of all grid steps. | |
if a number, it is used as the boundary grid step. | |
Returns | |
------- | |
xr.Dataset | |
A copy of da with new grid corners coordinates, they have the same name as the center | |
coordinates with a '_b' suffix. Ex: lat_b and lon_b. | |
""" | |
x = da[xdim] | |
dx = x.diff(xdim, label='lower') | |
if edge == 'reflect': | |
dx_left = dx[0] | |
dx_right = dx[-1] | |
elif edge == 'mean': | |
dx_left = dx_right = dx.mean() | |
else: | |
dx_left = dx_right = edge | |
Xs = np.concatenate(([x[0] - dx_left / 2], x + dx / 2, [x[-1] + dx_right])) | |
y = da[ydim] | |
dy = y.diff(ydim, label='lower') | |
if edge == 'reflect': | |
dy_left = dy[0] | |
dy_right = dy[-1] | |
elif edge == 'mean': | |
dy_left = dy_right = dy.mean() | |
else: | |
dy_left = dy_right = edge | |
Ys = np.concatenate(([y[0] - dy_left / 2], y + dy / 2, [y[-1] + dy_right])) | |
xdimb = xdim + '_b' | |
ydimb = ydim + '_b' | |
return da.assign(**{xdimb: ((xdimb,), Xs), ydimb: ((ydimb,), Ys)}) | |
def get_proj_rotated_pole(da): | |
"""Parses the `rotated_pole` variable attributes of a dataset to generate the proper | |
pyproj projection object with the help of cartopy. | |
""" | |
return pyproj.Proj( | |
crs.RotatedPole(pole_longitude=da.rotated_pole.grid_north_pole_longitude, | |
pole_latitude=da.rotated_pole.grid_north_pole_latitude).proj4_params | |
) | |
def subset_box(da, poly, pad=1): | |
"""Subset an xarray object to the smallest box including the polygon. | |
Assuming the polygon is defined as lon/lat and that those are variables in da. | |
A mask if first constructed for all grid centers falling within the bounds of the polygon, | |
then a padding of {pad} cells is added around it to be sure all points are included. | |
""" | |
dims = da.lon.dims | |
if len(dims) == 1: | |
dims = ['lon', 'lat'] | |
if hasattr(poly, 'total_bounds'): | |
min_lon, min_lat, max_lon, max_lat = poly.total_bounds | |
else: | |
min_lon, min_lat, max_lon, max_lat = poly.bounds | |
mask = (da.lon >= min_lon) & (da.lon <= max_lon) & (da.lat >= min_lat) & (da.lat <= max_lat) | |
if mask.sum() == 0: | |
raise ValueError(('The returned mask is empty. Either the polygons do not overlap with ' | |
'the grid or they are too small. In the latter case, try adding a ' | |
'buffer: subset_box(da, poly.buffer(0.5)).')) | |
for dim in dims: | |
mask = mask.rolling(**{dim: 2 * pad + 1}, min_periods=1, center=True).max() | |
return da.where(mask, drop=True) | |
def compute_area_weights(da, polys, dims=['lon', 'lat'], grid_crs=None, mode='fracpoly', cartesian=True): | |
"""Compute the area weights of each polygon on each gridcell | |
Parameters | |
---------- | |
da : xr.DataArray or xr.Dataset | |
A xarray object defining the centers (and optionally the corners) of a grid. | |
polys : geopandas.GeoSeries or geopandas.GeoDataframe | |
A Series of Polygons, with a defined crs | |
dims : Sequence of str | |
The names of the two coordinates defining the grid corners in da. Both must be 1D in da. | |
If 4 names are passed, the last two are used as the grid corners | |
grid_crs : pyproj.Proj | |
The projection of the grid in da. Defaults to "epsg:4326" | |
mode : {fracpoly, fracgrid, area} | |
The type of output. | |
fracpoly returns the fraction of polygon covered by each grid cell, | |
fraccell returns the fraction of each grid cell covered by the polygon, | |
area returns the area of the intersection between the grid cell and the polygon. | |
cartesian : bool | |
If true, the areas are computed in 'm' using the Equal Earth Greenwich projection, | |
if false the crs of the grid is used. Output will be attributed units=''. | |
Returns | |
------- | |
xr.DataArray | |
The weights defined along dims[0] and dims[1], according to method mode, for each polygon. | |
The first dimension is the same index as in polys. | |
""" | |
if len(dims) == 2: | |
xdim, ydim = dims | |
xdimb, ydimb = xdim + '_b', ydim + '_b' | |
da = add_bounds(da, xdim=xdim, ydim=ydim, edge='reflect') | |
else: | |
xdim, ydim, xdimb, ydimb = dims | |
weights = np.empty((polys.shape[0], da[xdim].size, da[ydim].size), dtype=float) | |
if cartesian: | |
proj = partial( | |
pyproj.transform, | |
grid_crs or pyproj.Proj(4326), # source coordinate system | |
pyproj.Proj(8857), # destination coordinate system Equal Earth Greenwich | |
always_xy=True # Proj 4326 is lat, lon | |
) | |
polys = polys.to_crs(epsg=8857) | |
else: | |
polys = polys.to_crs(grid_crs or pyproj.Proj(4326)) | |
for k, poly in enumerate(polys.geometry): | |
poly_area = poly.area | |
Xs = da[xdimb].values | |
Ys = da[ydimb].values | |
for i in range(da[xdim].size): | |
for j in range(da[ydim].size): | |
grid = Polygon( | |
[(Xs[i], Ys[j]), | |
(Xs[i], Ys[j + 1]), | |
(Xs[i + 1], Ys[j + 1]), | |
(Xs[i + 1], Ys[j]), | |
(Xs[i], Ys[j])] | |
) | |
if cartesian: | |
grid = transform(proj, grid) | |
intersect = grid.intersection(poly) | |
if mode == 'fracpoly': | |
weights[k, i, j] = intersect.area / poly_area | |
elif mode == 'fraccell': | |
weights[k, i, j] = intersect.area / grid.area | |
elif mode == 'area': | |
weights[k, i, j] = intersect.area | |
else: | |
raise ValueError(f'mode must be one of fracpoly, fraccell or area, got {mode}.') | |
desc = { | |
'fracpoly': 'The fraction of the polygon that is covered by each grid cell.', | |
'fraccell': 'The fraction of the gridcell that covers the polygon.', | |
'area': 'The area of intersection between the gridcell and the polygon.' | |
}[mode] | |
coords = {crd: crdda | |
for crd, crdda in da.coords.items() | |
if all([dim in [xdim, ydim] for dim in crdda.dims])} | |
coords['poly'] = polys.index | |
return xr.DataArray( | |
weights, | |
coords=coords, | |
dims=('poly', xdim, ydim), | |
name='weights', | |
attrs={'units': 'm^2' if mode == 'area' and cartesian else '', | |
'description': desc} | |
) | |
if __name__ == '__main__': | |
# usage example | |
import geopandas as gpd | |
ds = xr.open_dataset('a dataset defining a grid') | |
df = gpd.read_file('a list of polygons') | |
# simple case: Large grid, polygons concentrated on a region smaller than whole dataset | |
ds_sub = subset_box(ds, df.to_crs(epsg=4326)) # Accelerates the computation | |
w = compute_area_weights(ds_sub, df, dims=['lon', 'lat'], cartesian=True) | |
# Rotated pole case : ds is in rlon, rlat and has a "rotated_pole" variable | |
rp_crs = get_proj_rotated_pole(da) | |
w = compute_area_weights(ds_sub, df, dims=['rlon', 'rlat'], grid_crs=rp_crs, cartesian=True) | |
# Small poly case, polygons are too small, they all fall in one grid cell. Regulat plate carre dataset. | |
ds_sub = subset_box(ds, df.to_crs(epsg=4326).buffer(1)) # Add buffer to subset a larger box | |
w = compute_area_weights(ds_sub, df, dims=['lon', 'lat'], cartesian=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment