Skip to content

Instantly share code, notes, and snippets.

@gnomeria
Created October 7, 2018 05:33
Show Gist options
  • Save gnomeria/2e5c1b2a0e117d7bb7b89f76a3f60857 to your computer and use it in GitHub Desktop.
Save gnomeria/2e5c1b2a0e117d7bb7b89f76a3f60857 to your computer and use it in GitHub Desktop.
Rough kotlin Interval Search Tree implementation
/**
* Kotlin rough reimplementation of Robert Sedgewick and Kevin Wayne Interval Search Tree
* from Java. The value is stored as a list of values, that will be added on key collision
*
* https://algs4.cs.princeton.edu/93intersection/IntervalST.java.html
*/
/******************************************************************************
* Compilation: javac IntervalST.java
* Execution: java IntervalST
* Dependencies: Interval1D.java
*
* Interval search tree implemented using a randomized BST.
*
* Duplicate policy: if an interval is inserted that already
* exists, the new value overwrite the old one
*
******************************************************************************/
import java.util.*
import kotlin.math.roundToInt
import kotlin.reflect.KMutableProperty0
class Node<V : Any>(val interval: ClosedRange<Int>, val value: MutableList<V>) {
var left: Node<V>? = null
var right: Node<V>? = null
var N = 1
var max = Integer.MIN_VALUE
init {
this.N = 1
this.max = interval.endInclusive
}
}
class IntervalSearchTree<V : Any> {
private var root: Node<V>? = null
private fun setRoot(self: Any, aProperty0: KMutableProperty0<Node<V>?>) {
aProperty0.setter.call(self, root)
}
operator fun contains(interval: ClosedRange<Int>): Boolean {
return get(interval) != null
}
// return value associated with the given key
// if no such value, return null
fun get(interval: ClosedRange<Int>): MutableList<V>? {
return get(root, interval)
}
private fun get(x: Node<V>?, interval: ClosedRange<Int>): MutableList<V>? {
if (x == null) return null
val cmp = when {
interval.start < x.interval.start || interval.endInclusive < x.interval.endInclusive -> -1
interval.start > x.interval.start || interval.endInclusive > x.interval.endInclusive -> 1
else -> 0
}
return when {
cmp < 0 -> get(x.left, interval)
cmp > 0 -> get(x.right, interval)
else -> x.value
}
}
fun put(interval: ClosedRange<Int>, value: V) {
if (contains(interval)) {
get(interval)?.add(value)
} else {
root = randomizedInsert(root, interval, value)
}
}
private fun randomizedInsert(x: Node<V>?, interval: ClosedRange<Int>, value: V): Node<V>? {
if (x == null) return Node(interval, mutableListOf(value))
if (Math.random() * size(x) < 1.0) return rootInsert(x, interval, value)
val cmp = interval.compareTo(x.interval)
if (cmp < 0) x.left = randomizedInsert(x.left, interval, value)
else x.right = randomizedInsert(x.right, interval, value)
fix(x)
return x
}
private fun rootInsert(x: Node<V>?, interval: ClosedRange<Int>, value: V): Node<V>? {
if (x == null) return Node<V>(interval, mutableListOf(value))
val cmp = interval.compareTo(x.interval)
if (cmp < 0) {
x.left = randomizedInsert(x.left, interval, value)
} else {
x.right = randomizedInsert(x.right, interval, value)
}
fix(x)
return x
}
//SEARCHING
fun search(interval: ClosedRange<Int>): ClosedRange<Int>? {
return search(root, interval)
}
fun search(x: Node<V>?, interval: ClosedRange<Int>): ClosedRange<Int>? {
if (x != null) {
if (interval.intersect(x.interval)) return x.interval
else if (x.left == null) setRoot(this, x::right)
else if (x.left!!.max < interval.start) setRoot(this, x::right)
else setRoot(this, x::left)
}
return null
}
fun searchAll(interval: ClosedRange<Int>): Iterable<ClosedRange<Int>> {
var list = LinkedList<ClosedRange<Int>>()
searchAll(root, interval, list)
return list
}
fun searchAll(x: Node<V>?, interval: ClosedRange<Int>, list: LinkedList<ClosedRange<Int>>): Boolean {
var found1 = false
var found2 = false
var found3 = false
if (x == null) {
return false
}
if (interval.intersect(x.interval)) {
list.add(x.interval)
found1 = true
}
if (x.left != null && x.left!!.max >= interval.start) {
found2 = searchAll(x.left, interval, list)
}
if (found2 || x.left == null || x.left!!.max < interval.start) {
found3 = searchAll(x.right, interval, list)
}
return found1 || found2 || found3
}
fun getAllValues(interval: ClosedRange<Int>): Iterable<V> {
val res = mutableListOf<V>()
if (root == null) return res
getAllValues(root, interval, res)
return res
}
private fun getAllValues(x: Node<V>?, interval: ClosedRange<Int>, list: MutableList<V>): Boolean {
var found1 = false
var found2 = false
var found3 = false
if (x == null) {
return false
}
if (interval.intersect(x.interval)) {
list.addAll(x.value)
found1 = true
}
if (x.left != null && x.left!!.max >= interval.start) {
found2 = getAllValues(x.left, interval, list)
}
if (found2 || x.left == null || x.left!!.max < interval.start) {
found3 = getAllValues(x.right, interval, list)
}
return found1 || found2 || found3
}
@Suppress("MemberVisibilityCanBePrivate")
fun size(): Int {
val r = root
return size(r)
}
private fun size(x: Node<V>?): Int {
return x?.N ?: 0
}
fun height(): Int {
return height(root)
}
private fun height(x: Node<V>?): Int {
if (x == null) return 0
return 1 + Math.max(height(x.left), height(x.right))
}
private fun fix(x: Node<V>?) {
if (x == null) {
return
}
x.N = 1 + size()
}
}
private fun ClosedRange<Int>.intersect(that: ClosedRange<Int>): Boolean {
return when {
that.endInclusive < this.start -> false
this.endInclusive < that.start -> false
else -> true
}
}
private fun ClosedRange<Int>.compareTo(that: ClosedRange<Int>): Int {
return when {
this.start < that.start || this.endInclusive < that.endInclusive -> -1
this.start > that.start || this.endInclusive > that.endInclusive -> 1
else -> 0
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment