Skip to content

Instantly share code, notes, and snippets.

@eshioji
Last active October 15, 2017 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eshioji/8c89da5b809274be363c21bb39fa22ec to your computer and use it in GitHub Desktop.
Save eshioji/8c89da5b809274be363c21bb39fa22ec to your computer and use it in GitHub Desktop.
GroupedIterator (very useful to use with Spark's repartitionAndSortWithinPartitions)
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adopted from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
*/
object GroupedIterator {
def apply[K, V](
input: Iterator[(K, V)]): Iterator[(K, Iterator[(K, V)])] = {
if (input.hasNext) {
new GroupedIterator(input.buffered)
} else {
Iterator.empty
}
}
}
/**
* Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to
* next will return a pair containing the current group and an iterator that will return all the
* elements of that group. Iterators for each group are lazily constructed by extracting rows
* from the input iterator. As such, full groups are never materialized by this class.
*
* Example input:
* {{{
* Input: [a, 1], [b, 2], [b, 3]
* Grouping: x#1
* InputSchema: x#1, y#2
* }}}
*
* Result:
* {{{
* First call to next(): ([a], Iterator([a, 1])
* Second call to next(): ([b], Iterator([b, 2], [b, 3])
* }}}
*
* Note, the class does not handle the case of an empty input for simplicity of implementation.
* Use the factory to construct a new instance.
*
* @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or
* it is possible for the same group to appear more than once.
*/
class GroupedIterator[K,V] private(input: BufferedIterator[(K, V)]) extends Iterator[(K, Iterator[(K, V)])] {
/** Creates a row containing only the key for a given input row. */
val keyProjection = (k: K, v: V) => k
/**
* Holds null or the row that will be returned on next call to `next()` in the inner iterator.
*/
var currentRow: (K, V) = input.next()
/** Holds a copy of an input row that is in the current group. */
var currentGroup: K = currentRow._1
var currentIterator = createGroupValuesIterator()
/**
* Return true if we already have the next iterator or fetching a new iterator is successful.
*
* Note that, if we get the iterator by `next`, we should consume it before call `hasNext`,
* because we will consume the input data to skip to next group while fetching a new iterator,
* thus make the previous iterator empty.
*/
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
def next(): (K, Iterator[(K, V)]) = {
assert(hasNext) // Ensure we have fetched the next iterator.
val ret = (currentGroup, currentIterator)
currentIterator = null
ret
}
private def fetchNextGroupIterator(): Boolean = {
assert(currentIterator == null)
if (currentRow == null && input.hasNext) {
currentRow = input.next()
}
if (currentRow == null) {
// These is no data left, return false.
false
} else {
// Skip to next group.
// currentRow may be overwritten by `hasNext`, so we should compare them first.
while (currentGroup == currentRow._1 && input.hasNext) {
currentRow = input.next()
}
if (currentGroup == currentRow._1) {
// We are in the last group, there is no more groups, return false.
false
} else {
// Now the `currentRow` is the first row of next group.
currentGroup = currentRow._1
currentIterator = createGroupValuesIterator()
true
}
}
}
private def createGroupValuesIterator(): Iterator[(K, V)] = {
new Iterator[(K, V)] {
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()
def next(): (K, V) = {
assert(hasNext)
val res = currentRow
currentRow = null
res
}
private def fetchNextRowInGroup(): Boolean = {
assert(currentRow == null)
if (input.hasNext) {
// The inner iterator should NOT consume the input into next group, here we use `head` to
// peek the next input, to see if we should continue to process it.
if (currentGroup == input.head._1) {
// Next input is in the current group. Continue the inner iterator.
currentRow = input.next()
true
} else {
// Next input is not in the right group. End this inner iterator.
false
}
} else {
// There is no more data, return false.
false
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment