Skip to content

Instantly share code, notes, and snippets.

@peterm790
Created January 11, 2024 14:53
Show Gist options
  • Save peterm790/034a2e9bac0e70ceffb29e5aabfd3893 to your computer and use it in GitHub Desktop.
Save peterm790/034a2e9bac0e70ceffb29e5aabfd3893 to your computer and use it in GitHub Desktop.
Calculate Percentiles Using Bootstrapping
import xarray as xr
import intake
import numpy as np
from pathlib import Path
import dask
import pandas as pd
import sys
import fsspec
def shuffle_completely(ds, b, resample_dim="z"):
resample_dim_values = ds[resample_dim].values
smp_resample_dim = np.random.choice(resample_dim_values, len(resample_dim_values))
smp_ds = ds.sel({resample_dim: smp_resample_dim})
smp_ds[resample_dim] = resample_dim_values
return smp_ds
def get_percentiles(ds_resampled):
percentiles = []
for i in [0.01, 0.02, 0.05, 0.1 ,0.25, 0.5, 0.75, 0.9, 0.95, 0.98, 0.99]:
percentiles.append(ds_resampled.quantile(i, dim = 'z', skipna = False))
return xr.concat(percentiles, dim = 'quantile')
def parallelize_bootstrap_delayed(ds, n_bootstrap, func, shuffle_func):
ds = dask.delayed(ds)
results = []
for b in range(n_bootstrap):
ds_resampled = dask.delayed(shuffle_func)(ds, b)
results.append(dask.delayed(func)(ds_resampled))
results = dask.compute(*results)
return results
from dask_jobqueue import SLURMCluster
from distributed import Client
cluster=SLURMCluster(cores=1, #allocate this many cores to a job - slurm option cpus-per-task
#processes=10, # cut the job into this many processes - no slurm equiv???
memory="576GB", # Total memory for all nodes - dask thing???
job_mem='12GB', # Node memory to reserve in slurm
#job_cpu=10, # number of cpu [cores] to reserve in slurm??? -cpus-per-task again
walltime="48:00:00",
queue="base",
n_workers=50, #number of default workers - can use scale to manage this dynamically
interface='ib0',
death_timeout=180,
local_directory='/tmp/',
job_script_prologue=['export MALLOC_TRIM_THRESHOLD_=0',
'export BLOSC_NTHREADS=1',],
scheduler_options={
'interface':'ib0',
'dashboard_address':'localhost:8787',
},
)
if __name__ == "__main__":
client=Client(cluster,timeout=180)
try:
client.wait_for_workers(n_workers=50, timeout=180)
print('all workers available')
except:
print('not all workers available, continuing anyway')
for var in ['tasdp','tas','tasmin','tasmax']:
#open_dataset
catalog = intake.open_catalog(Path(Path.home(),'heat_center/data/climate/reanalysis/intake/reanalysis.yaml'))
ds = catalog['ERA5']['day'][var].to_dask()
#select time with 5 day gap to allow correct windows
ds = ds.sel(time = slice('1980-12-24','2011-01-06'))[var]
#subselect to same Africa-only domain as ERA5-Land
ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180))
ds = ds.sortby(ds.longitude)
ds = ds.sel(latitude = slice(38, -36), longitude = slice(-20, 52))
#create 5 day sample window
ds = ds.to_dataset(name = 'zero') #convert data array to dataset
names = {-2:'minus two', -1:'minus one', 1:'plus one', 2:'plus two'}
for i in range(-2,3):
if i!=0:
ds[names[i]] = ds.zero.shift(time = i)
ds = xr.concat([ds[var] for var in list(ds)], pd.Index(list(range(5)), name='shift'))
#now slice to correct reference time
ds = ds.sel(time = slice('1981','2010'))
fs = fsspec.filesystem("")
fs.mkdirs(f'{var}_files/', exist_ok = True)
files = [f.split('/')[-1] for f in fs.ls(f'{var}_files/')]
for DOY in range(1,366):
if f'{var}_{DOY}.nc' not in files:
ds_doy = ds[:,list(np.where(ds.time.dt.dayofyear.values == DOY)[0])]
ds_doy = ds_doy.load()
ds_doy = ds_doy.stack(z = ['shift','time'])
results = parallelize_bootstrap_delayed(ds_doy, 200, get_percentiles, shuffle_completely)
ds_bs = xr.concat(list(results), pd.Index(list(range(200))))
ds_final = ds_bs.median(dim = 'concat_dim').rename('bs_quantile_median').to_dataset()
ds_final['bs_quantile_variance'] = ds_bs.var(dim = 'concat_dim')
ds_final = ds_final.assign_coords(dayofyear = DOY)
ds_final = ds_final.expand_dims(dim="dayofyear")
ds_final.to_netcdf(path = f'/home/pmarsh/percentile_boot_strapping/{var}_files/{var}_{DOY}.nc', format = 'NETCDF4')
ds = xr.open_mfdataset(f'{var}_files/*')
ds.to_netcdf(f'/terra/data/reanalysis/global/reanalysis/ECMWF/ERA5-HEAT/clim/native/{var}_bs_perc_ECMWF_ECMWF_ERA5_19810101-20101231.nc.nc', format = 'NETCDF4')
fs.rm(f'{var}_files/', recursive=True) #delete daily interim files
del ds
client.close()
cluster.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment