Skip to content

Instantly share code, notes, and snippets.

@GMoncrieff
Created October 6, 2022 15:03
Show Gist options
  • Save GMoncrieff/624fc44be9adfb06fa4b6b90938ebbe8 to your computer and use it in GitHub Desktop.
Save GMoncrieff/624fc44be9adfb06fa4b6b90938ebbe8 to your computer and use it in GitHub Desktop.
Utility functions for data extraction in multidimensional arrays with non-rectilinear grids
import xoak
import xarray as xr
import numpy as np
import pandas as pd
def select_points(data: xr.Dataset,
xc: str, yc: str,
xdat: np.array, ydat: np.array,
name: str = "points") -> xr.Dataset:
"""Select points from an xarray dataset.
Args:
data (xr.Dataset): xarray dataset from which to select points
xc (str): name of the x coordinate
yc (str): name of the y coordinate
xdat (np.array): x coordinates of the points
ydat (np.array): y coordinates of the points
name (str, optional): name of the new dimension. Defaults to "points".
Returns:
xr.Dataset: xarray dataset with only the selected points
"""
# set index
data.xoak.set_index([xc, yc], 'sklearn_geo_balltree')
# create query xr
ds_sel = xr.Dataset({
xc: (name, xdat),
yc: (name, ydat)
})
# select points
data = data.xoak.sel({
xc: ds_sel[xc],
yc: ds_sel[yc]
})
return data
def add_neighbour_pixels(data: xr.Dataset,
xc: str, yc: str,
xlen:int = 3, ylen:int = 3,
zdim:str = 'z')-> xr.Dataset:
"""Add neighbour pixels to an xarray dataset.
Args:
data (xr.Dataset): xarray input dataset
xc (str): name of the x coordinate
yc (str): name of the y coordinate
xlen (int, optional): number of pixels to add on the left and right. Defaults to 3.
ylen (int, optional): number of pixels to add on the top and bottom. Defaults to 3.
zdim (str, optional): name of the z dimension. Defaults to 'z'.
Returns:
xr.Dataset: xarray dataset with pixles added in x and y dim
"""
# pad data
# the roll up and down x and y dims
# add padding? #.pad({xc: 1,yc:1}, mode='edge')\
data = data\
.rolling({xc:xlen,yc:ylen}, min_periods=1,center=True)\
.construct({xc: 'x_roll',yc: 'y_roll'})\
.stack({zdim:('x_roll','y_roll')},create_index=False)
return data
def extract_and_label(data: xr.Dataset,
query:pd.DataFrame) -> xr.Dataset:
"""Extract and label points from an xarray dataset.
Args:
data (xr.Dataset): xarray dataset from which to extract points
query (pd.DataFrame): dataframe with the points to extract
Returns:
xr.Dataset: xarray dataset with the extracted points
"""
#convert to xarray
query = query.to_xarray()
#get lat lon
xq, yq = query.longitude.values, query.latitude.values
#expand dims by adding neighbour pixels
data = add_neighbour_pixels(data, 'x', 'y', 3, 3, 'z')
#extract at points
data = select_points(data, 'latitude', 'longitude', xq, yq, 'index')
#merge with query labels
data = data.merge(query['lab'])
return data
def test():
#create df with locations and labels
df = pd.DataFrame({'longitude':[-122.666,-122.669,-122.721],'latitude':[21.1519,21.258,21.139],'lab':[1,1,7]})
#create xr with data
lats = np.array([[21.138 , 21.14499, 21.15197, 21.15894, 21.16591],
[21.16287, 21.16986, 21.17684, 21.18382, 21.19079],
[21.18775, 21.19474, 21.20172, 21.2087 , 21.21568],
[21.21262, 21.21962, 21.22661, 21.23359, 21.24056],
[21.2375 , 21.2445 , 21.25149, 21.25848, 21.26545]])
lons = np.array([[-122.72 , -122.69333, -122.66666, -122.63999, -122.61331],
[-122.7275 , -122.70082, -122.67415, -122.64746, -122.62078],
[-122.735 , -122.70832, -122.68163, -122.65494, -122.62825],
[-122.7425 , -122.71582, -122.68912, -122.66243, -122.63573],
[-122.75001, -122.72332, -122.69662, -122.66992, -122.64321]])
band = np.array([1,2])
speed = np.array([[[1, 2, 3, 4, 5],
[6 , 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25]],
[[100, 200, 300, 400, 500],
[600 , 700, 800, 900, 1000],
[1100, 1200, 1300, 1400, 1500],
[1600, 1700, 1800, 1900, 2000],
[2100, 2200, 2300, 2400, 2500]]])
ds = xr.Dataset({'SPEED':(('band','x', 'y'),speed)},
coords = {'latitude': (('x', 'y'), lats),
'longitude': (('x', 'y'), lons),
'band': ('band', band)},
attrs={'variable':'Wind Speed'})
ds = extract_and_label(data=ds,query=df)
assert list(dict(ds.dims).keys()) == ['index', 'band','z'] , "dim names do not match expectation"
assert list(dict(ds.dims).values()) == [3,2,9], "dim lengths do not match expectation"
assert ds['lab'].values.tolist() == [1,1,7], "labels do not match expectation"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment