Last active
April 18, 2023 03:07
-
-
Save mjhong0708/7554bccbfa343fcd70eb71c22c92f8c0 to your computer and use it in GitHub Desktop.
Farthest point sampling (FPS) with numpy using arbitrary metric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import njit | |
from typing import Any, Callable, Sequence, TypeVar | |
Point_T = TypeVar("T") | |
Metric = Callable[[np.ndarray, np.ndarray], float] | |
@njit | |
def euclidean_distance(x: np.ndarray, y: np.ndarray) -> float: | |
dR = x - y | |
dR2 = dR * dR | |
dist = np.sqrt(np.sum(dR2)) | |
return dist | |
def farthest_point_sampling( | |
points: Sequence[Point_T], | |
num_samples: int, | |
metric: Metric = euclidean_distance | |
) -> Sequence[Point_T]: | |
if num_samples > len(points): | |
raise ValueError("Number of samples must be less than or equal to the number of points") | |
# Initialize the result list with the first point | |
p0 = points[0] | |
sampled_points = [p0] | |
distance_matrix = np.array([metric(p0, pi) for pi in points]) | |
for _ in range(num_samples - 1): | |
farthest_point_index = np.argmax(distance_matrix) | |
farthest_point = points[farthest_point_index] | |
sampled_points.append(farthest_point) | |
for i, point in enumerate(points): | |
p_f = points[farthest_point_index] | |
p_i = points[i] | |
distance_to_new_point = metric(p_f, p_i) | |
distance_matrix[i] = min(distance_matrix[i], distance_to_new_point) | |
return sampled_points | |
points = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]]) | |
num_samples = 5 | |
sampled_points = farthest_point_sampling(points, num_samples) | |
print("Sampled points:", sampled_points) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage