Skip to content

Instantly share code, notes, and snippets.

@lebedov
Created June 16, 2015 18:35
Show Gist options
  • Save lebedov/9fa8b5a02a0e764cd40c to your computer and use it in GitHub Desktop.
Save lebedov/9fa8b5a02a0e764cd40c to your computer and use it in GitHub Desktop.
Find permutation of matrix that maximizes its trace using the Munkres algorithm.
#!/usr/bin/env python
"""
Find permutation of matrix that maximizes its trace using the Munkres algorithm.
Reference
---------
https://stat.ethz.ch/pipermail/r-help/2010-April/236664.html
"""
import itertools
import sys
import munkres
import numpy as np
def permute_cols(a, inds):
"""
Permutes the columns of matrix `a` given
a list of tuples `inds` whose elements `(from, to)` describe how columns
should be permuted.
"""
p = np.zeros_like(a)
for i in inds:
p[i] = 1
return np.dot(a, p)
def maximize_trace(a):
"""
Maximize trace by minimizing the Frobenius norm of
`np.dot(p, a)-np.eye(a.shape[0])`, where `a` is square and
`p` is a permutation matrix. Returns permuted version of `a` with
maximal trace.
"""
assert a.shape[0] == a.shape[1]
d = np.zeros_like(a)
n = a.shape[0]
b = np.eye(n, dtype=int)
for i, j in itertools.product(xrange(n), xrange(n)):
d[j, i] = sum((b[j, :]-a[i, :])**2)
m = munkres.Munkres()
inds = m.compute(d)
return permute_cols(a, inds)
if __name__ == '__main__':
n = 6
a = np.random.randint(0, 10, n**2).reshape(n, n)
print 'original: '
print a
print 'trace: %d' % a.trace()
ap = maximize_trace(a)
print 'permuted: '
print ap
print 'trace: %d' % ap.trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment