|
package com.klibisz.pdci |
|
|
|
import java.util |
|
|
|
import breeze.linalg.{*, DenseMatrix, DenseVector} |
|
import breeze.stats.distributions.RandBasis |
|
import com.klibisz.elastiknn.dataset.{AnnBenchmarks, Dataset} |
|
|
|
import scala.collection.mutable |
|
import scala.math.{Pi, acos, pow, max} |
|
|
|
object PDCIImplicits { |
|
|
|
implicit def convertDoubleToFloat( |
|
m: DenseMatrix[Double]): DenseMatrix[Float] = |
|
breeze.linalg.convert(m, Float) |
|
|
|
} |
|
|
|
import com.klibisz.pdci.PDCIImplicits._ |
|
|
|
final case class SearchTreeEntry( |
|
dPrj: Float, // Value of the data vector (d) projected on a unit vector. |
|
dIdx: Int // Index of the data vector (d). |
|
) |
|
|
|
sealed class SearchTree { |
|
private val treeMap: java.util.TreeMap[Float, SearchTreeEntry] = |
|
new util.TreeMap[Float, SearchTreeEntry]() |
|
def add(e: SearchTreeEntry): SearchTreeEntry = treeMap.put(e.dPrj, e) |
|
def remove(e: SearchTreeEntry): SearchTreeEntry = remove(e.dPrj) |
|
def remove(prj: Float): SearchTreeEntry = treeMap.remove(prj) |
|
def isEmpty: Boolean = treeMap.isEmpty |
|
def nonEmpty: Boolean = !isEmpty |
|
def nearest(prj: Float): SearchTreeEntry = { |
|
val hi = treeMap.ceilingEntry(prj) |
|
val lo = treeMap.floorEntry(prj) |
|
if (hi == null) lo.getValue |
|
else if (lo == null) hi.getValue |
|
else if (hi.getKey - prj < prj - lo.getKey) hi.getValue |
|
else lo.getValue |
|
} |
|
} |
|
|
|
final case class NextVisitEntry( |
|
dst: Float, // |qPrj - pPrj|. Used as the priority. |
|
qProj: Float, // Value of the query vector (q) projected on the projection vector (p). |
|
dProj: Float, // Value of the data vector (d) projected on the projection vector (p). |
|
pIdx: Int, // Index of the projection vector (p). |
|
dIdx: Int // Index of the data vector (d). |
|
) |
|
|
|
final case class CandidateEntry( |
|
dst: Float, // Evaluated euclidean distance to the query vector. |
|
dIdx: Int // Index of the data vector (d). |
|
) |
|
|
|
case class PDCI(m: Int, |
|
L: Int, |
|
D: DenseMatrix[Float] = DenseMatrix.zeros[Float](0, 0), |
|
U: DenseMatrix[Float] = DenseMatrix.zeros[Float](0, 0), |
|
T: Vector[SearchTree] = Vector.empty) { |
|
|
|
private val twoOverPi: Float = (2f / math.Pi).toFloat |
|
|
|
private def offset(l: Int, j: Int): Int = l * m + j |
|
|
|
/** Each row is a random unit vector. */ |
|
private def randomUnitVectors( |
|
rows: Int, |
|
cols: Int, |
|
rand: breeze.stats.distributions.Rand[Double]): DenseMatrix[Float] = |
|
DenseMatrix |
|
.rand[Double](rows, cols, rand) |
|
.apply(*, ::) |
|
.map(breeze.linalg.normalize(_)) |
|
|
|
def fit(D: DenseMatrix[Float])( |
|
implicit rb: RandBasis = RandBasis.withSeed(0)): PDCI = { |
|
val (n, d) = (D.rows, D.cols) |
|
val U = randomUnitVectors(m * L, d, rb.uniform) |
|
val T = (0 until (m * L)).toVector.map(_ => new SearchTree) |
|
for { |
|
j <- 0 until m |
|
l <- 0 until L |
|
uIdx = offset(l, j) |
|
dIdx <- 0 until n |
|
prj = D(dIdx, ::).dot(U(uIdx, ::)) |
|
} T(uIdx).add(SearchTreeEntry(prj, dIdx)) |
|
PDCI(m, L, D, U, T) |
|
} |
|
|
|
def query(q: DenseVector[Float], |
|
k: Int, |
|
errorProb: Float = 0.1f): Vector[Int] = { |
|
|
|
// Useful for instantiating structures below. |
|
val lvec = (0 until L).toVector |
|
|
|
// Counter map. (compound index index (l))(data index) -> # of times visited in a simple index. |
|
val C = lvec.map(_ => mutable.Map.empty[Int, Int].withDefault(_ => 0)) |
|
|
|
// Solutions/candidates. One set of candidates per compound index. The candidate is the data index. |
|
// Don't think I need this with the data-specific stopping criteria. |
|
// val S = lvec.map(_ => mutable.ListBuffer.empty[Int]) |
|
|
|
// Priority queues. One per compound index. |
|
val P = lvec.map( |
|
_ => |
|
mutable.PriorityQueue[NextVisitEntry]()( |
|
Ordering.fromLessThan[NextVisitEntry](_.dst > _.dst))) |
|
|
|
// Track the entries that get visited and removed from the trees so they can later be added back. |
|
val V = new mutable.ListBuffer[NextVisitEntry]() |
|
|
|
// Track the k closest candidates across all compound indexes. |
|
// The head of the priority q will be the k-th furthest from the query vector. |
|
val kNearestHeap = mutable.PriorityQueue[CandidateEntry]()( |
|
Ordering.fromLessThan[CandidateEntry](_.dst < _.dst)) |
|
val kNearestSet = mutable.Set.empty[Int] |
|
|
|
// Track the farthest candidate for each compound index. |
|
val farthest = mutable.Map |
|
.empty[Int, CandidateEntry] |
|
.withDefault(_ => CandidateEntry(Float.MinValue, -1)) |
|
|
|
for { |
|
l <- 0 until L |
|
j <- 0 until m |
|
uIdx = offset(l, j) |
|
qPrj = U(uIdx, ::).t.dot(q) |
|
SearchTreeEntry(dPrj, dIdx) = T(uIdx).nearest(qPrj) |
|
nxt = NextVisitEntry((dPrj - qPrj).abs, qPrj, dPrj, uIdx, dIdx) |
|
} { |
|
P(l) += nxt |
|
T(uIdx).remove(dPrj) |
|
V += nxt |
|
} |
|
|
|
var numDistEvals: Int = 0 |
|
var curErrorProb: Float = 1.0f |
|
var compoundIdx: Int = 0 |
|
|
|
while (curErrorProb > errorProb && numDistEvals < 4000) { |
|
|
|
if (P(compoundIdx).nonEmpty) { |
|
|
|
// Take the next entry that needs to be visited. |
|
val NextVisitEntry(dst, qPrj, dPrj, pIdx, dIdx) = |
|
P(compoundIdx).dequeue() |
|
|
|
// Update the counter to denote that it's being visited. |
|
val numVisits = C(compoundIdx)(dIdx) + 1 |
|
C(compoundIdx).put(dIdx, numVisits) |
|
|
|
if (numVisits > m) { |
|
println(numVisits) |
|
} |
|
|
|
// If this vector has been visited on all m simple indices, then it becomes a candidate. |
|
// In this case, compute its distance and add it to the kNearest priority queue. |
|
if (numVisits == m && !kNearestSet.contains(dIdx)) { |
|
val cand = |
|
CandidateEntry(breeze.linalg.norm(q - D(dIdx, ::).t).toFloat, dIdx) |
|
numDistEvals += 1 |
|
var recompute = false |
|
if (cand.dst > farthest(compoundIdx).dst) { |
|
farthest(compoundIdx) = cand |
|
recompute = true |
|
} |
|
if (kNearestHeap.size <= k) { |
|
kNearestHeap.enqueue(cand) |
|
kNearestSet.add(cand.dIdx) |
|
} else if (cand.dst < kNearestHeap.head.dst) { |
|
kNearestSet.remove(kNearestHeap.dequeue().dIdx) |
|
kNearestHeap.enqueue(cand) |
|
kNearestSet.add(cand.dIdx) |
|
recompute = true |
|
} |
|
if (kNearestHeap.size > k && recompute) { |
|
curErrorProb = 1.0f |
|
val ck = kNearestHeap.head.dst |
|
com.klibisz.elastiknn.utils.fastfor(0)(_ < L, _ + 1)(i => { |
|
val cm = farthest(i).dst |
|
if (cm > ck) |
|
curErrorProb *= 1 - pow(twoOverPi * acos(ck / cm), m).toFloat |
|
}) |
|
println(curErrorProb, numDistEvals) |
|
} |
|
} |
|
|
|
// Find the next closest entry along this projection and add it to the priority queue. |
|
if (T(pIdx).nonEmpty) { |
|
val searchTreeEntry = T(pIdx).nearest(qPrj) |
|
val nxt = NextVisitEntry((searchTreeEntry.dPrj - qPrj).abs, |
|
qProj = qPrj, |
|
dProj = searchTreeEntry.dPrj, |
|
pIdx = pIdx, |
|
dIdx = searchTreeEntry.dIdx) |
|
P(compoundIdx) += nxt |
|
T(pIdx).remove(nxt.dProj) |
|
V += nxt |
|
} |
|
} |
|
|
|
// Increment the compound index being looked at next. |
|
compoundIdx = (compoundIdx + 1) % L |
|
|
|
} |
|
|
|
// Add the visited entries back into the trees. |
|
V.foreach(pqItem => |
|
T(pqItem.pIdx).add(SearchTreeEntry(pqItem.dProj, pqItem.dIdx))) |
|
|
|
val check = Seq(53843, 38620, 16186, 27059, 47003, 14563, 44566, 15260, |
|
40368, 36395, 30502, 14770, 17228, 35919, 27166, 21518, 52010, 38763, |
|
14505, 48108, 9444, 55668, 9724, 57204, 17946, 41958, 40710, 29762, 26957, |
|
24700, 54364, 35937, 41236, 23149, 31073, 51420, 1673, 50255, 4130, 30988, |
|
32980, 25972, 38504, 25705, 29983, 59665, 43917, 15856, 39003, 6475, 3298, |
|
35150, 47117, 53679, 31614, 22546, 50147, 49516, 6525, 47015, 53589, |
|
13382, 37550, 55750, 51500, 58507, 26967, 53783, 42146, 51916, 32432, |
|
16935, 48955, 53855, 17017, 3324, 15570, 46160, 41928, 773, 53483, 21463, |
|
58009, 23969, 31882, 9793, 46053, 24572, 27005, 52606, 53333, 23909, |
|
18429, 1789, 20558, 11421, 25788, 53418, 8931, 49024).take(k).toSet |
|
|
|
val found = kNearestHeap.dequeueAll.reverse.map(_.dIdx) |
|
val isect = check.intersect(found.toSet) |
|
println( |
|
s"Found ${isect.size} of $k, curErrorProb = $curErrorProb, numDistEvals = $numDistEvals") |
|
??? |
|
} |
|
|
|
} |
|
|
|
object PDCIExample extends App { |
|
|
|
// val st = new SearchTree() |
|
// st.add(SearchTreeEntry(1.0f, 0)) |
|
// st.add(SearchTreeEntry(2.0f, 0)) |
|
// st.add(SearchTreeEntry(3.0f, 0)) |
|
// println(st.nearest(SearchTreeEntry(3.1f, 0))) |
|
|
|
val data: Dataset = AnnBenchmarks.euclidean.mnist784d60k |
|
val pdci = PDCI(m = 10, L = 3).fit(data.train) |
|
pdci.query(data.test(0, ::).t, k = 10, errorProb = 0.2f) |
|
// pdci.query(data.test(0, ::).t, k0 = 10000, k1 = 500000) |
|
|
|
println("Done") |
|
} |