Skip to content

Instantly share code, notes, and snippets.

@dahoal
Last active July 9, 2020 01:58
Show Gist options
  • Save dahoal/e623bfbd2ed7619961cfd97318b230df to your computer and use it in GitHub Desktop.
Save dahoal/e623bfbd2ed7619961cfd97318b230df to your computer and use it in GitHub Desktop.
calculate index with dask
#===================================== job script ================================================
#!/bin/bash
#PBS -N spi
#PBS -P w35
#PBS -q normal
#PBS -l walltime=05:00:00
#PBS -l mem=190GB
#PBS -l ncpus=10
#PBS -l storage=gdata/w35+gdata/hh5+gdata/rr8
#PBS -l jobfs=100GB
module load hdf4
module use /g/data3/hh5/public/modules
module load conda/analysis3
path=/home/565/dh4185/scripts
python -W ignore $path/dev.spi_awra.py ndays=30
#===================================== script ================================================
import matplotlib.pyplot as plt
import xarray as xr
import tempfile
import dask
from dask.diagnostics import ProgressBar
import climtas.nci
import os
import numpy as np
from glob import glob
import warnings
from numpy import apply_along_axis
from scipy.stats import rankdata
import pandas as pd
def getSizeGB(arr):
return '{:.3f} GB'.format(arr.nbytes / 1024 ** 3)
def rank_by_monthday(da):
"""
Return the ranking of each grid point in 'da', calculated by grouping each
horizontal grid cell by month and day and returning the ranking within that
group at each time point (developed by Scott Wales)
"""
# Drop time chunking
rechunked = da.chunk({'time': None})
# Check size of one chunk
chunksize = np.prod(da.data.chunksize)*da.data.itemsize
if chunksize > 50 * 1024**2:
warnings.warn("Chunk size is over 50 mb, try smaller horizontal chunks")
# Each chunk contains the full timeseries for a specific horizontal domain
# Use xarray.map_blocks to analyse that domain
monthday = da.time.dt.month * 100 + da.time.dt.day
da.coords['monthday'] = monthday
def rank_along_axis(x, **kwargs):
return apply_along_axis(rankdata, 0, x, **kwargs)
def block_ranker(block):
if block.size == 0:
return block
return block.groupby('monthday').apply(rank_along_axis, method='dense')
return da.map_blocks(block_ranker)
def SPI(arr, ndays, climstart, climend):
"""
For a specified date and nday length, calculates the standardised precipitation
index (SPI, McKee 1993), based on daily precipitation data.
Months are defined as 30-day intervals. Leap years are simply kicked out to
simplify the ranking of the data.
Framework based on Farahmand and AghaKouchak 2015
Original code based on EDDI Fortan code from: Mike Hobbins
Date: March 10, 2016
Rewritten in Python by: David Hoffmann
Date: November 21, 2018
New rank_by_monthday by: Scott Wales
Rewritten for dask by: Christian Stassen
Date: June, 2020
:param arr: xarray dataset
:param ndays: scaling parameter e.g. avaraging over 30 days to achieve an 1-month time scale -> SPI1
"""
# do rolling mean with leap years
print('Calculating rolling sum...')
with xr.set_options(keep_attrs=True):
# arr = arr.rolling(time=ndays, center=False).sum().dropna(dim='time',how='any')
arr = arr.rolling(time=ndays, center=False).sum(skipna=False)#.isel(time=slice(ndays-1,None))
# kick out leap years
print('Remove leap days...')
arr = arr.sel(time=~((arr.time.dt.month == 2) & (arr.time.dt.day == 29)))
arr = arr.chunk({'time':None, 'lat':20, 'lon':20})
### new
print(arr)
print('Save temp file to zarr...')
arr.to_dataset(name='precip').to_zarr(os.environ['PBS_JOBFS'] + '/data.zarr',mode='w')
print('Opening temp file from zarr...')
arr = xr.open_zarr(os.environ['PBS_JOBFS'] + '/data.zarr').precip
print(arr)
###
print('Ranking by day of year...')
# arr = arr.chunk({'time':None, 'lat':10, 'lon':10})
rank = rank_by_monthday(arr)
rank = rank.drop('monthday') # drop monthday dimension
# Define sample size wrt length of climatology + year in question
N = climend - climstart + 1
den = N+0.33
# Empirical Tukey plotting position (Wilks 2011)
print('Calculating P...')
P = (rank - 0.33)/den
print('Calculating W...')
W1 = P.where(P>0.5,xr.ufuncs.sqrt(-2. * xr.ufuncs.log(P))) # Contains the value we want where P<=0.5, else contains P
W = W1.where(P<=0.5,xr.ufuncs.sqrt(-2. * xr.ufuncs.log(1-P))) # We take W1 values where P<=0.5, else the alternate calculation.
# Inverse normal approximation (Vincente-Serrano et al 2010)
print('Calculating inverse normal approximation...')
# normalisation parameters
C0 = 2.515517
C1 = 0.802853
C2 = 0.010328
d1 = 1.432788
d2 = 0.189269
d3 = 0.001308
pr_est1 = P.where(P>0.5, -1. * (W - (C0 + C1 * W + C2 * W**2.) / (1. + d1 * W + d2 * W**2. + d3 * W**3.)))
pr_est = pr_est1.where(P<=0.5, W - (C0 + C1 * W + C2 * W**2.) / (1. + d1 * W + d2 * W**2. + d3 * W**3.))
pr_est = pr_est.to_dataset(name='SPI_{}d'.format(ndays))
#pr_est = xr.Dataset({'SPI_{}d'.format(ndays): (('time', 'lat','lon'), pr_est)},{'time': pr['time'], 'lat': pr['lat'], 'lon': pr['lon']})
return pr_est
def main(**kwargs):
print(kwargs)
# scaling parameter e.g. avaraging over 30 days to achieve an 1-month time scale -> SPI1
ndays = int(kwargs['ndays'])
## directories
odir = '/g/data/w35/dh4185/data/AWAP/SPI/'
# odir = '/Users/david/Documents/MonashUniverityPhD/Data/'
ofile = odir+'SPI_awap_agg{}d_{}-{}.nc'.format(ndays,'{}','{}') #Keep two dummys empty for now
# ofile = odir+'SPI_awap_agg{}d_{}-{}encode2.nc'.format(ndays,'{}','{}') #Keep two dummys empty for now
ifile = ['/g/data/w35/dh4185/data/AWAP/precip/precip_total_0.05_{}.nc'.format(i) for i in range(1975,2018)]
# ifile = odir+'pr_dummy.nc'
#< read data
print('Read data...')
# ds_in = xr.open_dataset(ifile)
ds_in = xr.open_mfdataset(sorted(ifile),combine='nested', concat_dim='time')#.sel(lat=slice(-40,-20))
# ds_in = xr.open_mfdataset(sorted(glob(ifile)),combine='nested', concat_dim='time')#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5))#, combine='nested', concat_dim='time')#,chunks={'time':None, 'lat':1, 'lon':1})#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5))#.sel(time=slice('1975', '2018'))#.sel(lat=slice(-35,-34.5),lon=slice(144,144.5))
da = ds_in.chunk({'lat': 40, 'lon': 40}).precip
da = da.sortby(da.lat)
print('Opening dataset of {}'.format(getSizeGB(da)))
print(da)
#< Get calibration period (default to total length of time series)
if 'climstart' in kwargs:
climstart = int(kwargs['climstart'])
else:
climstart = int(pd.to_datetime(str(da.time[0].values)).year)
if 'climend' in kwargs:
climend = int(kwargs['climend'])
else:
climend = int(pd.to_datetime(str(da.time[-1].values)).year)
#< Calculate SPI
spi = SPI(da, ndays, climstart, climend)
print(spi)
#< Convert metpy/pint unit to string for writing
# saver = spi.to_netcdf(ofile.format(climstart,climend), compute=False, encoding={'SPI_30d': {'chunksizes': spi['SPI_30d'].data.chunksize, 'zlib': True, 'complevel': 4, 'shuffle': True}})
saver = spi.to_netcdf(ofile.format(climstart,climend), compute=False, encoding={'SPI_30d': {'chunksizes': spi['SPI_30d'].data.chunksize, 'shuffle': True}})
future = client.persist(saver)
dask.distributed.progress(future)
future.compute()
print('')
if __name__ == '__main__':
import dask.distributed
import sys
## Get the number of CPUS in the job and start a dask.distributed cluster
# mem = 190
# cores = int(os.environ.get('PBS_NCPUS','4'))
# memory_limit = '{}gb'.format(int(max(mem/cores, 4)))
# client = dask.distributed.Client(n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp())
client = climtas.nci.GadiClient()#n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp())
#< Print client summary
print('### Client summary')
print(client)
print('\n\n')
#< Call the main function
main(**dict(arg.split('=') for arg in sys.argv[1:]))
#< Close the client
client.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment