Skip to content

Instantly share code, notes, and snippets.

@ccarouge
Created May 20, 2019 04:40
Show Gist options
  • Save ccarouge/d0218b0ae3d218150f5fc8d99acf5fd3 to your computer and use it in GitHub Desktop.
Save ccarouge/d0218b0ae3d218150f5fc8d99acf5fd3 to your computer and use it in GitHub Desktop.
multi_correlation.py
def least_sq_correlation(da, test):
"""
Calculate the least squares correlation against multiple
locations simultaneously
Args:
da (xarray.DataArray): Input data field
test: 1D Timeseries to test all locations against
Returns:
xarray.DataArray with correlation coefficients (rvalue) at
each location
Time should be the first axis of 'da', and this axis should be
the same size as the 'test' timeseries
Implementation is copied from scipy.stats.linregress, modified
to work on multiple series
"""
assert test.ndim == 1
assert da.shape[0] == test.shape[0]
shape = da.shape
# Block other dims together
arr = da.data.reshape((shape[0], -1))
# Get covariance
cov = dask.array.cov(arr.T, test, bias=1)
# Get variances and covariance with 'test'
ssxm = dask.array.diagonal(cov)[0:-1]
ssym = cov[-1,-1]
ssxym = cov[-1,0:-1]
# Compute 'r' value
r_num = ssxym
r_den = np.sqrt(ssxm * ssym)
r = r_num / r_den
slope = r_num / ssxm
# Return to original shape
r = r.reshape(shape[1:])
slope = slope.reshape(shape[1:])
# Add xarray metadata from the input dataarray
da_out = xarray.DataArray(r, dims=da.dims[1:], coords={k:v for k,v in da.coords.items() if k in da.dims[1:]})
d_slope = xarray.DataArray(slope,dims=da.dims[1:], coords={k:v for k,v in da.coords.items() if k in da.dims[1:]})
# Concatenate r and slope in same DataArray. Add a stats coordinate and concat along it.
da_out.coords['stats'] = 'rvalue'
d_slope.coords['stats'] = 'slope'
da_out = xarray.concat([da_out, d_slope], 'stats')
return da_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment