Skip to content

Instantly share code, notes, and snippets.

@nmakes
Last active November 4, 2021 01:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nmakes/06e495ee353f55f7e63df0660344b88c to your computer and use it in GitHub Desktop.
Save nmakes/06e495ee353f55f7e63df0660344b88c to your computer and use it in GitHub Desktop.
A generalized RANSAC algorithm to fit any arbitrary function (see Example)
"""
Implements the RANSAC Algorithm (https://en.wikipedia.org/wiki/Random_sample_consensus)
Written By: Naveen Venkat (naveenvenkat.com)
School of Computer Science, Carnegie Mellon University
"""
import matplotlib.pyplot as plt
import numpy as np
class RANSAC:
"""Implements the RANSAC [1] algorithm.
Params
------
* **data(np.ndarray):** Numpy array of shape (N, [D1, ..., Dm]) containing N m-dimensional data samples.
* **fit(function):** A function that takes a subset of `data` (num_points, [D1, ..., Dm]) and returns parameters `params` of the fitted model.
* **count_inliers(function):** A function that takes `data` and the fitted estimator `params` and returns the number of inliers.
* **num_points(int):** Number of random points to construct an estimator at each iteration of RANSAC.
* **max_iterations(int|None):** Total number of iterations to run the RANSAC algorithm. Default = 1000.
Example
--------
Suppose we want to fit a line to a noisy dataset. Let's first define the dataset.
>>> total_number_of_points = 200
>>> points_x = np.linspace(-3, 7, total_number_of_points)
>>> points_y = np.array([20 * x + 3 + np.random.rand() * 60 for x in points_x])
>>> points_y[50:75] = points_y[50:75] + 120
>>> points_y[125:150] = points_y[125:150] - 120
>>> data = np.stack([points_x, points_y], axis=-1)
Now we decide upon a few parameters - a threshold (distance) to identify an inlier, number of points (2 are required to uniquely fit a line),
and maximum number of iterations.
>>> threshold = 10
>>> num_points = 2
>>> max_iterations = 50000
The fit, count_inlier and predict functions are as follows:
>>> fit = lambda p: (((p[1, 1] - p[0, 1]) / (p[1, 0] - p[0, 0])), p[1, 0], p[1, 1])
>>> count_inlier = lambda p, params: (np.abs(p[:, 1] - (params[0] * (p[:, 0] - params[1]) + params[2])) < threshold).sum()
>>> predict = lambda x, params: params[0] * (x - params[1]) + params[2]
Now we call ransac and obtain the best parameters.
>>> ransac = RANSAC(data=data,
... fit=fit,
... count_inliers=count_inlier,
... num_points=num_points,
... max_iterations=max_iterations)
>>> best_params = ransac.find_best_estimator()
>>> y = predict(points_x, best_params)
References
----------
[1] Random Sample Consensus (RANSAC) Wiki, https://en.wikipedia.org/wiki/Random_sample_consensus
"""
def __init__(self, data, fit, count_inliers, num_points, max_iterations=1000):
"""Default Constructor
Args:
data(np.ndarray): Numpy array of shape (N, [D1, ..., Dm]) containing N m-dimensional data samples.
fit(function): A function that takes a subset of `data` (num_points, [D1, ..., Dm]) and returns parameters `params` of the fitted model.
count_inliers(function): A function that takes `data` and the fitted estimator `params` and returns the number of inliers.
num_points(int): Number of random points to construct an estimator at each iteration of RANSAC.
max_iterations(int|None): Total number of iterations to run the RANSAC algorithm. Default = 1000.
"""
self.data = data
self.fit = fit
self.count_inliers = count_inliers
self.num_points = num_points
self.max_iterations = max_iterations
def random_sample(self):
"""Returns a dictionary containing the estimator corresponding to a random sample from the data
Returns:
dict: Dictionary containing:
estimator: The estimator parameters returned by the `fit` function,
num_inliers: The number of inliers evaluated by `count_inliers` corresponding to the function that was fit,
sampled_indices: The indices of the points that were used to fit the estimator
"""
sampled_indices = np.random.choice(len(self.data), size=self.num_points, replace=False)
data_subset = np.array([self.data[i] for i in sampled_indices])
fit_estimator = self.fit(data_subset)
num_inliers = self.count_inliers(self.data, fit_estimator)
estimator = {'estimator': fit_estimator,
'num_inliers': num_inliers,
'sampled_indices': sampled_indices}
return estimator
def find_best_estimator(self, return_metadata=False):
"""Runs the RANSAC algorithm to find the best estimator
Args:
return_metadata(bool): If True, returns a tuple containing the estimator and the metadata information - number of inliers and sampled
indices. If False, only returns the estimator
Returns:
[tuple|np.ndarray]: A tuple containing (`estimator`, `num_inliers`, `sampled_indices`) if return_metadata=True, otherwise, `estimator`.
"""
best_inliers = -1
best_estimator = {}
for _ in range(self.max_iterations):
estimator = self.random_sample()
if estimator['num_inliers'] >= best_inliers:
best_inliers = estimator['num_inliers']
best_estimator = estimator
if return_metadata:
return best_estimator['estimator'], best_estimator['num_inliers'], best_estimator['sampled_indices']
else:
return best_estimator['estimator']
def plot(points_x, points_y, generated_y):
"""Plots points"""
plt.figure()
plt.plot(points_x, points_y, label='true data')
plt.plot(points_x, generated_y, label='fitted data')
plt.show()
def test_ransac():
"""Simple test for RANSAC"""
total_number_of_points = 200
points_x = np.linspace(-3, 7, total_number_of_points)
points_y = np.array([20 * x + 3 + np.random.rand() * 60 for x in points_x])
points_y[50:75] = points_y[50:75] + 120
points_y[125:150] = points_y[125:150] - 120
data = np.stack([points_x, points_y], axis=-1)
threshold = 10
num_points = 2
max_iterations = 50000
fit = lambda p: (((p[1, 1] - p[0, 1]) / (p[1, 0] - p[0, 0])), p[1, 0], p[1, 1])
count_inlier = lambda p, params: (np.abs(p[:, 1] - (params[0] * (p[:, 0] - params[1]) + params[2])) < threshold).sum()
predict = lambda x, params: params[0] * (x - params[1]) + params[2]
ransac = RANSAC(data=data,
fit=fit,
count_inliers=count_inlier,
num_points=num_points,
max_iterations=max_iterations)
best_params, ni, si = ransac.find_best_estimator(return_metadata=True)
print('Number of inliers in the best model: {} ({}%)'.format(ni, 100*float(ni)/total_number_of_points))
y = predict(points_x, best_params)
plot(points_x, points_y, y)
if __name__ == '__main__':
test_ransac()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment