Skip to content

Instantly share code, notes, and snippets.

@oakkitten
Last active November 13, 2020 10:14
Show Gist options
  • Save oakkitten/820e4d7f502981abbfd48d1b4da323a7 to your computer and use it in GitHub Desktop.
Save oakkitten/820e4d7f502981abbfd48d1b4da323a7 to your computer and use it in GitHub Desktop.
Asynchronous iterator in Kotlin

Like sequence {}, but asynchronous. Like Flow, but with inverted control.

fun foo() = runBlocking {
    val seq = asyncSequence {
        for (i in 1..5) {
            delay(100)
            yield(i)
        }
    }
    
    for (i in seq) println(i)
}
package asynciterator
import kotlinx.coroutines.*
import kotlin.coroutines.intrinsics.createCoroutineUnintercepted
import kotlin.coroutines.*
import kotlin.experimental.ExperimentalTypeInference
import kotlin.system.getTimeMillis
interface AsyncCoroutineScope<T> : CoroutineScope {
suspend fun yield(value: T)
}
interface AsyncIterator<T> {
operator fun iterator() = this
suspend operator fun hasNext(): Boolean
suspend operator fun next(): T
}
private sealed class Status<T>
private class NotReady<T> : Status<T>()
private class Ready<T>(val value: T) : Status<T>()
private class Done<T> : Status<T>()
private class Failed<T>(val exception: Throwable) : Status<T>()
class AsyncIteratorImpl<T>(
override val coroutineContext: CoroutineContext
) : AsyncIterator<T>, AsyncCoroutineScope<T>, Continuation<Unit> {
private lateinit var mainContinuation: Continuation<Status<T>>
internal lateinit var sequenceContinuation: Continuation<Unit>
private var status: Status<T> = NotReady()
override operator fun iterator() = this
private suspend fun ensureNext() {
if (status is NotReady) {
status = suspendCoroutine {
mainContinuation = it
sequenceContinuation.resume(Unit)
}
status.let { if (it is Failed) throw it.exception }
}
}
private fun throwBadStatus(): Nothing {
when (val status = status) {
is Done -> throw NoSuchElementException("Asynchronous iterator is exhausted")
is Failed -> throw IllegalStateException("Asynchronous sequence block previously threw an exception", status.exception)
else -> throw IllegalStateException("This ain't possible")
}
}
override suspend operator fun hasNext(): Boolean {
ensureNext()
if (status is Done) return false
if (status is Ready) return true
throwBadStatus()
}
override suspend operator fun next(): T {
ensureNext()
status.let { if (it is Ready) return it.value.also { status = NotReady() } }
throwBadStatus()
}
override suspend fun yield(value: T) = suspendCancellableCoroutine<Unit> {
sequenceContinuation = it
mainContinuation.resume(Ready(value))
}
// the following is completion continuation stuff. resumeWith here will be called when the coroutine
// either run to the very end or throws an exception; this includes CancellationException
override val context = coroutineContext
override fun resumeWith(result: Result<Unit>) {
val exception = result.exceptionOrNull()
if (exception is CancellationException) return
mainContinuation.resume(if (exception != null) Failed(exception) else Done())
}
}
@OptIn(ExperimentalTypeInference::class)
fun <T> CoroutineScope.asyncSequence(@BuilderInference block: suspend AsyncCoroutineScope<T>.() -> Unit): AsyncIterator<T> {
return AsyncIteratorImpl<T>(coroutineContext).apply {
sequenceContinuation = block.createCoroutineUnintercepted(this, this)
}
}
@file:Suppress("ControlFlowWithEmptyBody")
package asynciterator
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlin.math.absoluteValue
import kotlin.system.measureTimeMillis
import kotlin.test.Test
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
fun CoroutineScope.makeSequence(delayDuration: Long = 0) = asyncSequence {
for (x in 1..5) {
delay(delayDuration)
yield(x)
}
}
@Suppress("DIVISION_BY_ZERO")
fun CoroutineScope.makeFailingSequence() = asyncSequence {
for (x in makeSequence()) yield(x)
1 / 0
}
@Test fun sum() = runBlocking {
var sum = 0
for (x in makeSequence()) sum += x
assertTrue { sum == 15 }
}
@Test fun concurrency() {
val single = measureTimeMillis {
runBlocking {
launch { for (x in makeSequence(100)) {} }
}
}
val double = measureTimeMillis {
runBlocking {
launch { for (x in makeSequence(100)) {} }
launch { for (x in makeSequence(100)) {} }
}
}
assertTrue { (double - single).absoluteValue / double < 0.1 }
}
@Test fun throws() = runBlocking<Unit> {
assertFailsWith(ArithmeticException::class) {
for (x in makeFailingSequence()) {}
}
}
@Test fun `throws if next() is called again`() = runBlocking<Unit> {
val seq = makeFailingSequence()
assertFailsWith(ArithmeticException::class) {
for (x in seq) {}
}
val e = assertFailsWith(IllegalStateException::class) {
for (x in seq) {}
}
assertTrue { e.cause is ArithmeticException }
}
@Test fun `doesn't automatically proceed after yield()`() = runBlocking<Unit> {
for (x in makeFailingSequence()) {
if (x == 5) break
}
}
@Test fun `runs finalizer`() {
var finalizerRun = false
runBlocking {
val seq = asyncSequence {
try {
for (x in 1..5) yield(x)
} finally {
finalizerRun = true
}
}
for (x in seq) {
if (x == 3) break
}
}
assertTrue { finalizerRun }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment