Skip to content

Instantly share code, notes, and snippets.

@emileten
Created February 8, 2022 04:32
Show Gist options
  • Save emileten/cd28b028b3262dfe76fcb681eed129a3 to your computer and use it in GitHub Desktop.
Save emileten/cd28b028b3262dfe76fcb681eed129a3 to your computer and use it in GitHub Desktop.
impose a cell specific temporal cap on an [time, lon, lat] xarray dataset.
import xarray as xr
import numpy as np
### Fake data ###
def spatio_temporal_gcm_factory(
x=np.random.rand(1, 361, 721),
start_date="1995-01-01",
lat=np.arange(-90, 90.5, 0.5),
lon=np.arange(-180, 180.5, 0.5),
units="someunit",
):
time = xr.cftime_range(
start=start_date, freq="D", periods=len(x), calendar="standard"
)
out = xr.DataArray(
data=x,
coords={"time": time, "lat": lat, "lon": lon},
dims=["time", "lat", "lon"],
attrs={"units": units},
)
return out
# be fast : 2 time steps, 1 latitude value, 2 longitude values.
def tiny_factory(start_date):
return spatio_temporal_gcm_factory(x=np.random.rand(2, 1, 2), start_date=start_date, lat=[1.0], lon=[1.0, 2.0])
# to check if dask likes this workflow
def chunked_tiny_factory(start_date):
non_chunked = spatio_temporal_gcm_factory(x=np.random.rand(2, 1, 2), start_date=start_date, lat=[1.0], lon=[1.0, 2.0])
chunked = non_chunked.chunk({'time':-1,'lat': -1, 'lon':2}) # can't chunk across time !
return chunked
factory = chunked_tiny_factory
era = factory(start_date="1950-01-01")
clean_hist = factory(start_date="1950-01-01")
clean_future = factory(start_date="2050-01-01")
ds_future = factory(start_date="2050-01-01")
### Compute ###
def cell_max(da):
return da.stack(gridcell=["lat", "lon"]).groupby("gridcell").max('time').unstack("gridcell")
era_cell_max = cell_max(era)
clean_hist_cell_max = cell_max(clean_hist)
clean_future_cell_max = cell_max(clean_future)
bias_corrected_cap = era_cell_max * (clean_future_cell_max/clean_hist_cell_max)
bias_corrected_cap.values[0][0] = 0.1 # so that behavior is obvious in the output.
time_steps = ds_future['time'].values # to expand dims
### Swap values for replacement where needed
cell_specific_cap_expanded = cell_specific_cap.expand_dims(dim=dict(time=time_steps)) # Need to add time dimension to the 'max' (creating duplicates)
final_output = ds_future.where(ds_future < cell_specific_cap_expanded, cell_specific_cap_expanded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment