Created
May 17, 2019 05:13
-
-
Save ScottWales/aa04e4bf4b4c29f1c5ce048c59562ba4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray | |
import dask | |
def least_sq_correlation(da, test): | |
""" | |
Calculate the least squares correlation against multiple | |
locations simultaneously | |
Args: | |
da (xarray.DataArray): Input data field | |
test: 1D Timeseries to test all locations against | |
Returns: | |
xarray.DataArray with correlation coefficients (rvalue) at | |
each location | |
Time should be the first axis of 'da', and this axis should be | |
the same size as the 'test' timeseries | |
Implementation is copied from scipy.stats.linregress, modified | |
to work on multiple series | |
""" | |
assert test.ndim == 1 | |
assert da.shape[0] == test.shape[0] | |
shape = da.shape | |
# Block other dims together | |
arr = da.data.reshape((shape[0], -1)) | |
# Get covariance | |
cov = dask.array.cov(arr.T, test, bias=1) | |
# Get variances and covariance with 'test' | |
ssxm = dask.array.diagonal(cov)[0:-1] | |
ssym = cov[-1,-1] | |
ssxym = cov[-1,0:-1] | |
# Compute 'r' value | |
r_num = ssxym | |
r_den = np.sqrt(ssxm * ssym) | |
r = r_num / r_den | |
# Return to original shape | |
r = r.reshape(shape[1:]) | |
# Add xarray metadata from the input dataarray | |
da_out = xarray.DataArray(r, dims=da.dims[1:], coords={k:v for k,v in da.coords.items() if k in da.dims[1:]}) | |
return da_out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks Scott for the help.