Skip to content

Instantly share code, notes, and snippets.

@larroy
Last active January 8, 2017 06:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save larroy/f37f579613388eaccffb to your computer and use it in GitHub Desktop.
Save larroy/f37f579613388eaccffb to your computer and use it in GitHub Desktop.
Binary search in scala
import scala.annotation.tailrec
object BinarySearch {
/**
* @param xs Sequence to search
* @param key key to find
* @param min minimum index (inclusive)
* @param max maximum index (inclusive)
* @param keyExtract function to apply to elements xs before comparing to key, defaults as identity
* @tparam T type of elements in the sequence
* @tparam U key type
* @return the index of the element if found or NotFound otherwise
*/
@tailrec
def binarySearch[T, U](key: U, xs: IndexedSeq[T], min: Int, max: Int, keyExtract: (T) => U = (x:T) => x)(implicit ordering: Ordering[U]): Option[Int] = {
if (max < min)
None
else {
val mid = (min + max) >>> 1
val extracted: U = keyExtract(xs(mid))
if (ordering.lt(key, extracted))
binarySearch(key, xs, min, mid - 1, keyExtract)(ordering)
else if (ordering.gt(key, extracted))
binarySearch(key, xs, mid + 1, max, keyExtract)(ordering)
else
Some(mid)
}
}
/**
* Find the first element that doesn't compare less than key, otherwise NotFound
*/
def lowerBound[T, U](key: U, xs: IndexedSeq[T], keyExtract: (T) => U = (x:T) => x)(implicit ordering: Ordering[U]): Option[Int] = {
var len = xs.length
var first = 0
while (len > 0) {
val half = len >>> 1
val middle = first + half
if (ordering.lt(keyExtract(xs(middle)), key)) {
first = middle + 1
len = len - half - 1
} else {
len = half
}
}
if (first < xs.length)
Some(first)
else
None
}
/**
* Find the first element that compares greater than key, otherwise NotFound
*/
def upperBound[T, U](key: U, xs: IndexedSeq[T], keyExtract: (T) => U = (x:T) => x)(implicit ordering: Ordering[U]): Option[Int] = {
var len = xs.length
var first = 0
while (len > 0) {
val half = len >>> 1
val middle = first + half
if (ordering.lt(key, keyExtract(xs(middle)))) {
len = half
} else {
first = middle + 1
len = len - half - 1
}
}
if (first < xs.length)
Some(first)
else
None
}
}
/**
* @author piotr 05.09.14
*/
import org.specs2.mutable._
import scala.collection.mutable
import BinarySearch._
class BinarySearchSpec extends Specification {
"BinarySearchSpec" should {
"search correctly" in {
val xs = mutable.ArrayBuffer(1,2,2,3,4,6,9,12)
// 0 1 2 3 4 5 6 7
binarySearch[Int, Int](0, xs, 0, xs.size - 1) must beNone
binarySearch[Int, Int](1, xs, 0, xs.size - 1) must beSome(0)
binarySearch[Int, Int](2, xs, 0, xs.size - 1) must beSome(1) or beSome(2)
binarySearch[Int, Int](3, xs, 0, xs.size - 1) must beSome(3)
binarySearch[Int, Int](4, xs, 0, xs.size - 1) must beSome(4)
binarySearch[Int, Int](5, xs, 0, xs.size - 1) must beNone
}
"lower bound finds the first element which doesn't compare less than key" in {
val xs = mutable.ArrayBuffer(1,2,2,3,4,6,9,12)
lowerBound[Int, Int](2, xs) must beSome(1)
lowerBound[Int, Int](8, xs) must beSome(6)
lowerBound[Int, Int](1, xs) must beSome(0)
lowerBound[Int, Int](10, xs) must beSome(7)
lowerBound[Int, Int](13, xs) must beNone
lowerBound[Int, Int](0, Vector(0)) must beSome(0)
lowerBound[Int, Int](0, Vector()) must beNone
lowerBound[Int, Int](1, Vector(0)) must beNone
}
"upper bound finds the first element greater than key" in {
val xs = mutable.ArrayBuffer(1,2,2,3,4)
// 0 1 2 3 4
upperBound[Int, Int](2, xs) must beSome(3)
upperBound[Int, Int](8, xs) must beNone
upperBound[Int, Int](1, xs) must beSome(1)
upperBound[Int, Int](4, xs) must beNone
upperBound[Int, Int](0, Vector(0)) must beNone
upperBound[Int, Int](0, Vector()) must beNone
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment