Last active
January 15, 2017 21:46
-
-
Save nlw0/b4eda8623d5853399a4271d67e107b7e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import estimation.RobustEstimator | |
import geometry.{Line, Point} | |
import scala.math.abs | |
import scala.util.Random | |
object TestRANSAC extends App { | |
val n_outliers = 100 | |
val n_inliers = 100 | |
val aux = Line(Point.nextGaussian) | |
val original_model = if (abs(aux.x) < abs(aux.y)) aux else Line(aux.y, aux.x) | |
val sigma = 0.2 | |
val data = generate_data(n_outliers, n_inliers, original_model, sigma) | |
val ransac = new RobustEstimator(pick_point_pair, get_line_from_point_pair, test_point_closer_than(3 * sigma)) | |
val iterations = 10 | |
val estimated_model = ransac.estimate(data, iterations) | |
println(s"Original model $original_model") | |
println(s"Estimated model: $estimated_model") | |
def pick_point_pair(data: Seq[Point]): (Point, Point) = { | |
val List(p1, p2) = sample_without_replacement(2, data) | |
(p1, p2) | |
} | |
def get_line_from_point_pair(point_pair: (Point, Point)): Line = { | |
val (p1, p2) = point_pair | |
val delta = p2 - p1 | |
val x = delta.y * (p1 cross p2) / delta.sqnorm | |
val y = -delta.x * (p1 cross p2) / delta.sqnorm | |
Line(x, y) | |
} | |
def test_point_closer_than(threshold: Double)(l: Line)(p: Point): Boolean = { | |
abs(l distance p) < threshold | |
} | |
def sample_without_replacement[D](N: Int, data: Seq[D], sample: List[D] = List.empty): List[D] = { | |
if (N <= 0) sample else { | |
val el = Random.nextInt(data.size) | |
val remaining_data = data.take(el) ++ data.drop(el + 1) | |
sample_without_replacement(N - 1, remaining_data, sample :+ data(el)) | |
} | |
} | |
def generate_data(n_outliers: Int, n_inliers: Int, line: Line, noise: Double) = { | |
val r = 10.0 // Size of test space | |
val outliers = List.fill(n_outliers) { (Point.nextUniform - Point(0.5, 0.5)) * r } | |
val inliers = List.fill(n_inliers) { | |
val x = (Random.nextDouble - 0.5) * r / 2 | |
Point(x, line(x)) + Point.nextGaussian * noise | |
} | |
outliers ++ inliers | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment