Created
February 18, 2013 18:16
-
-
Save satojkovic/4979391 to your computer and use it in GitHub Desktop.
Earth Mover's Distance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#-*- encoding: utf-8 -*- | |
# | |
# Earth Mover's Distance | |
# | |
# Reference: http://aidiary.hatenablog.com/entry/20120804/1344058475 | |
# | |
import numpy as np | |
import rpy2.robjects as robjects | |
# import lp.transport(R) | |
robjects.r['library']('lpSolve') | |
transport = robjects.r['lp.transport'] | |
def euclid_dist(f_p, f_q): | |
""" | |
euclid distance | |
""" | |
if len(f_p) != len(f_q): | |
print "Error: calc euclid_dist %d <=> %d" %( len(f_p), len(f_q) ) | |
return -1 | |
return np.sqrt(np.sum( (f_p - f_q)**2 )) | |
def earth_movers_distance(dist, w_p, w_q): | |
""" | |
earth mover's distance by robjects(lpSovle::lp.transport) | |
""" | |
# distance vector to distance matrix | |
costs = robjects.r['matrix'](robjects.FloatVector(dist), nrow=len(w_p), ncol=len(w_q), | |
byrow=True) | |
row_signs = ["<"] * len(w_p) | |
row_rhs = robjects.FloatVector(w_p) | |
col_signs = [">"] * len(w_q) | |
col_rhs = robjects.FloatVector(w_q) | |
t = transport(costs, "min", row_signs, row_rhs, col_signs, col_rhs) | |
flow = t.rx2('solution') | |
dist = dist.reshape(len(w_p), len(w_q)) | |
flow = np.array(flow) | |
work = np.sum(flow * dist) | |
emd = work / np.sum(flow) | |
return emd | |
def set_test_data(): | |
# features | |
f_p = np.array([ [100, 40, 22], [211, 20, 2], [32, 190, 150], [2, 100, 100] ]) | |
f_q = np.array([ [0, 0, 0], [50, 100, 80], [255, 255, 255] ]) | |
# weights | |
w_p = np.array( [4, 3, 2, 1] ) | |
w_q = np.array( [5, 3, 2] ) | |
return f_p, f_q, w_p, w_q | |
def main(): | |
# test data | |
f_p, f_q, w_p, w_q = set_test_data() | |
# distance vector | |
n_p = len(f_p) | |
n_q = len(f_q) | |
dist = np.zeros(n_p * n_q) | |
for i in range(n_p): | |
for j in range(n_q): | |
dist[i * n_q + j] = euclid_dist(f_p[i], f_q[j]) | |
# EMD | |
emd = earth_movers_distance(dist, w_p, w_q) | |
print "Earth Mover's Distance : %f" % emd | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment