Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active June 26, 2021 22:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/efd61b63acdf6ead38a361a18fad27c8 to your computer and use it in GitHub Desktop.
Save ahwillia/efd61b63acdf6ead38a361a18fad27c8 to your computer and use it in GitHub Desktop.
Greedy heuristic for finding K-permutations that match a set of K matrices
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.utils import check_random_state
import scipy.sparse
def perm_alignment(X, Y):
"""
Given two matrix X and Y. Returns sparse matrix P, holding permutation
matrix that minimizes norm(X @ P - Y).
Parameters
----------
X : ndarray, m x n matrix
Y : ndarray, m x n matrix
Returns
-------
P : scipy.sparse.csr_matrix, n x n permutation matrix
cost : float, minimized objective function
"""
XtY = X.T @ Y
ri, ci = linear_sum_assignment(XtY, maximize=True)
n = ri.size
P = scipy.sparse.csr_matrix(
(np.ones(n), (ri, ci)), shape=(n, n)
)
cost = -XtY[ri, ci].sum()
return P, cost
def multi_perm_alignment(Xs, tol=1e-6, miniter=20, maxiter=100, random_state=None):
"""
Greedy sover for multi-set permutation alignment problem.
Parameters
----------
Xs : list of ndarray, m x n matrices
tol : float, convergence tolerance
Returns
-------
perms : list of
cost_hist : ndarray, vector holding cost over iterations
"""
# Dimensions
nx = len(Xs)
m, n = Xs[0].shape
for i in range(nx):
assert Xs[i].shape[0] == m
assert Xs[i].shape[1] == n
# Initialize random number generator.
rs = check_random_state(random_state)
# Initialize permutations and aligned matrices.
perms = [scipy.sparse.identity(n) for _ in range(nx)]
Xbar = np.copy(Xs[rs.randint(nx)])
# Need at least two iters for convergence check.
miniter = max(2, miniter)
converged = False
cost_hist = []
# Track cost over time.
cost = 0.0
for i in range(nx):
cost += np.linalg.norm(
Xs[i] @ perms[i] - Xbar
) ** 2
cost_hist.append(cost / (np.sqrt(m * n) * nx))
# Optimize permutations.
for itercount in range(maxiter):
# Update each permutation.
for i in range(nx):
perms[i], _ = perm_alignment(Xs[i], Xbar)
# Update barycenter estimate.
Xbar = np.zeros((m, n))
for i in range(nx):
Xbar += Xs[i] @ perms[i]
Xbar /= nx
# Track cost over time.
cost = 0.0
for i in range(nx):
cost += np.linalg.norm(
Xs[i] @ perms[i] - Xbar
) ** 2
cost_hist.append(cost / (np.sqrt(m * n) * nx))
# Check convergence.
if itercount > miniter:
if (cost_hist[-2] - cost_hist[-1]) < tol:
converged = True
# Break loop if converged.
if converged:
break
return perms, cost_hist
# DEMO
if __name__ == "__main__":
# Create synthetic data.
m, n, nx = 15, 10, 50
noise_stddev = 1.0
rs = np.random.RandomState(1234)
Xbar = rs.randn(m, n)
true_perms = [rs.permutation(n) for _ in range(nx)]
Xs = []
for p in true_perms:
Xs.append(Xbar[:, p] + rs.randn(m, n) * noise_stddev)
# Fit permutations.
perms, cost_hist = multi_perm_alignment(Xs)
# Plot results.
import matplotlib.pyplot as plt
plt.plot(cost_hist)
plt.xlabel("iterations")
plt.ylabel("cost")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment