Created
July 17, 2020 04:00
-
-
Save dahoal/6e819b2d364290511f38781954f7535f to your computer and use it in GitHub Desktop.
detect flash drought events and their length
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import xarray as xr | |
import tempfile | |
import dask | |
import regionmask | |
import os | |
import numpy as np | |
from glob import glob | |
import pandas as pd | |
# from numba import jit # Speedup for python functions, doesn't work properly yet | |
################# flash drought detection function | |
@jit(nogil=True) | |
def find_fd1D_mask(array, q_up, q_low, q_end, dtime): #, verbose=False): | |
""" | |
Use masks to determine the indices of times that match the | |
initial criteria, and only loop through a list of those positions | |
Written by Aidan Heerdegen (CMS Help Team) | |
""" | |
indices = np.arange(len(array)) | |
# Return indices where wet and moist are true | |
moist = np.delete(np.where(array>=q_up,indices,-1),np.where(np.where(array>=q_up,indices,-1)==-1)) | |
wet = np.delete(np.where(array>=q_end,indices,-1),np.where(np.where(array>=q_end,indices,-1)==-1)) | |
duration = np.zeros_like(array) | |
# Cycle over all wet indices | |
for ind, windex in enumerate(wet): | |
# Check for dry spell within dtime | |
count=0 | |
for i in range(windex+1,min(windex+dtime+1,len(array))): | |
if array[i] >= q_end: | |
# Found another wet cell before a dry cel | |
break | |
if array[i] < q_low: | |
for j, loc in enumerate(moist): | |
if loc > i: | |
duration[i] = (loc-windex-count)-1 | |
break | |
# Delete all moist indices up to this point | |
# so they don't need to be iterated again | |
moist = np.delete(moist,j) | |
break | |
# Count those instances where value is between q_low and q_up | |
count+=1 | |
return duration | |
### function adjusted for inverted EDDI | |
def find_fd1D_mask_eddi(array, q_up, q_low, q_end, dtime): | |
""" | |
Use masks to determine the indices of times that match the | |
initial criteria, and only loop through a list of those positions | |
""" | |
indices = np.arange(len(array)) | |
# Return indices where wet and moist are true | |
moist = np.delete(np.where(array<=q_up,indices,-1),np.where(np.where(array<=q_up,indices,-1)==-1)) | |
wet = np.delete(np.where(array<=q_end,indices,-1),np.where(np.where(array<=q_end,indices,-1)==-1)) | |
duration = np.zeros_like(array) | |
# Cycle over all wet indices | |
for ind, windex in enumerate(wet): | |
# Check for dry spell within dtime | |
count=0 | |
for i in range(windex+1,min(windex+dtime+1,len(array))): | |
if array[i] <= q_end: | |
# Found another wet cell before a dry cell | |
break | |
if array[i] > q_low: | |
for j, loc in enumerate(moist): | |
if loc > i: | |
duration[i] = (loc-windex-count)-1 | |
break | |
# Delete all moist indices up to this point | |
# so they don't need to be iterated again | |
moist = np.delete(moist,j) | |
break | |
# Count those instances where value is between q_low and q_up | |
count+=1 | |
return duration | |
# function wrapper | |
def detect_fd(arr, q_up, q_low, q_end, dtime=20, index=None): | |
""" | |
Find flash droughts (FD) in xarray. | |
Wraps function detect_fd_1d for n-dimensions | |
:param arr1d: 1d-xarray input | |
:param q_up: upper quantile threshold | |
:param q_low: lower quantile threshold for drought to be indentified | |
:param q_end: threshold for drought to be broken | |
:param dtime: maximum number of time steps between q_up and q_low | |
:return: xarray containing true for flash drought conditions | |
""" | |
if index=='EDDI': | |
return xr.apply_ufunc(find_fd1D_mask_eddi, arr, q_up, | |
q_low, | |
q_end, | |
dtime, | |
input_core_dims=[['time'],[],[],[],[]], | |
output_core_dims=[['time']], | |
vectorize=True, | |
dask="parallelized", | |
output_dtypes=[arr.dtype]) | |
else: | |
return xr.apply_ufunc(find_fd1D_mask, arr, q_up, | |
q_low, | |
q_end, | |
dtime, | |
input_core_dims=[['time'],[],[],[],[]], | |
output_core_dims=[['time']], | |
vectorize=True, | |
dask="parallelized", | |
output_dtypes=[arr.dtype]) | |
def main(**kwargs): | |
print(kwargs) | |
index = str(kwargs['index']) | |
dtime = int(kwargs['dtime']) | |
prctl_dr = float(kwrgs['prctl_dr']) | |
prctl_norm = float(kwrgs['prctl_norm']) | |
################# directories | |
odir = '/g/data/w35/dh4185/data/AWAP/fd_count/highres/' | |
ofile = '{}{}_awap_FDcount_prctl{}to{}in{}d.nc'.format(odir,index,prctl_norm,prctl_dr,dtime) | |
if index[:4]=='EDDI': | |
folder='EDDI' | |
else: | |
folder=index | |
ifile = '/g/data/w35/dh4185/data/AWAP/{}/{}_awap_agg30d_1975-2018.nc'.format(folder,index) | |
################# get file and calculate percentiles | |
arr = xr.open_dataset(ifile).sel(time=slice('2000','2000')) #lat=slice(-37,-33),lon=slice(135,145)) | |
arr = arr.chunk({'time':None, 'lat':50, 'lon':50}) | |
if index[:4]=='EDDI': | |
arr = arr['{}_30d'.format('EDDI')] | |
else: | |
arr = arr['{}_30d'.format(index)] | |
print(arr) | |
print('Calculate quantiles...') | |
q = arr.quantile([prctl_dr, prctl_norm],'time') | |
q10 = q.sel(quantile=prctl_dr) | |
q40 = q.sel(quantile=prctl_norm) | |
if index=='SPI': | |
# get land/sea mask -> use to ignore ocean values? | |
ocean = regionmask.defined_regions.natural_earth.land_110.mask(arr,lon_name='lon',lat_name='lat')#.sel(lat=slice(-35,-33),lon=slice(140,142)) | |
else: | |
ocean = regionmask.defined_regions.natural_earth.land_110.mask(arr,lon_name='longitude',lat_name='latitude')#.sel(lat=slice(-35,-33),lon=slice(140,142)) | |
lsmask = ~(np.isnan(ocean)) | |
################# run function | |
fd = detect_fd(arr, q_up=q40, q_low=q10, q_end=q40, dtime=dtime,index=index[:4]) | |
# convert to dataset and save as netcdf | |
fd = fd.to_dataset(name='FDcount').where(lsmask).to_netcdf(ofile, compute=False, encoding={'FD_count': {'chunksizes': fd['FD_count'].data.chunksize, 'shuffle': True}}) | |
# Add a progress bar and do the calculations | |
fd = client.persist(fd) | |
dask.distributed.progress(fd) | |
fd.compute() | |
print('') | |
if __name__ == '__main__': | |
################# set up client | |
import dask.distributed | |
import sys | |
## Get the number of CPUS in the job and start a dask.distributed cluster | |
# mem = 190 | |
# cores = int(os.environ.get('PBS_NCPUS','4')) | |
# memory_limit = '{}gb'.format(int(max(mem/cores, 4))) | |
# client = dask.distributed.Client(n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp()) | |
client = climtas.nci.GadiClient()#n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp()) | |
#< Print client summary | |
print('### Client summary') | |
print(client) | |
print('\n\n') | |
#< Call the main function | |
main(**dict(arg.split('=') for arg in sys.argv[1:])) | |
#< Close the client | |
client.shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment