Skip to content

Instantly share code, notes, and snippets.

@dojeda
Created June 8, 2020 09:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dojeda/ad577aeab9e0111ce08aa663392c5359 to your computer and use it in GitHub Desktop.
Save dojeda/ad577aeab9e0111ce08aa663392c5359 to your computer and use it in GitHub Desktop.
Comparison of distance functions on covariance matrices
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from pyriemann.utils.distance import distance_riemann
@partial(np.vectorize, excluded=['ref'])
def distance_riemann_v(c_xx, c_yy, ref):
""" Vectorized version of distance_riemann projected in 2D
Calculates the Riemannian distance between the reference matrix `ref` and
a covariance matrix whose variance terms are `c_xx` and `c_yy`. Note that
this function ignores the covariance term.
Parameters
----------
c_xx: np.array
Variances of the x channel. This array must have shape (n, ).
c_yy: np.array
Variances of the y channel. This array must have shape (n, ).
ref: np.array
Reference covariance matrix. This must be a matrix of size (2, 2).
Returns
-------
float: the Riemannian distance between [[c_xx, 0], [0, c_yy]] and `ref`.
"""
c = np.diag([c_xx, c_yy])
return distance_riemann(c, ref)
@partial(np.vectorize, excluded=['ref'])
def distance_euclid_v(c_xx, c_yy, ref):
""" Vectorized version of Euclidean distance between 2D covariance matrices
Calculates the Euclidean distance between the reference matrix `ref` and
a covariance matrix whose variance terms are `c_xx` and `c_yy`. Note that
this function ignores the covariance term.
Parameters
----------
c_xx: np.array
Variances of the x channel. This array must have shape (n, ).
c_yy: np.array
Variances of the y channel. This array must have shape (n, ).
ref: np.array
Reference covariance matrix. This must be a matrix of size (2, 2).
Returns
-------
float: the Riemannian distance between [[c_xx, 0], [0, c_yy]] and `ref`.
"""
return np.linalg.norm([
c_xx - ref[0, 0],
c_yy - ref[1, 1]
])
def main():
# Reference matrix
C = np.array([
[+3.1, -0.2],
[-0.2, +5.4],
])
# Reference rejection distance
d_reject = 1.5
# X and Y limits for plot
vmin, vmax = 0.01, 30
# Mesh for 2D color/contour plot
x = np.linspace(vmin, vmax, 50)
y = x
X, Y = np.meshgrid(x, y)
# Distance values in mesh
Zr = distance_riemann_v(X, Y, ref=C)
Ze = distance_euclid_v(X, Y, ref=C)
# Plotting
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
cb = axs[0].contourf(X, Y, Zr, levels=20, vmin=0, vmax=Zr.max(), cmap='RdYlBu_r')
axs[0].contour(X, Y, Zr, levels=[d_reject], colors='g')
cbar = fig.colorbar(cb, ax=axs[0])
cbar.ax.set_ylabel('Riemannian distance to reference')
axs[0].set_xlabel('Cxx')
axs[0].set_ylabel('Cyy')
axs[0].set_title('Riemann')
cb = axs[1].contourf(X, Y, Ze, levels=20, vmin=0, vmax=Ze.max(), cmap='RdYlBu_r')
axs[1].contour(X, Y, Ze, levels=[d_reject], colors='g')
cbar = fig.colorbar(cb, ax=axs[1])
cbar.ax.set_ylabel('Euclidean distance to reference')
axs[1].set_xlabel('Cxx')
axs[1].set_ylabel('Cyy')
axs[1].set_title('Euclid')
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment