Skip to content

Instantly share code, notes, and snippets.

@clane9
Last active November 8, 2022 19:41
Show Gist options
  • Save clane9/8908c06ed5c391e08d4e3a0081e1131a to your computer and use it in GitHub Desktop.
Save clane9/8908c06ed5c391e08d4e3a0081e1131a to your computer and use it in GitHub Desktop.
Correlation ratio cost function
import numpy as np
def corr_ratio(x: np.ndarray, y: np.ndarray, bins: int = 256) -> float:
"""
Flirt correlation ratio cost function between `x` and `y`. Measures the variance
in `y` over each iso-set of `x`.
See [Jenkinson, NeuroImage 2002](https://doi.org/10.1006/nimg.2002.1132),
Table 1 for the definition. Also [here](https://www.fmrib.ox.ac.uk/datasets/techrep/tr02mj1/tr02mj1/node4.html).
Note: The number of elements in `x` and `y` should be sufficiently large compared to the number of bins,
e.g. maybe >10x. This ensures that the iso-sets of `x` can be well estimated.
"""
assert x.shape == y.shape, "x and y expected to have same shape"
_, edges = np.histogram(x, bins=bins)
count, cost = 0, 0.0
for ii in range(bins):
left, right = edges[ii: ii + 2]
mask = (x >= left) & (x < right)
iso_count = mask.sum()
if iso_count > 1:
y_iso_var = np.var(y[mask])
cost = cost + iso_count * y_iso_var
count = count + iso_count
cost = cost / (np.var(y) * count)
return cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment