Skip to content

Instantly share code, notes, and snippets.

@josephhardinee
Created August 25, 2021 16:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save josephhardinee/5e1b8da4764239a029c16cf4ceaaca8e to your computer and use it in GitHub Desktop.
Save josephhardinee/5e1b8da4764239a029c16cf4ceaaca8e to your computer and use it in GitHub Desktop.
climatology_futures
#!/usr/bin/env python
import xarray as xr
import gcsfs
import numpy as np
import dask as da
import logging
from dask.distributed import Client, as_completed
def process_latitude_slice(ds_all, lat_idx, lon_idx):
""" Calculate a latitude slice statistics."""
# logger.info(f"Longitude {idx_2}/{len(longitude_r)}.")
step_size=2.5
#ds_all = xr.open_zarr("gcs://era_5_bucket/2m_temperature/2m_temperature_1979_2019_v2.zarr", consolidated=True)
#ds_all = ds_all.assign_coords(latitude= np.arange(90, -90.25, -0.25),longitude= np.arange(-180, 180, 0.25) )
longitude_r = np.arange(-180, 180, step_size)
latitude_r = np.arange(90, -90, -1*step_size)
ds_subset = ds_all.sel(longitude=slice(longitude_r[lon_idx], longitude_r[lon_idx]+step_size-0.25),
latitude = slice(latitude_r[lat_idx], latitude_r[lat_idx]-step_size+0.25)).compute()
result = (ds_subset.groupby("time.dayofyear").mean().compute(), ds_subset.groupby("time.dayofyear").std().compute()) # Submit as one to let dask find redundancies in graph.
return result
if __name__ == '__main__':
# import pdb; pdb.set_trace()
client = Client(n_workers=5, threads_per_worker=1)
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
ds_list_mean = []
ds_list_std = []
futures = []
step_size = 5
steps_lat = int(180/step_size)
steps_lon = int(360/step_size)
ds_all = xr.open_zarr("gcs://era_5_bucket/2m_temperature/2m_temperature_1979_2019_v2.zarr", consolidated=True)
ds_all = ds_all.assign_coords(latitude= np.arange(90, -90.25, -0.25),longitude= np.arange(-180, 180, 0.25) )
data_future = client.scatter(ds_all)
for idx_2 in range(72):
for idx_1 in range(36):
future = client.submit(process_latitude_slice, data_future, idx_1, idx_2)
futures.append(future)
for future in as_completed(futures):
res = future.result()
print("Result received")
ds_list_mean.append(res[0])
ds_list_std.append(res[1])
#results = client.gather(futures)
#ds_list_mean, ds_list_std = zip(*results)
ds_climatology_mean = xr.combine_by_coords(ds_list_mean)
ds_climatology_std = xr.combine_by_coords(ds_list_std)
ds_climatology_mean = ds_climatology_mean.rename({"t2m": "t2m_mean"})
ds_climatology_std = ds_climatology_std.rename({"t2m": "t2m_std"})
ds_climatology = xr.merge((ds_climatology_mean, ds_climatology_std))
ds_climatology.to_netcdf("/home/jupyter/data/subseasonal/t2m_climatology_dask.nc")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment