Skip to content

Instantly share code, notes, and snippets.

@GibsonRuitiari
Created December 4, 2023 11:01
Show Gist options
  • Save GibsonRuitiari/ade41d75def0374f4c261a0aa3072c69 to your computer and use it in GitHub Desktop.
Save GibsonRuitiari/ade41d75def0374f4c261a0aa3072c69 to your computer and use it in GitHub Desktop.
This is the kotlin implementation of a patricia trie. This implementation is meant to be an accompaniment of the write up that explains the inner workings of patricia trie as described by Knuth. Please see the write for further understanding.
import java.util.Objects
import java.util.concurrent.locks.ReentrantLock
import kotlin.math.max
fun main(args: Array<String>){
val text ="A B C D E F G"
val trie = Trie(text)
println("${" ".repeat(10)}1.START${" ".repeat(10)}")
println(trie.addNodeToTrie(1))
println()
println("${" ".repeat(10)}2.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(2))
println()
println("${" ".repeat(10)}3.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(3))
//
println()
println("${" ".repeat(10)}4.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(4))
//
println()
println("${" ".repeat(10)}5.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(5))
println()
println("${" ".repeat(10)}6.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(6))
println()
println("${" ".repeat(10)}7.↓${" ".repeat(10)}")
println(trie.addNodeToTrie(7))
Thread.sleep(1000)
println(trie.searchForNode(6))
// Thread.sleep(1000)
println(trie.searchForNode(4))
println()
println(trie.searchForNode(3))
println()
println(trie.searchForNode(2))
println("size of trie ${trie.size}")
}
fun String.toBitPattern(): Array<Array<Int>?>{
require(this.isNotBlank() || this.isNotEmpty())
val charArray = this.toCharArray()
val array = arrayOfNulls<Array<Int>>(this.length)
for ((index,i) in charArray.withIndex()){
val bitsAsCharArray = Integer.toBinaryString(i.code).padStart(8, padChar = '0').toCharArray()
val bitsAsIntArray= Array(bitsAsCharArray.size){
bitsAsCharArray[it].digitToInt()
}
array[index]=bitsAsIntArray
}
return array
}
fun bitDifferenceBetweenTwoBitPatterns(bitPatternWord1:Array<Array<Int>?>,bitPatternWord2:Array<Array<Int>?>):Int{
val max = max(bitPatternWord1.size, bitPatternWord2.size)
for (i in 0..max){
val aBits=bitPatternWord1.getOrNull(i)
val bBits=bitPatternWord2.getOrNull(i)
if (aBits==null || bBits==null) return -1
if (aBits.contentEquals(bBits)) continue
for (index in aBits.indices){
if ((aBits[index] xor bBits[index])==1){
return i.times(8).plus(index+1)
}
}
}
return -1
}
class Trie(text:String){
private var _size:Int = 0
private var root:Node?=null
fun getRootNode():Node?=root
private var splitText = text.split(" ")
data class NodeSearchResult(val node:Node?,val bit:Int, val isFound:Boolean)
private val lock = ReentrantLock()
fun addNodeToTrie(key: Int):Node?{
// rationalizedIndex because our keys start at 1 not 0 so minus by 1 to get the proper index
val rationalizedIndex = key-1
val text= splitText.getOrNull(rationalizedIndex) ?: return null
val keyBitPattern = text.toBitPattern()
if (root==null){
root=Node(key = key, skipBit = 0, value = keyBitPattern)
root!!.leftNode=root
_size+=1
return root
}
val searchResult = searchForNode(key)
val flattenedKeyBitPattern = keyBitPattern.filterNotNull().flatMap { it.toList() }
val isFound = searchResult?.isFound
val searchResultNode = searchResult?.node ?: return null
if (isFound==true) {
return searchResultNode
}
val jBit = bitDifferenceBetweenTwoBitPatterns(searchResultNode.value,keyBitPattern)
// keyBits[:jBit-1]
val newBitPatternKey= flattenedKeyBitPattern.take(jBit-1)
val (node, bit) = searchForLink(newBitPatternKey)
val previousNode:Node?
val newNode = Node(key, skipBit = jBit, value = keyBitPattern)
if (bit==0){
previousNode=node?.leftNode
node?.leftNode =newNode
}else{
previousNode=node?.rightNode
node?.rightNode = newNode
}
for (index in flattenedKeyBitPattern.indices){
if (index == jBit-1){
val keyBit = flattenedKeyBitPattern[index]
if (keyBit==0){
newNode.leftNode = newNode
newNode.rightNode = if (searchResultNode==node) previousNode else searchResultNode
}else{
newNode.rightNode=newNode
newNode.leftNode=if (searchResultNode==node) previousNode else searchResultNode
}
_size+=1
if (node?.key==1) root = node
else{
if (searchResultNode.key==1) root=node
}
return newNode
}
}
return null
}
private fun searchForLink(newBitPatternKey:List<Int>):Pair<Node?,Int>{
var parent = root
var child = root?.leftNode
var skipBit: Int?
var jBit: Int?
val lengthOfKeyBitPattern = newBitPatternKey.size
while (true){
if (child?.skipBit==0) break // only one node exists so just get out
skipBit = child?.skipBit ?: break
jBit=newBitPatternKey.getOrNull(skipBit-1)
if (skipBit > lengthOfKeyBitPattern) break
val s=parent
if (jBit==null) break
if (jBit==0){
parent=child
child = child.leftNode
}else{
parent = child
child=child.rightNode
}
if (child?.key==parent.key || s?.key==child?.key){
return child to jBit
}
}
val bit = parent?.skipBit ?: 0
return parent to bit
}
fun searchForNode(key:Int):NodeSearchResult?{
if (root==null) return null
// key not part of the given text
val text=splitText.getOrNull(key-1) ?: return null
val keyBitPattern = text.toBitPattern().filterNotNull().flatMap { it.toList() }
var parent = root
var child = root?.leftNode
var skipBit = child?.skipBit!!
var jBit: Int?
if (child?.key==key){
return NodeSearchResult(child,bit= keyBitPattern.getOrNull(child.skipBit-1) ?: 0,isFound = true)
}else if (root?.key==key){
return NodeSearchResult(node = root, bit = keyBitPattern.getOrNull(root!!.skipBit-1) ?: 0,isFound = true)
}
val lengthOfKeyBitPattern = keyBitPattern.size
while (skipBit<= lengthOfKeyBitPattern){
if (skipBit==0) break
skipBit = child?.skipBit ?: break
jBit=keyBitPattern.getOrNull(skipBit-1)
val s=parent
if (jBit==null) break
if (jBit==0){
parent=child
child = child?.leftNode
}else{
parent=child
child = child?.rightNode
}
if (child?.key == parent?.key || s?.key==child?.key){
return NodeSearchResult(node = child, bit = jBit, isFound = child?.key == key)
}
}
// we have reached it means the skipBit>n or jbit ==null so perform a final check to see if the parent.key==
// key
return NodeSearchResult(node=parent, bit = parent?.skipBit ?: 0, isFound = parent?.key ==key)
}
val size:Int
get() = _size
}
/**
* A node of a compressed trie/patricia trie
* Node consists of a key, skipBit, value, leftNode and rightNode
* the key is an integer indicating the starting position of our string in text
* As an example in This house was built by Thomas
* starting from 1 we assign each character a integer (key) all the way until the end
* So the key for This will be 1, house will be 2, was 3, and so on..
* The skipBit is the bit-difference between an immediate ancestor node and a current-node
* the value is the array of bits (bit-array) representing the bit patterns of string/word
* represented by our key. We are using ascii-character encoding so the total length of a single
* bit pattern is 8 by default. As an example since key 1= this
* the value array will hold a total of 5 bit patterns since this consists of 5 characters.
* bit pattern for this therefore is
* ```
* t- 01110100
* h-01101000
* i-01101001
* s-01110011
* ```
* the leftNode and rightNode represents the left and right links of this node (self-explanatory :xd)
*/
data class Node(val key:Int,
val skipBit:Int,
val value:Array<Array<Int>?>,
var leftNode:Node?=null,
var rightNode:Node?=null){
override fun toString(): String{
val builder = StringBuilder().apply {
appendLine("{key:$key,")
appendLine("indexBit:$skipBit")
appendLine("value:${value.joinToString(separator = " "){s-> "${s?.joinToString("")}"}}")
}
if (this.leftNode!=null) builder.appendLine("left = ${leftNode?.key}}")
if (this.rightNode!=null) builder.appendLine("right = ${rightNode?.key}}")
return builder.toString()
}
override fun hashCode(): Int {
return Objects.hash(key) }
override fun equals(other: Any?): Boolean{
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as Node
if (key != other.key) return false
if (!value.contentEquals(other.value)) return false
if (skipBit != other.skipBit) return false
if (leftNode != other.leftNode) return false
if (rightNode != other.rightNode) return false
return true
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment