Skip to content

Instantly share code, notes, and snippets.

@dahoal
Created July 17, 2020 04:00
Show Gist options
  • Save dahoal/6e819b2d364290511f38781954f7535f to your computer and use it in GitHub Desktop.
Save dahoal/6e819b2d364290511f38781954f7535f to your computer and use it in GitHub Desktop.
detect flash drought events and their length
import matplotlib.pyplot as plt
import xarray as xr
import tempfile
import dask
import regionmask
import os
import numpy as np
from glob import glob
import pandas as pd
# from numba import jit # Speedup for python functions, doesn't work properly yet
################# flash drought detection function
@jit(nogil=True)
def find_fd1D_mask(array, q_up, q_low, q_end, dtime): #, verbose=False):
"""
Use masks to determine the indices of times that match the
initial criteria, and only loop through a list of those positions
Written by Aidan Heerdegen (CMS Help Team)
"""
indices = np.arange(len(array))
# Return indices where wet and moist are true
moist = np.delete(np.where(array>=q_up,indices,-1),np.where(np.where(array>=q_up,indices,-1)==-1))
wet = np.delete(np.where(array>=q_end,indices,-1),np.where(np.where(array>=q_end,indices,-1)==-1))
duration = np.zeros_like(array)
# Cycle over all wet indices
for ind, windex in enumerate(wet):
# Check for dry spell within dtime
count=0
for i in range(windex+1,min(windex+dtime+1,len(array))):
if array[i] >= q_end:
# Found another wet cell before a dry cel
break
if array[i] < q_low:
for j, loc in enumerate(moist):
if loc > i:
duration[i] = (loc-windex-count)-1
break
# Delete all moist indices up to this point
# so they don't need to be iterated again
moist = np.delete(moist,j)
break
# Count those instances where value is between q_low and q_up
count+=1
return duration
### function adjusted for inverted EDDI
def find_fd1D_mask_eddi(array, q_up, q_low, q_end, dtime):
"""
Use masks to determine the indices of times that match the
initial criteria, and only loop through a list of those positions
"""
indices = np.arange(len(array))
# Return indices where wet and moist are true
moist = np.delete(np.where(array<=q_up,indices,-1),np.where(np.where(array<=q_up,indices,-1)==-1))
wet = np.delete(np.where(array<=q_end,indices,-1),np.where(np.where(array<=q_end,indices,-1)==-1))
duration = np.zeros_like(array)
# Cycle over all wet indices
for ind, windex in enumerate(wet):
# Check for dry spell within dtime
count=0
for i in range(windex+1,min(windex+dtime+1,len(array))):
if array[i] <= q_end:
# Found another wet cell before a dry cell
break
if array[i] > q_low:
for j, loc in enumerate(moist):
if loc > i:
duration[i] = (loc-windex-count)-1
break
# Delete all moist indices up to this point
# so they don't need to be iterated again
moist = np.delete(moist,j)
break
# Count those instances where value is between q_low and q_up
count+=1
return duration
# function wrapper
def detect_fd(arr, q_up, q_low, q_end, dtime=20, index=None):
"""
Find flash droughts (FD) in xarray.
Wraps function detect_fd_1d for n-dimensions
:param arr1d: 1d-xarray input
:param q_up: upper quantile threshold
:param q_low: lower quantile threshold for drought to be indentified
:param q_end: threshold for drought to be broken
:param dtime: maximum number of time steps between q_up and q_low
:return: xarray containing true for flash drought conditions
"""
if index=='EDDI':
return xr.apply_ufunc(find_fd1D_mask_eddi, arr, q_up,
q_low,
q_end,
dtime,
input_core_dims=[['time'],[],[],[],[]],
output_core_dims=[['time']],
vectorize=True,
dask="parallelized",
output_dtypes=[arr.dtype])
else:
return xr.apply_ufunc(find_fd1D_mask, arr, q_up,
q_low,
q_end,
dtime,
input_core_dims=[['time'],[],[],[],[]],
output_core_dims=[['time']],
vectorize=True,
dask="parallelized",
output_dtypes=[arr.dtype])
def main(**kwargs):
print(kwargs)
index = str(kwargs['index'])
dtime = int(kwargs['dtime'])
prctl_dr = float(kwrgs['prctl_dr'])
prctl_norm = float(kwrgs['prctl_norm'])
################# directories
odir = '/g/data/w35/dh4185/data/AWAP/fd_count/highres/'
ofile = '{}{}_awap_FDcount_prctl{}to{}in{}d.nc'.format(odir,index,prctl_norm,prctl_dr,dtime)
if index[:4]=='EDDI':
folder='EDDI'
else:
folder=index
ifile = '/g/data/w35/dh4185/data/AWAP/{}/{}_awap_agg30d_1975-2018.nc'.format(folder,index)
################# get file and calculate percentiles
arr = xr.open_dataset(ifile).sel(time=slice('2000','2000')) #lat=slice(-37,-33),lon=slice(135,145))
arr = arr.chunk({'time':None, 'lat':50, 'lon':50})
if index[:4]=='EDDI':
arr = arr['{}_30d'.format('EDDI')]
else:
arr = arr['{}_30d'.format(index)]
print(arr)
print('Calculate quantiles...')
q = arr.quantile([prctl_dr, prctl_norm],'time')
q10 = q.sel(quantile=prctl_dr)
q40 = q.sel(quantile=prctl_norm)
if index=='SPI':
# get land/sea mask -> use to ignore ocean values?
ocean = regionmask.defined_regions.natural_earth.land_110.mask(arr,lon_name='lon',lat_name='lat')#.sel(lat=slice(-35,-33),lon=slice(140,142))
else:
ocean = regionmask.defined_regions.natural_earth.land_110.mask(arr,lon_name='longitude',lat_name='latitude')#.sel(lat=slice(-35,-33),lon=slice(140,142))
lsmask = ~(np.isnan(ocean))
################# run function
fd = detect_fd(arr, q_up=q40, q_low=q10, q_end=q40, dtime=dtime,index=index[:4])
# convert to dataset and save as netcdf
fd = fd.to_dataset(name='FDcount').where(lsmask).to_netcdf(ofile, compute=False, encoding={'FD_count': {'chunksizes': fd['FD_count'].data.chunksize, 'shuffle': True}})
# Add a progress bar and do the calculations
fd = client.persist(fd)
dask.distributed.progress(fd)
fd.compute()
print('')
if __name__ == '__main__':
################# set up client
import dask.distributed
import sys
## Get the number of CPUS in the job and start a dask.distributed cluster
# mem = 190
# cores = int(os.environ.get('PBS_NCPUS','4'))
# memory_limit = '{}gb'.format(int(max(mem/cores, 4)))
# client = dask.distributed.Client(n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp())
client = climtas.nci.GadiClient()#n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp())
#< Print client summary
print('### Client summary')
print(client)
print('\n\n')
#< Call the main function
main(**dict(arg.split('=') for arg in sys.argv[1:]))
#< Close the client
client.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment