Skip to content

Instantly share code, notes, and snippets.

@kuchaale
Last active April 24, 2019 12:50
Show Gist options
  • Save kuchaale/7c279afb46e52ed87f349b8e4177126f to your computer and use it in GitHub Desktop.
Save kuchaale/7c279afb46e52ed87f349b8e4177126f to your computer and use it in GitHub Desktop.
Calculates trends across variables
import xarray as xr
import dask.array as da
from dask.delayed import delayed
import numpy as np
from scipy import stats
# regression function defition
def regression(y):
"""apply linear regression function along time axis"""
axis_num = y.get_axis_num('time')
return da.apply_along_axis(_calc_slope, axis_num, y)
def _calc_slope(y):
"""return linear regression statistical variables"""
x = np.arange(len(y))
return stats.linregress(x, y)
# start and ending year definition
syear1 = 1980
eyear1 = 1997
syear2 = 1998
eyear2 = 2015
# select analysed variables
var_ls = ['O3', 'T', 'H']
# open file as xarray.Dataset
ifile = '/mnt/1data/trendy/prog/tmp/00AA.l187.nc' # input name definition
ds = xr.open_dataset(ifile, chunks={'lat': 75})
# select particular period
per1 = slice(str(syear1),str(eyear1))
per2 = slice(str(syear2),str(eyear2))
data_per1 = ds.sel(time=per1)
data_per2 = ds.sel(time=per2)
# regression analysis
delayed_objs = [delayed(regression)(delayed(data_per1[var])).persist() \
for var in var_ls]
delayed_objs2 = [delayed(regression)(delayed(data_per2[var])).persist() \
for var in var_ls]
results_per1 = da.compute(*delayed_objs) # transforms dask.delayed to dask.array
results_per2 = da.compute(*delayed_objs2)
# statistical variables definition
variables = ['slope', 'intercept', 'r_value', 'p_value', 'std_err']
# coordination definition
coords = {'period': ['{sl.start}-{sl.stop}'.format(sl = per1),\
'{sl.start}-{sl.stop}'.format(sl = per2)], \
'stats': variables, 'lev': ds.lev, 'lat': ds.lat, 'lon': ds.lon}
# output xarray.Dataset definition
ds_out = xr.Dataset({'{}_trend_stats'.format(var_ls[i]): \
(['period', 'stats', 'lev', 'lat', 'lon'],\
da.stack([results_per1[i], results_per2[i]])) for i in range(len(var_ls))}, \
coords = coords)
# save xarray.Dataset
out_file = 'test_t4p_delayed3.nc' # output name definition
ds_out.to_netcdf(out_file) # save to NetCDF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment