Skip to content

Instantly share code, notes, and snippets.

@alexklibisz
Last active June 26, 2021 16:46
Show Gist options
  • Save alexklibisz/ecb49ace8e3f2dd44b612093bece6fe1 to your computer and use it in GitHub Desktop.
Save alexklibisz/ecb49ace8e3f2dd44b612093bece6fe1 to your computer and use it in GitHub Desktop.
PDCI Scala Implementation

This is a very rough first-pass Scala implementation of the PDCI algorithm for nearest neighbor search.

It's quite an interesting algorithm but I found it difficult to implement efficiently on the JVM. I'm pretty sure this will compile and run, but also I last touched this code in March 2019, so who knows :).

The build.sbt includes some unnecessary dependencies, as this was pulled out of a private repo containing other experiments which eventually became Elastiknn.

name := "elastiknn-prototypes"
version := "0.1"
scalaVersion := "2.12.8"
resolvers ++= Seq(
"Sonatype Snapshots" at "http://oss.sonatype.org/content/repositories/snapshots",
"Sonatype Releases" at "http://oss.sonatype.org/content/repositories/releases"
)
val sparkVersion = "2.4.0"
libraryDependencies ++= Seq(
"org.scalanlp" %% "breeze" % "1.0-RC2",
"org.scalanlp" % "breeze-natives_2.12" % "1.0-RC2",
"org.bytedeco.javacpp-presets" % "hdf5-platform" % "1.10.4-1.4.4",
"org.apache.spark" %% "spark-core" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion,
"org.apache.spark" %% "spark-mllib" % sparkVersion
)
fork in run := true
javaOptions in run += "-Xmx9G"
package com.klibisz.pdci
import java.util
import breeze.linalg.{*, DenseMatrix, DenseVector}
import breeze.stats.distributions.RandBasis
import com.klibisz.elastiknn.dataset.{AnnBenchmarks, Dataset}
import scala.collection.mutable
import scala.math.{Pi, acos, pow, max}
object PDCIImplicits {
implicit def convertDoubleToFloat(
m: DenseMatrix[Double]): DenseMatrix[Float] =
breeze.linalg.convert(m, Float)
}
import com.klibisz.pdci.PDCIImplicits._
final case class SearchTreeEntry(
dPrj: Float, // Value of the data vector (d) projected on a unit vector.
dIdx: Int // Index of the data vector (d).
)
sealed class SearchTree {
private val treeMap: java.util.TreeMap[Float, SearchTreeEntry] =
new util.TreeMap[Float, SearchTreeEntry]()
def add(e: SearchTreeEntry): SearchTreeEntry = treeMap.put(e.dPrj, e)
def remove(e: SearchTreeEntry): SearchTreeEntry = remove(e.dPrj)
def remove(prj: Float): SearchTreeEntry = treeMap.remove(prj)
def isEmpty: Boolean = treeMap.isEmpty
def nonEmpty: Boolean = !isEmpty
def nearest(prj: Float): SearchTreeEntry = {
val hi = treeMap.ceilingEntry(prj)
val lo = treeMap.floorEntry(prj)
if (hi == null) lo.getValue
else if (lo == null) hi.getValue
else if (hi.getKey - prj < prj - lo.getKey) hi.getValue
else lo.getValue
}
}
final case class NextVisitEntry(
dst: Float, // |qPrj - pPrj|. Used as the priority.
qProj: Float, // Value of the query vector (q) projected on the projection vector (p).
dProj: Float, // Value of the data vector (d) projected on the projection vector (p).
pIdx: Int, // Index of the projection vector (p).
dIdx: Int // Index of the data vector (d).
)
final case class CandidateEntry(
dst: Float, // Evaluated euclidean distance to the query vector.
dIdx: Int // Index of the data vector (d).
)
case class PDCI(m: Int,
L: Int,
D: DenseMatrix[Float] = DenseMatrix.zeros[Float](0, 0),
U: DenseMatrix[Float] = DenseMatrix.zeros[Float](0, 0),
T: Vector[SearchTree] = Vector.empty) {
private val twoOverPi: Float = (2f / math.Pi).toFloat
private def offset(l: Int, j: Int): Int = l * m + j
/** Each row is a random unit vector. */
private def randomUnitVectors(
rows: Int,
cols: Int,
rand: breeze.stats.distributions.Rand[Double]): DenseMatrix[Float] =
DenseMatrix
.rand[Double](rows, cols, rand)
.apply(*, ::)
.map(breeze.linalg.normalize(_))
def fit(D: DenseMatrix[Float])(
implicit rb: RandBasis = RandBasis.withSeed(0)): PDCI = {
val (n, d) = (D.rows, D.cols)
val U = randomUnitVectors(m * L, d, rb.uniform)
val T = (0 until (m * L)).toVector.map(_ => new SearchTree)
for {
j <- 0 until m
l <- 0 until L
uIdx = offset(l, j)
dIdx <- 0 until n
prj = D(dIdx, ::).dot(U(uIdx, ::))
} T(uIdx).add(SearchTreeEntry(prj, dIdx))
PDCI(m, L, D, U, T)
}
def query(q: DenseVector[Float],
k: Int,
errorProb: Float = 0.1f): Vector[Int] = {
// Useful for instantiating structures below.
val lvec = (0 until L).toVector
// Counter map. (compound index index (l))(data index) -> # of times visited in a simple index.
val C = lvec.map(_ => mutable.Map.empty[Int, Int].withDefault(_ => 0))
// Solutions/candidates. One set of candidates per compound index. The candidate is the data index.
// Don't think I need this with the data-specific stopping criteria.
// val S = lvec.map(_ => mutable.ListBuffer.empty[Int])
// Priority queues. One per compound index.
val P = lvec.map(
_ =>
mutable.PriorityQueue[NextVisitEntry]()(
Ordering.fromLessThan[NextVisitEntry](_.dst > _.dst)))
// Track the entries that get visited and removed from the trees so they can later be added back.
val V = new mutable.ListBuffer[NextVisitEntry]()
// Track the k closest candidates across all compound indexes.
// The head of the priority q will be the k-th furthest from the query vector.
val kNearestHeap = mutable.PriorityQueue[CandidateEntry]()(
Ordering.fromLessThan[CandidateEntry](_.dst < _.dst))
val kNearestSet = mutable.Set.empty[Int]
// Track the farthest candidate for each compound index.
val farthest = mutable.Map
.empty[Int, CandidateEntry]
.withDefault(_ => CandidateEntry(Float.MinValue, -1))
for {
l <- 0 until L
j <- 0 until m
uIdx = offset(l, j)
qPrj = U(uIdx, ::).t.dot(q)
SearchTreeEntry(dPrj, dIdx) = T(uIdx).nearest(qPrj)
nxt = NextVisitEntry((dPrj - qPrj).abs, qPrj, dPrj, uIdx, dIdx)
} {
P(l) += nxt
T(uIdx).remove(dPrj)
V += nxt
}
var numDistEvals: Int = 0
var curErrorProb: Float = 1.0f
var compoundIdx: Int = 0
while (curErrorProb > errorProb && numDistEvals < 4000) {
if (P(compoundIdx).nonEmpty) {
// Take the next entry that needs to be visited.
val NextVisitEntry(dst, qPrj, dPrj, pIdx, dIdx) =
P(compoundIdx).dequeue()
// Update the counter to denote that it's being visited.
val numVisits = C(compoundIdx)(dIdx) + 1
C(compoundIdx).put(dIdx, numVisits)
if (numVisits > m) {
println(numVisits)
}
// If this vector has been visited on all m simple indices, then it becomes a candidate.
// In this case, compute its distance and add it to the kNearest priority queue.
if (numVisits == m && !kNearestSet.contains(dIdx)) {
val cand =
CandidateEntry(breeze.linalg.norm(q - D(dIdx, ::).t).toFloat, dIdx)
numDistEvals += 1
var recompute = false
if (cand.dst > farthest(compoundIdx).dst) {
farthest(compoundIdx) = cand
recompute = true
}
if (kNearestHeap.size <= k) {
kNearestHeap.enqueue(cand)
kNearestSet.add(cand.dIdx)
} else if (cand.dst < kNearestHeap.head.dst) {
kNearestSet.remove(kNearestHeap.dequeue().dIdx)
kNearestHeap.enqueue(cand)
kNearestSet.add(cand.dIdx)
recompute = true
}
if (kNearestHeap.size > k && recompute) {
curErrorProb = 1.0f
val ck = kNearestHeap.head.dst
com.klibisz.elastiknn.utils.fastfor(0)(_ < L, _ + 1)(i => {
val cm = farthest(i).dst
if (cm > ck)
curErrorProb *= 1 - pow(twoOverPi * acos(ck / cm), m).toFloat
})
println(curErrorProb, numDistEvals)
}
}
// Find the next closest entry along this projection and add it to the priority queue.
if (T(pIdx).nonEmpty) {
val searchTreeEntry = T(pIdx).nearest(qPrj)
val nxt = NextVisitEntry((searchTreeEntry.dPrj - qPrj).abs,
qProj = qPrj,
dProj = searchTreeEntry.dPrj,
pIdx = pIdx,
dIdx = searchTreeEntry.dIdx)
P(compoundIdx) += nxt
T(pIdx).remove(nxt.dProj)
V += nxt
}
}
// Increment the compound index being looked at next.
compoundIdx = (compoundIdx + 1) % L
}
// Add the visited entries back into the trees.
V.foreach(pqItem =>
T(pqItem.pIdx).add(SearchTreeEntry(pqItem.dProj, pqItem.dIdx)))
val check = Seq(53843, 38620, 16186, 27059, 47003, 14563, 44566, 15260,
40368, 36395, 30502, 14770, 17228, 35919, 27166, 21518, 52010, 38763,
14505, 48108, 9444, 55668, 9724, 57204, 17946, 41958, 40710, 29762, 26957,
24700, 54364, 35937, 41236, 23149, 31073, 51420, 1673, 50255, 4130, 30988,
32980, 25972, 38504, 25705, 29983, 59665, 43917, 15856, 39003, 6475, 3298,
35150, 47117, 53679, 31614, 22546, 50147, 49516, 6525, 47015, 53589,
13382, 37550, 55750, 51500, 58507, 26967, 53783, 42146, 51916, 32432,
16935, 48955, 53855, 17017, 3324, 15570, 46160, 41928, 773, 53483, 21463,
58009, 23969, 31882, 9793, 46053, 24572, 27005, 52606, 53333, 23909,
18429, 1789, 20558, 11421, 25788, 53418, 8931, 49024).take(k).toSet
val found = kNearestHeap.dequeueAll.reverse.map(_.dIdx)
val isect = check.intersect(found.toSet)
println(
s"Found ${isect.size} of $k, curErrorProb = $curErrorProb, numDistEvals = $numDistEvals")
???
}
}
object PDCIExample extends App {
// val st = new SearchTree()
// st.add(SearchTreeEntry(1.0f, 0))
// st.add(SearchTreeEntry(2.0f, 0))
// st.add(SearchTreeEntry(3.0f, 0))
// println(st.nearest(SearchTreeEntry(3.1f, 0)))
val data: Dataset = AnnBenchmarks.euclidean.mnist784d60k
val pdci = PDCI(m = 10, L = 3).fit(data.train)
pdci.query(data.test(0, ::).t, k = 10, errorProb = 0.2f)
// pdci.query(data.test(0, ::).t, k0 = 10000, k1 = 500000)
println("Done")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment