Last active
November 4, 2021 01:03
-
-
Save nmakes/06e495ee353f55f7e63df0660344b88c to your computer and use it in GitHub Desktop.
A generalized RANSAC algorithm to fit any arbitrary function (see Example)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implements the RANSAC Algorithm (https://en.wikipedia.org/wiki/Random_sample_consensus) | |
Written By: Naveen Venkat (naveenvenkat.com) | |
School of Computer Science, Carnegie Mellon University | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
class RANSAC: | |
"""Implements the RANSAC [1] algorithm. | |
Params | |
------ | |
* **data(np.ndarray):** Numpy array of shape (N, [D1, ..., Dm]) containing N m-dimensional data samples. | |
* **fit(function):** A function that takes a subset of `data` (num_points, [D1, ..., Dm]) and returns parameters `params` of the fitted model. | |
* **count_inliers(function):** A function that takes `data` and the fitted estimator `params` and returns the number of inliers. | |
* **num_points(int):** Number of random points to construct an estimator at each iteration of RANSAC. | |
* **max_iterations(int|None):** Total number of iterations to run the RANSAC algorithm. Default = 1000. | |
Example | |
-------- | |
Suppose we want to fit a line to a noisy dataset. Let's first define the dataset. | |
>>> total_number_of_points = 200 | |
>>> points_x = np.linspace(-3, 7, total_number_of_points) | |
>>> points_y = np.array([20 * x + 3 + np.random.rand() * 60 for x in points_x]) | |
>>> points_y[50:75] = points_y[50:75] + 120 | |
>>> points_y[125:150] = points_y[125:150] - 120 | |
>>> data = np.stack([points_x, points_y], axis=-1) | |
Now we decide upon a few parameters - a threshold (distance) to identify an inlier, number of points (2 are required to uniquely fit a line), | |
and maximum number of iterations. | |
>>> threshold = 10 | |
>>> num_points = 2 | |
>>> max_iterations = 50000 | |
The fit, count_inlier and predict functions are as follows: | |
>>> fit = lambda p: (((p[1, 1] - p[0, 1]) / (p[1, 0] - p[0, 0])), p[1, 0], p[1, 1]) | |
>>> count_inlier = lambda p, params: (np.abs(p[:, 1] - (params[0] * (p[:, 0] - params[1]) + params[2])) < threshold).sum() | |
>>> predict = lambda x, params: params[0] * (x - params[1]) + params[2] | |
Now we call ransac and obtain the best parameters. | |
>>> ransac = RANSAC(data=data, | |
... fit=fit, | |
... count_inliers=count_inlier, | |
... num_points=num_points, | |
... max_iterations=max_iterations) | |
>>> best_params = ransac.find_best_estimator() | |
>>> y = predict(points_x, best_params) | |
References | |
---------- | |
[1] Random Sample Consensus (RANSAC) Wiki, https://en.wikipedia.org/wiki/Random_sample_consensus | |
""" | |
def __init__(self, data, fit, count_inliers, num_points, max_iterations=1000): | |
"""Default Constructor | |
Args: | |
data(np.ndarray): Numpy array of shape (N, [D1, ..., Dm]) containing N m-dimensional data samples. | |
fit(function): A function that takes a subset of `data` (num_points, [D1, ..., Dm]) and returns parameters `params` of the fitted model. | |
count_inliers(function): A function that takes `data` and the fitted estimator `params` and returns the number of inliers. | |
num_points(int): Number of random points to construct an estimator at each iteration of RANSAC. | |
max_iterations(int|None): Total number of iterations to run the RANSAC algorithm. Default = 1000. | |
""" | |
self.data = data | |
self.fit = fit | |
self.count_inliers = count_inliers | |
self.num_points = num_points | |
self.max_iterations = max_iterations | |
def random_sample(self): | |
"""Returns a dictionary containing the estimator corresponding to a random sample from the data | |
Returns: | |
dict: Dictionary containing: | |
estimator: The estimator parameters returned by the `fit` function, | |
num_inliers: The number of inliers evaluated by `count_inliers` corresponding to the function that was fit, | |
sampled_indices: The indices of the points that were used to fit the estimator | |
""" | |
sampled_indices = np.random.choice(len(self.data), size=self.num_points, replace=False) | |
data_subset = np.array([self.data[i] for i in sampled_indices]) | |
fit_estimator = self.fit(data_subset) | |
num_inliers = self.count_inliers(self.data, fit_estimator) | |
estimator = {'estimator': fit_estimator, | |
'num_inliers': num_inliers, | |
'sampled_indices': sampled_indices} | |
return estimator | |
def find_best_estimator(self, return_metadata=False): | |
"""Runs the RANSAC algorithm to find the best estimator | |
Args: | |
return_metadata(bool): If True, returns a tuple containing the estimator and the metadata information - number of inliers and sampled | |
indices. If False, only returns the estimator | |
Returns: | |
[tuple|np.ndarray]: A tuple containing (`estimator`, `num_inliers`, `sampled_indices`) if return_metadata=True, otherwise, `estimator`. | |
""" | |
best_inliers = -1 | |
best_estimator = {} | |
for _ in range(self.max_iterations): | |
estimator = self.random_sample() | |
if estimator['num_inliers'] >= best_inliers: | |
best_inliers = estimator['num_inliers'] | |
best_estimator = estimator | |
if return_metadata: | |
return best_estimator['estimator'], best_estimator['num_inliers'], best_estimator['sampled_indices'] | |
else: | |
return best_estimator['estimator'] | |
def plot(points_x, points_y, generated_y): | |
"""Plots points""" | |
plt.figure() | |
plt.plot(points_x, points_y, label='true data') | |
plt.plot(points_x, generated_y, label='fitted data') | |
plt.show() | |
def test_ransac(): | |
"""Simple test for RANSAC""" | |
total_number_of_points = 200 | |
points_x = np.linspace(-3, 7, total_number_of_points) | |
points_y = np.array([20 * x + 3 + np.random.rand() * 60 for x in points_x]) | |
points_y[50:75] = points_y[50:75] + 120 | |
points_y[125:150] = points_y[125:150] - 120 | |
data = np.stack([points_x, points_y], axis=-1) | |
threshold = 10 | |
num_points = 2 | |
max_iterations = 50000 | |
fit = lambda p: (((p[1, 1] - p[0, 1]) / (p[1, 0] - p[0, 0])), p[1, 0], p[1, 1]) | |
count_inlier = lambda p, params: (np.abs(p[:, 1] - (params[0] * (p[:, 0] - params[1]) + params[2])) < threshold).sum() | |
predict = lambda x, params: params[0] * (x - params[1]) + params[2] | |
ransac = RANSAC(data=data, | |
fit=fit, | |
count_inliers=count_inlier, | |
num_points=num_points, | |
max_iterations=max_iterations) | |
best_params, ni, si = ransac.find_best_estimator(return_metadata=True) | |
print('Number of inliers in the best model: {} ({}%)'.format(ni, 100*float(ni)/total_number_of_points)) | |
y = predict(points_x, best_params) | |
plot(points_x, points_y, y) | |
if __name__ == '__main__': | |
test_ransac() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment