Skip to content

Instantly share code, notes, and snippets.

@satojkovic
Created February 18, 2013 18:16
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save satojkovic/4979391 to your computer and use it in GitHub Desktop.
Save satojkovic/4979391 to your computer and use it in GitHub Desktop.
Earth Mover's Distance
#-*- encoding: utf-8 -*-
#
# Earth Mover's Distance
#
# Reference: http://aidiary.hatenablog.com/entry/20120804/1344058475
#
import numpy as np
import rpy2.robjects as robjects
# import lp.transport(R)
robjects.r['library']('lpSolve')
transport = robjects.r['lp.transport']
def euclid_dist(f_p, f_q):
"""
euclid distance
"""
if len(f_p) != len(f_q):
print "Error: calc euclid_dist %d <=> %d" %( len(f_p), len(f_q) )
return -1
return np.sqrt(np.sum( (f_p - f_q)**2 ))
def earth_movers_distance(dist, w_p, w_q):
"""
earth mover's distance by robjects(lpSovle::lp.transport)
"""
# distance vector to distance matrix
costs = robjects.r['matrix'](robjects.FloatVector(dist), nrow=len(w_p), ncol=len(w_q),
byrow=True)
row_signs = ["<"] * len(w_p)
row_rhs = robjects.FloatVector(w_p)
col_signs = [">"] * len(w_q)
col_rhs = robjects.FloatVector(w_q)
t = transport(costs, "min", row_signs, row_rhs, col_signs, col_rhs)
flow = t.rx2('solution')
dist = dist.reshape(len(w_p), len(w_q))
flow = np.array(flow)
work = np.sum(flow * dist)
emd = work / np.sum(flow)
return emd
def set_test_data():
# features
f_p = np.array([ [100, 40, 22], [211, 20, 2], [32, 190, 150], [2, 100, 100] ])
f_q = np.array([ [0, 0, 0], [50, 100, 80], [255, 255, 255] ])
# weights
w_p = np.array( [4, 3, 2, 1] )
w_q = np.array( [5, 3, 2] )
return f_p, f_q, w_p, w_q
def main():
# test data
f_p, f_q, w_p, w_q = set_test_data()
# distance vector
n_p = len(f_p)
n_q = len(f_q)
dist = np.zeros(n_p * n_q)
for i in range(n_p):
for j in range(n_q):
dist[i * n_q + j] = euclid_dist(f_p[i], f_q[j])
# EMD
emd = earth_movers_distance(dist, w_p, w_q)
print "Earth Mover's Distance : %f" % emd
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment