Last active
October 15, 2017 18:07
-
-
Save eshioji/8c89da5b809274be363c21bb39fa22ec to your computer and use it in GitHub Desktop.
GroupedIterator (very useful to use with Spark's repartitionAndSortWithinPartitions)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
/* | |
* Adopted from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala | |
*/ | |
object GroupedIterator { | |
def apply[K, V]( | |
input: Iterator[(K, V)]): Iterator[(K, Iterator[(K, V)])] = { | |
if (input.hasNext) { | |
new GroupedIterator(input.buffered) | |
} else { | |
Iterator.empty | |
} | |
} | |
} | |
/** | |
* Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to | |
* next will return a pair containing the current group and an iterator that will return all the | |
* elements of that group. Iterators for each group are lazily constructed by extracting rows | |
* from the input iterator. As such, full groups are never materialized by this class. | |
* | |
* Example input: | |
* {{{ | |
* Input: [a, 1], [b, 2], [b, 3] | |
* Grouping: x#1 | |
* InputSchema: x#1, y#2 | |
* }}} | |
* | |
* Result: | |
* {{{ | |
* First call to next(): ([a], Iterator([a, 1]) | |
* Second call to next(): ([b], Iterator([b, 2], [b, 3]) | |
* }}} | |
* | |
* Note, the class does not handle the case of an empty input for simplicity of implementation. | |
* Use the factory to construct a new instance. | |
* | |
* @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or | |
* it is possible for the same group to appear more than once. | |
*/ | |
class GroupedIterator[K,V] private(input: BufferedIterator[(K, V)]) extends Iterator[(K, Iterator[(K, V)])] { | |
/** Creates a row containing only the key for a given input row. */ | |
val keyProjection = (k: K, v: V) => k | |
/** | |
* Holds null or the row that will be returned on next call to `next()` in the inner iterator. | |
*/ | |
var currentRow: (K, V) = input.next() | |
/** Holds a copy of an input row that is in the current group. */ | |
var currentGroup: K = currentRow._1 | |
var currentIterator = createGroupValuesIterator() | |
/** | |
* Return true if we already have the next iterator or fetching a new iterator is successful. | |
* | |
* Note that, if we get the iterator by `next`, we should consume it before call `hasNext`, | |
* because we will consume the input data to skip to next group while fetching a new iterator, | |
* thus make the previous iterator empty. | |
*/ | |
def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator | |
def next(): (K, Iterator[(K, V)]) = { | |
assert(hasNext) // Ensure we have fetched the next iterator. | |
val ret = (currentGroup, currentIterator) | |
currentIterator = null | |
ret | |
} | |
private def fetchNextGroupIterator(): Boolean = { | |
assert(currentIterator == null) | |
if (currentRow == null && input.hasNext) { | |
currentRow = input.next() | |
} | |
if (currentRow == null) { | |
// These is no data left, return false. | |
false | |
} else { | |
// Skip to next group. | |
// currentRow may be overwritten by `hasNext`, so we should compare them first. | |
while (currentGroup == currentRow._1 && input.hasNext) { | |
currentRow = input.next() | |
} | |
if (currentGroup == currentRow._1) { | |
// We are in the last group, there is no more groups, return false. | |
false | |
} else { | |
// Now the `currentRow` is the first row of next group. | |
currentGroup = currentRow._1 | |
currentIterator = createGroupValuesIterator() | |
true | |
} | |
} | |
} | |
private def createGroupValuesIterator(): Iterator[(K, V)] = { | |
new Iterator[(K, V)] { | |
def hasNext: Boolean = currentRow != null || fetchNextRowInGroup() | |
def next(): (K, V) = { | |
assert(hasNext) | |
val res = currentRow | |
currentRow = null | |
res | |
} | |
private def fetchNextRowInGroup(): Boolean = { | |
assert(currentRow == null) | |
if (input.hasNext) { | |
// The inner iterator should NOT consume the input into next group, here we use `head` to | |
// peek the next input, to see if we should continue to process it. | |
if (currentGroup == input.head._1) { | |
// Next input is in the current group. Continue the inner iterator. | |
currentRow = input.next() | |
true | |
} else { | |
// Next input is not in the right group. End this inner iterator. | |
false | |
} | |
} else { | |
// There is no more data, return false. | |
false | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment