-
-
Save walsvid/6f7d657a5d2be85c9b24a48b39ab6290 to your computer and use it in GitHub Desktop.
Vanilla Chamfer distance computation in NumPy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.neighbors import NearestNeighbors | |
def chamfer_distance(x, y, metric='l2', direction='bi'): | |
"""Chamfer distance between two point clouds | |
Parameters | |
---------- | |
x: numpy array [n_points_x, n_dims] | |
first point cloud | |
y: numpy array [n_points_y, n_dims] | |
second point cloud | |
metric: string or callable, default ‘l2’ | |
metric to use for distance computation. Any metric from scikit-learn or scipy.spatial.distance can be used. | |
direction: str | |
direction of Chamfer distance. | |
'y_to_x': computes average minimal distance from every point in y to x | |
'x_to_y': computes average minimal distance from every point in x to y | |
'bi': compute both | |
Returns | |
------- | |
chamfer_dist: float | |
computed bidirectional Chamfer distance: | |
sum_{x_i \in x}{\min_{y_j \in y}{||x_i-y_j||**2}} + sum_{y_j \in y}{\min_{x_i \in x}{||x_i-y_j||**2}} | |
""" | |
if direction=='y_to_x': | |
x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x) | |
min_y_to_x = x_nn.kneighbors(y)[0] | |
chamfer_dist = np.mean(min_y_to_x) | |
elif direction=='x_to_y': | |
y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y) | |
min_x_to_y = y_nn.kneighbors(x)[0] | |
chamfer_dist = np.mean(min_x_to_y) | |
elif direction=='bi': | |
x_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(x) | |
min_y_to_x = x_nn.kneighbors(y)[0] | |
y_nn = NearestNeighbors(n_neighbors=1, leaf_size=1, algorithm='kd_tree', metric=metric).fit(y) | |
min_x_to_y = y_nn.kneighbors(x)[0] | |
chamfer_dist = np.mean(min_y_to_x) + np.mean(min_x_to_y) | |
else: | |
raise ValueError("Invalid direction type. Supported types: \'y_x\', \'x_y\', \'bi\'") | |
return chamfer_dist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment