Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active September 4, 2023 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/0b9c761ba64811eaefab11a88caeaad0 to your computer and use it in GitHub Desktop.
Save ahwillia/0b9c761ba64811eaefab11a88caeaad0 to your computer and use it in GitHub Desktop.
L2 Metric Repair
import numba
import numpy as np
from scipy.spatial.distance import pdist, squareform
from math import comb
@numba.jit(nopython=True)
def index(n, i, j):
"""
Computes linear index of (i, j) from the (n x n) distance matrix.
"""
if j > i:
return (j - i) + (i * ((n - 1) + (n - i)) // 2) - 1
else:
return (i - j) + (j * ((n - 1) + (n - j)) // 2) - 1
@numba.jit(nopython=True)
def alg31(n, d, e, z):
"""
Implements Algorithm 3.1 from Brickell et al. (2008), passing over
all triangles once.
Brickell, J., Dhillon, I. S., Sra, S., & Tropp, J. A. (2008).
The metric nearness problem. SIAM Journal on Matrix Analysis and
Applications, 30(1), 375-396.
"""
u = 0
tt = 0
# Iterate of elements (i, j) of the distance matrix.
for i in range(n):
for j in range(i + 1, n):
# Compute linear index.
ij = index(n, i, j)
# Iterate over (i, k, j) for k not in (i, j).
for k in range(n):
if (i != k) and (j != k):
# Get linear index.
ik = index(n, i, k)
kj = index(n, k, j)
# Compute update
v = d[ik] + d[kj] - d[ij]
ts = (e[ij] - e[ik] - e[kj] - v) / 3
t = max(ts, -z[u])
# Keep running total of updates.
tt += abs(t)
# Apply update
e[ij] -= t
e[ik] += t
e[kj] += t
z[u] += t
u += 1
return tt
@numba.jit(nopython=True)
def max_violation(D):
"""
Returns the worst triangle inequality violation over all directed triplets
in an (n x n) distance matrix. Negative numbers indicate a triangle
inequality violation.
"""
n = D.shape[0]
v = np.inf
for i in range(n):
for j in range(i + 1, n):
for k in range(n):
v = min(v, D[i, k] + D[k, j] - D[i, j])
return v
def metric_repair(D, num_iters=10, verbose=True):
n = D.shape[0]
idx = np.triu_indices_from(D, 1)
d = D[idx]
e = np.zeros_like(d)
z = np.zeros(3 * comb(n, 3))
for it in range(num_iters):
tt = alg31(n, d, e, z)
if verbose:
print("Param update:", tt)
D_clean = np.zeros_like(D)
D_clean[idx] = d + e
D_clean += D_clean.T
return D_clean
if __name__ == "__main__":
print("Test on metric data (no repair needed...)")
X = np.random.randn(6, 3)
D = squareform(pdist(X, 'euclidean'))
D_clean = metric_repair(D, num_iters=3, verbose=True)
print("FINAL VIOLATION:", max_violation(D_clean))
print("\n\n\nTest on non-metric data...)")
D = squareform(pdist(X, 'sqeuclidean'))
D_clean = metric_repair(D, num_iters=20, verbose=True)
max_violation(D_clean)
print("FINAL VIOLATION:", max_violation(D_clean))
@ahwillia
Copy link
Author

ahwillia commented Sep 4, 2023

Code Output:

Test on metric data (no repair needed...)
Param update: 0.0
Param update: 0.0
Param update: 0.0
FINAL VIOLATION: 0.0



Test on non-metric data...)
Param update: 8.113355037136897
Param update: 2.5888001414176234
Param update: 0.9454095000947543
Param update: 0.08793189750970375
Param update: 0.013513393376006747
Param update: 0.0029870137438086823
Param update: 0.0011071825408025353
Param update: 0.0003461925209929384
Param update: 0.00012207343859651132
Param update: 4.093589726121098e-05
Param update: 1.399067995592566e-05
Param update: 4.751838014220115e-06
Param update: 1.6164266368560532e-06
Param update: 5.497768184037355e-07
Param update: 1.869639359626992e-07
Param update: 6.358938292594374e-08
Param update: 2.1626080666988418e-08
Param update: 7.35509771606265e-09
Param update: 2.501444199924189e-09
Param update: 8.507408481387795e-10
FINAL VIOLATION: -1.73033143369139e-10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment