Created
October 4, 2018 05:47
-
-
Save nicola007b/454bc77c435cff65e5cdd73ced316e1c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Copyright 2018 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. | |
* See LICENSE in the project root for license information. | |
*/ | |
package com.linkedin.nn.utils | |
import com.linkedin.nn.Types.{ItemId, ItemIdDistancePair} | |
import java.util.{PriorityQueue => JPriorityQueue} | |
import scala.collection.JavaConverters._ | |
import scala.collection.mutable | |
/** | |
* This is a simple wrapper around the scala [[mutable.PriorityQueue]] that allows it to only hold a fixed number of | |
* elements. By default, [[mutable.PriorityQueue]] behaves as a max-priority queue i.e as a max heap. [[TopNQueue]] | |
* can be used to get smallest-n elements in a streaming fashion. | |
* | |
* We also deduplicate the contents based on the first value of the tuple ([[ItemId]] id). | |
* | |
* @param maxCapacity max number of elements the queue will hold | |
*/ | |
class TopNQueue(maxCapacity: Int) extends Serializable { | |
private val ord = Ordering.by[ItemIdDistancePair, Double](_._2) | |
private val priorityQ = new JPriorityQueue[ItemIdDistancePair](maxCapacity, ord) | |
private val elements: mutable.HashSet[ItemId] = mutable.HashSet[ItemId]() // for deduplication | |
/** | |
* Enqueue elements in the queue | |
* @param elems The elements to enqueue | |
*/ | |
def enqueue(elems: ItemIdDistancePair*): Unit = { | |
elems.foreach { x => | |
if (!elements.contains(x._1)) { | |
if (priorityQ.size < maxCapacity) { | |
priorityQ.offer(x) | |
elements.add(x._1) | |
} else { | |
if (priorityQ.peek._2 > x._2) { | |
elements.remove(priorityQ.poll()._1) | |
priorityQ.offer(x) | |
elements.add(x._1) | |
} | |
} | |
} | |
} | |
} | |
def nonEmpty(): Boolean = !priorityQ.isEmpty | |
def iterator(): Iterator[ItemIdDistancePair] = priorityQ.asScala.toList.reverseIterator | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment