Created
January 11, 2024 14:53
-
-
Save peterm790/034a2e9bac0e70ceffb29e5aabfd3893 to your computer and use it in GitHub Desktop.
Calculate Percentiles Using Bootstrapping
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import intake | |
import numpy as np | |
from pathlib import Path | |
import dask | |
import pandas as pd | |
import sys | |
import fsspec | |
def shuffle_completely(ds, b, resample_dim="z"): | |
resample_dim_values = ds[resample_dim].values | |
smp_resample_dim = np.random.choice(resample_dim_values, len(resample_dim_values)) | |
smp_ds = ds.sel({resample_dim: smp_resample_dim}) | |
smp_ds[resample_dim] = resample_dim_values | |
return smp_ds | |
def get_percentiles(ds_resampled): | |
percentiles = [] | |
for i in [0.01, 0.02, 0.05, 0.1 ,0.25, 0.5, 0.75, 0.9, 0.95, 0.98, 0.99]: | |
percentiles.append(ds_resampled.quantile(i, dim = 'z', skipna = False)) | |
return xr.concat(percentiles, dim = 'quantile') | |
def parallelize_bootstrap_delayed(ds, n_bootstrap, func, shuffle_func): | |
ds = dask.delayed(ds) | |
results = [] | |
for b in range(n_bootstrap): | |
ds_resampled = dask.delayed(shuffle_func)(ds, b) | |
results.append(dask.delayed(func)(ds_resampled)) | |
results = dask.compute(*results) | |
return results | |
from dask_jobqueue import SLURMCluster | |
from distributed import Client | |
cluster=SLURMCluster(cores=1, #allocate this many cores to a job - slurm option cpus-per-task | |
#processes=10, # cut the job into this many processes - no slurm equiv??? | |
memory="576GB", # Total memory for all nodes - dask thing??? | |
job_mem='12GB', # Node memory to reserve in slurm | |
#job_cpu=10, # number of cpu [cores] to reserve in slurm??? -cpus-per-task again | |
walltime="48:00:00", | |
queue="base", | |
n_workers=50, #number of default workers - can use scale to manage this dynamically | |
interface='ib0', | |
death_timeout=180, | |
local_directory='/tmp/', | |
job_script_prologue=['export MALLOC_TRIM_THRESHOLD_=0', | |
'export BLOSC_NTHREADS=1',], | |
scheduler_options={ | |
'interface':'ib0', | |
'dashboard_address':'localhost:8787', | |
}, | |
) | |
if __name__ == "__main__": | |
client=Client(cluster,timeout=180) | |
try: | |
client.wait_for_workers(n_workers=50, timeout=180) | |
print('all workers available') | |
except: | |
print('not all workers available, continuing anyway') | |
for var in ['tasdp','tas','tasmin','tasmax']: | |
#open_dataset | |
catalog = intake.open_catalog(Path(Path.home(),'heat_center/data/climate/reanalysis/intake/reanalysis.yaml')) | |
ds = catalog['ERA5']['day'][var].to_dask() | |
#select time with 5 day gap to allow correct windows | |
ds = ds.sel(time = slice('1980-12-24','2011-01-06'))[var] | |
#subselect to same Africa-only domain as ERA5-Land | |
ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180)) | |
ds = ds.sortby(ds.longitude) | |
ds = ds.sel(latitude = slice(38, -36), longitude = slice(-20, 52)) | |
#create 5 day sample window | |
ds = ds.to_dataset(name = 'zero') #convert data array to dataset | |
names = {-2:'minus two', -1:'minus one', 1:'plus one', 2:'plus two'} | |
for i in range(-2,3): | |
if i!=0: | |
ds[names[i]] = ds.zero.shift(time = i) | |
ds = xr.concat([ds[var] for var in list(ds)], pd.Index(list(range(5)), name='shift')) | |
#now slice to correct reference time | |
ds = ds.sel(time = slice('1981','2010')) | |
fs = fsspec.filesystem("") | |
fs.mkdirs(f'{var}_files/', exist_ok = True) | |
files = [f.split('/')[-1] for f in fs.ls(f'{var}_files/')] | |
for DOY in range(1,366): | |
if f'{var}_{DOY}.nc' not in files: | |
ds_doy = ds[:,list(np.where(ds.time.dt.dayofyear.values == DOY)[0])] | |
ds_doy = ds_doy.load() | |
ds_doy = ds_doy.stack(z = ['shift','time']) | |
results = parallelize_bootstrap_delayed(ds_doy, 200, get_percentiles, shuffle_completely) | |
ds_bs = xr.concat(list(results), pd.Index(list(range(200)))) | |
ds_final = ds_bs.median(dim = 'concat_dim').rename('bs_quantile_median').to_dataset() | |
ds_final['bs_quantile_variance'] = ds_bs.var(dim = 'concat_dim') | |
ds_final = ds_final.assign_coords(dayofyear = DOY) | |
ds_final = ds_final.expand_dims(dim="dayofyear") | |
ds_final.to_netcdf(path = f'/home/pmarsh/percentile_boot_strapping/{var}_files/{var}_{DOY}.nc', format = 'NETCDF4') | |
ds = xr.open_mfdataset(f'{var}_files/*') | |
ds.to_netcdf(f'/terra/data/reanalysis/global/reanalysis/ECMWF/ERA5-HEAT/clim/native/{var}_bs_perc_ECMWF_ECMWF_ERA5_19810101-20101231.nc.nc', format = 'NETCDF4') | |
fs.rm(f'{var}_files/', recursive=True) #delete daily interim files | |
del ds | |
client.close() | |
cluster.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment