Skip to content

Instantly share code, notes, and snippets.

@ScottWales
Created May 17, 2019 05:13
Show Gist options
  • Save ScottWales/aa04e4bf4b4c29f1c5ce048c59562ba4 to your computer and use it in GitHub Desktop.
Save ScottWales/aa04e4bf4b4c29f1c5ce048c59562ba4 to your computer and use it in GitHub Desktop.
import xarray
import dask
def least_sq_correlation(da, test):
"""
Calculate the least squares correlation against multiple
locations simultaneously
Args:
da (xarray.DataArray): Input data field
test: 1D Timeseries to test all locations against
Returns:
xarray.DataArray with correlation coefficients (rvalue) at
each location
Time should be the first axis of 'da', and this axis should be
the same size as the 'test' timeseries
Implementation is copied from scipy.stats.linregress, modified
to work on multiple series
"""
assert test.ndim == 1
assert da.shape[0] == test.shape[0]
shape = da.shape
# Block other dims together
arr = da.data.reshape((shape[0], -1))
# Get covariance
cov = dask.array.cov(arr.T, test, bias=1)
# Get variances and covariance with 'test'
ssxm = dask.array.diagonal(cov)[0:-1]
ssym = cov[-1,-1]
ssxym = cov[-1,0:-1]
# Compute 'r' value
r_num = ssxym
r_den = np.sqrt(ssxm * ssym)
r = r_num / r_den
# Return to original shape
r = r.reshape(shape[1:])
# Add xarray metadata from the input dataarray
da_out = xarray.DataArray(r, dims=da.dims[1:], coords={k:v for k,v in da.coords.items() if k in da.dims[1:]})
return da_out
@shweta121sharma
Copy link

Thanks Scott for the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment