Last active
July 9, 2020 01:58
-
-
Save dahoal/e623bfbd2ed7619961cfd97318b230df to your computer and use it in GitHub Desktop.
calculate index with dask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#===================================== job script ================================================ | |
#!/bin/bash | |
#PBS -N spi | |
#PBS -P w35 | |
#PBS -q normal | |
#PBS -l walltime=05:00:00 | |
#PBS -l mem=190GB | |
#PBS -l ncpus=10 | |
#PBS -l storage=gdata/w35+gdata/hh5+gdata/rr8 | |
#PBS -l jobfs=100GB | |
module load hdf4 | |
module use /g/data3/hh5/public/modules | |
module load conda/analysis3 | |
path=/home/565/dh4185/scripts | |
python -W ignore $path/dev.spi_awra.py ndays=30 | |
#===================================== script ================================================ | |
import matplotlib.pyplot as plt | |
import xarray as xr | |
import tempfile | |
import dask | |
from dask.diagnostics import ProgressBar | |
import climtas.nci | |
import os | |
import numpy as np | |
from glob import glob | |
import warnings | |
from numpy import apply_along_axis | |
from scipy.stats import rankdata | |
import pandas as pd | |
def getSizeGB(arr): | |
return '{:.3f} GB'.format(arr.nbytes / 1024 ** 3) | |
def rank_by_monthday(da): | |
""" | |
Return the ranking of each grid point in 'da', calculated by grouping each | |
horizontal grid cell by month and day and returning the ranking within that | |
group at each time point (developed by Scott Wales) | |
""" | |
# Drop time chunking | |
rechunked = da.chunk({'time': None}) | |
# Check size of one chunk | |
chunksize = np.prod(da.data.chunksize)*da.data.itemsize | |
if chunksize > 50 * 1024**2: | |
warnings.warn("Chunk size is over 50 mb, try smaller horizontal chunks") | |
# Each chunk contains the full timeseries for a specific horizontal domain | |
# Use xarray.map_blocks to analyse that domain | |
monthday = da.time.dt.month * 100 + da.time.dt.day | |
da.coords['monthday'] = monthday | |
def rank_along_axis(x, **kwargs): | |
return apply_along_axis(rankdata, 0, x, **kwargs) | |
def block_ranker(block): | |
if block.size == 0: | |
return block | |
return block.groupby('monthday').apply(rank_along_axis, method='dense') | |
return da.map_blocks(block_ranker) | |
def SPI(arr, ndays, climstart, climend): | |
""" | |
For a specified date and nday length, calculates the standardised precipitation | |
index (SPI, McKee 1993), based on daily precipitation data. | |
Months are defined as 30-day intervals. Leap years are simply kicked out to | |
simplify the ranking of the data. | |
Framework based on Farahmand and AghaKouchak 2015 | |
Original code based on EDDI Fortan code from: Mike Hobbins | |
Date: March 10, 2016 | |
Rewritten in Python by: David Hoffmann | |
Date: November 21, 2018 | |
New rank_by_monthday by: Scott Wales | |
Rewritten for dask by: Christian Stassen | |
Date: June, 2020 | |
:param arr: xarray dataset | |
:param ndays: scaling parameter e.g. avaraging over 30 days to achieve an 1-month time scale -> SPI1 | |
""" | |
# do rolling mean with leap years | |
print('Calculating rolling sum...') | |
with xr.set_options(keep_attrs=True): | |
# arr = arr.rolling(time=ndays, center=False).sum().dropna(dim='time',how='any') | |
arr = arr.rolling(time=ndays, center=False).sum(skipna=False)#.isel(time=slice(ndays-1,None)) | |
# kick out leap years | |
print('Remove leap days...') | |
arr = arr.sel(time=~((arr.time.dt.month == 2) & (arr.time.dt.day == 29))) | |
arr = arr.chunk({'time':None, 'lat':20, 'lon':20}) | |
### new | |
print(arr) | |
print('Save temp file to zarr...') | |
arr.to_dataset(name='precip').to_zarr(os.environ['PBS_JOBFS'] + '/data.zarr',mode='w') | |
print('Opening temp file from zarr...') | |
arr = xr.open_zarr(os.environ['PBS_JOBFS'] + '/data.zarr').precip | |
print(arr) | |
### | |
print('Ranking by day of year...') | |
# arr = arr.chunk({'time':None, 'lat':10, 'lon':10}) | |
rank = rank_by_monthday(arr) | |
rank = rank.drop('monthday') # drop monthday dimension | |
# Define sample size wrt length of climatology + year in question | |
N = climend - climstart + 1 | |
den = N+0.33 | |
# Empirical Tukey plotting position (Wilks 2011) | |
print('Calculating P...') | |
P = (rank - 0.33)/den | |
print('Calculating W...') | |
W1 = P.where(P>0.5,xr.ufuncs.sqrt(-2. * xr.ufuncs.log(P))) # Contains the value we want where P<=0.5, else contains P | |
W = W1.where(P<=0.5,xr.ufuncs.sqrt(-2. * xr.ufuncs.log(1-P))) # We take W1 values where P<=0.5, else the alternate calculation. | |
# Inverse normal approximation (Vincente-Serrano et al 2010) | |
print('Calculating inverse normal approximation...') | |
# normalisation parameters | |
C0 = 2.515517 | |
C1 = 0.802853 | |
C2 = 0.010328 | |
d1 = 1.432788 | |
d2 = 0.189269 | |
d3 = 0.001308 | |
pr_est1 = P.where(P>0.5, -1. * (W - (C0 + C1 * W + C2 * W**2.) / (1. + d1 * W + d2 * W**2. + d3 * W**3.))) | |
pr_est = pr_est1.where(P<=0.5, W - (C0 + C1 * W + C2 * W**2.) / (1. + d1 * W + d2 * W**2. + d3 * W**3.)) | |
pr_est = pr_est.to_dataset(name='SPI_{}d'.format(ndays)) | |
#pr_est = xr.Dataset({'SPI_{}d'.format(ndays): (('time', 'lat','lon'), pr_est)},{'time': pr['time'], 'lat': pr['lat'], 'lon': pr['lon']}) | |
return pr_est | |
def main(**kwargs): | |
print(kwargs) | |
# scaling parameter e.g. avaraging over 30 days to achieve an 1-month time scale -> SPI1 | |
ndays = int(kwargs['ndays']) | |
## directories | |
odir = '/g/data/w35/dh4185/data/AWAP/SPI/' | |
# odir = '/Users/david/Documents/MonashUniverityPhD/Data/' | |
ofile = odir+'SPI_awap_agg{}d_{}-{}.nc'.format(ndays,'{}','{}') #Keep two dummys empty for now | |
# ofile = odir+'SPI_awap_agg{}d_{}-{}encode2.nc'.format(ndays,'{}','{}') #Keep two dummys empty for now | |
ifile = ['/g/data/w35/dh4185/data/AWAP/precip/precip_total_0.05_{}.nc'.format(i) for i in range(1975,2018)] | |
# ifile = odir+'pr_dummy.nc' | |
#< read data | |
print('Read data...') | |
# ds_in = xr.open_dataset(ifile) | |
ds_in = xr.open_mfdataset(sorted(ifile),combine='nested', concat_dim='time')#.sel(lat=slice(-40,-20)) | |
# ds_in = xr.open_mfdataset(sorted(glob(ifile)),combine='nested', concat_dim='time')#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5))#, combine='nested', concat_dim='time')#,chunks={'time':None, 'lat':1, 'lon':1})#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5))#.sel(time=slice('1975', '2018'))#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5)) | |
da = ds_in.chunk({'lat': 40, 'lon': 40}).precip | |
da = da.sortby(da.lat) | |
print('Opening dataset of {}'.format(getSizeGB(da))) | |
print(da) | |
#< Get calibration period (default to total length of time series) | |
if 'climstart' in kwargs: | |
climstart = int(kwargs['climstart']) | |
else: | |
climstart = int(pd.to_datetime(str(da.time[0].values)).year) | |
if 'climend' in kwargs: | |
climend = int(kwargs['climend']) | |
else: | |
climend = int(pd.to_datetime(str(da.time[-1].values)).year) | |
#< Calculate SPI | |
spi = SPI(da, ndays, climstart, climend) | |
print(spi) | |
#< Convert metpy/pint unit to string for writing | |
# saver = spi.to_netcdf(ofile.format(climstart,climend), compute=False, encoding={'SPI_30d': {'chunksizes': spi['SPI_30d'].data.chunksize, 'zlib': True, 'complevel': 4, 'shuffle': True}}) | |
saver = spi.to_netcdf(ofile.format(climstart,climend), compute=False, encoding={'SPI_30d': {'chunksizes': spi['SPI_30d'].data.chunksize, 'shuffle': True}}) | |
future = client.persist(saver) | |
dask.distributed.progress(future) | |
future.compute() | |
print('') | |
if __name__ == '__main__': | |
import dask.distributed | |
import sys | |
## Get the number of CPUS in the job and start a dask.distributed cluster | |
# mem = 190 | |
# cores = int(os.environ.get('PBS_NCPUS','4')) | |
# memory_limit = '{}gb'.format(int(max(mem/cores, 4))) | |
# client = dask.distributed.Client(n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp()) | |
client = climtas.nci.GadiClient()#n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp()) | |
#< Print client summary | |
print('### Client summary') | |
print(client) | |
print('\n\n') | |
#< Call the main function | |
main(**dict(arg.split('=') for arg in sys.argv[1:])) | |
#< Close the client | |
client.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment