Skip to content

Instantly share code, notes, and snippets.

@peheje
Last active August 27, 2020 19:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peheje/98070f0b065c1ed10917b40dab30bd29 to your computer and use it in GitHub Desktop.
Save peheje/98070f0b065c1ed10917b40dab30bd29 to your computer and use it in GitHub Desktop.
import java.lang.Exception
import java.nio.ByteBuffer
import java.security.MessageDigest
import java.util.*
import kotlin.math.abs
fun main() {
val filterSize = 1_000_000
val numberOfEntries = 100_000
val filter = BloomFilter(filterSize, numberOfHashes = 4)
val entriesInFilter = Array(numberOfEntries) { randomString() }
val entriesNotInFilter = Array(numberOfEntries) { randomString() }
for (entry in entriesInFilter)
filter.add(entry)
val confusionMatrix = ConfusionMatrix(entriesInFilter, entriesNotInFilter) { sample ->
filter.maybeExists(sample)
}
confusionMatrix.printReport()
if (confusionMatrix.falseNegativeRate > 0.0) {
throw Exception("This should not happen, if it does the implementation of the bloom filter is wrong.")
}
}
class BloomFilter(private val size: Int, numberOfHashes: Int) {
private val flags = BitSet(size)
private val salts = IntArray(numberOfHashes) { it }.map { it.toString() }
private val sha = MessageDigest.getInstance("SHA-1")
fun add(entry: String) {
for (salt in salts) {
val index = hashedIndex(entry, salt)
flags.set(index)
}
}
fun maybeExists(entry: String): Boolean {
for (salt in salts) {
val index = hashedIndex(entry, salt)
if (!flags[index]) {
return false
}
}
return true
}
private fun hashedIndex(entry: String, salt: String): Int {
val salted = entry + salt
val hash = sha.digest(salted.toByteArray())
val wrapped = ByteBuffer.wrap(hash)
return abs(wrapped.int) % size
}
}
class ConfusionMatrix<T>(positives: Array<T>, negatives: Array<T>, val predict: (sample: T) -> Boolean) {
private val positivesCount = positives.size
private val negativesCount = negatives.size
private var truePositiveCount = 0
private var trueNegativeCount = 0
private var falsePositiveCount = 0
private var falseNegativeCount = 0
val accuracyRate: Double
val misclassificationRate: Double
val truePositiveRate: Double
val trueNegativeRate: Double
val falsePositiveRate: Double
val falseNegativeRate: Double
init {
if (positives.isEmpty()) throw Exception("positives must not be empty")
if (negatives.isEmpty()) throw Exception("negatives must not be empty")
countTruePositiveAndFalseNegative(positives)
countFalsePositiveAndTrueNegative(negatives)
accuracyRate = (truePositiveCount + trueNegativeCount).toDouble() / (negativesCount + positivesCount)
misclassificationRate = 1.0 - accuracyRate
truePositiveRate = truePositiveCount.toDouble() / positivesCount
trueNegativeRate = trueNegativeCount.toDouble() / negativesCount
falsePositiveRate = falsePositiveCount.toDouble() / negativesCount
falseNegativeRate = falseNegativeCount.toDouble() / positivesCount
}
private fun countTruePositiveAndFalseNegative(positives: Array<T>) {
for (positive in positives) {
if (predict(positive))
truePositiveCount++
else
falseNegativeCount++
}
}
private fun countFalsePositiveAndTrueNegative(negatives: Array<T>) {
for (negative in negatives) {
if (predict(negative))
falsePositiveCount++
else
trueNegativeCount++
}
}
fun printReport() {
val dataRows = mapOf(
"Accuracy" to accuracyRate,
"Misclassification rate" to misclassificationRate,
"True positive rate" to truePositiveRate,
"True negative rate" to trueNegativeRate,
"False positive rate" to falsePositiveRate,
"False negative rate" to falseNegativeRate
)
val printer = Printer(dataRows)
printer.print()
}
}
class Printer(private val dataRows: Map<String, Double>) {
private val spacing = 2
private val longestLabelLength = getLongestString(dataRows.keys) + spacing
private val stringBuilder = StringBuilder()
private fun getLongestString(labels: Set<String>): Int {
return labels.map { it.length }.maxOrNull() ?: 50
}
fun print() {
for ((label, value) in dataRows) {
printLabel(label)
printPadding(label)
printFormattedValue(value)
println()
}
}
private fun printLabel(label: String) {
print("$label:")
}
private fun printPadding(label: String) {
val paddingNeeded = longestLabelLength - label.length
stringBuilder.clear()
for (x in 0 until paddingNeeded) stringBuilder.append(" ")
print(stringBuilder.toString())
}
private fun printFormattedValue(value: Double) {
val width6digits2 = "%6.2f"
val percentage = String.format(width6digits2, value * 100) + "%"
print(percentage)
}
}
private fun randomString(): String {
return UUID.randomUUID().toString()
}
@peheje
Copy link
Author

peheje commented Aug 22, 2020

You can run it on try.kotlinlang.org but set

val filterSize = 1000
val numberOfEntries = 100

Or lower if you get timeouts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment