Last active
October 23, 2023 20:19
-
-
Save elizarov/5bbbe5a3b88985ae577d8ec3706e85ef to your computer and use it in GitHub Desktop.
Delimited Continuations shift/reset in Kotlin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kotlin.coroutines.* | |
import kotlin.coroutines.intrinsics.* | |
/** | |
* Implementation for Delimited Continuations `shift`/`reset` primitives via Kotlin Coroutines. | |
* See [https://en.wikipedia.org/wiki/Delimited_continuation]. | |
* | |
* The following LISP code: | |
* | |
* ``` | |
* (* 2 (reset (+ 1 (shift k (k 5))))) | |
* ``` | |
* | |
* translates to: | |
* | |
* ``` | |
* 2 * reset<Int> { | |
* 1 + shift<Int> { k -> k(5) } | |
* } | |
* ``` | |
*/ | |
fun <T> reset(body: suspend DelimitedScope<T>.() -> T): T = | |
DelimitedScopeImpl<T>().also { impl -> | |
body.startCoroutine(impl, impl) | |
}.runReset() | |
interface DelimitedContinuation<T, R> | |
@RestrictsSuspension | |
abstract class DelimitedScope<T> { | |
abstract suspend fun <R> shift(block: suspend DelimitedScope<T>.(DelimitedContinuation<T, R>) -> T): R | |
abstract suspend operator fun <R> DelimitedContinuation<T, R>.invoke(value: R): T | |
} | |
private typealias ShiftedFun<T> = (DelimitedScope<T>, DelimitedContinuation<T, Any?>, Continuation<T>) -> Any? | |
@Suppress("UNCHECKED_CAST") | |
private class DelimitedScopeImpl<T> : DelimitedScope<T>(), Continuation<T>, DelimitedContinuation<T, Any?> { | |
private var shifted: ShiftedFun<T>? = null | |
private var shiftCont: Continuation<Any?>? = null | |
private var invokeCont: Continuation<T>? = null | |
private var invokeValue: Any? = null | |
private var result: Result<T>? = null | |
override val context: CoroutineContext | |
get() = EmptyCoroutineContext | |
override fun resumeWith(result: Result<T>) { | |
this.result = result | |
} | |
override suspend fun <R> shift(block: suspend DelimitedScope<T>.(DelimitedContinuation<T, R>) -> T): R = | |
suspendCoroutineUninterceptedOrReturn { | |
this.shifted = block as ShiftedFun<T> | |
this.shiftCont = it as Continuation<Any?> | |
COROUTINE_SUSPENDED | |
} | |
override suspend fun <R> DelimitedContinuation<T, R>.invoke(value: R): T = | |
suspendCoroutineUninterceptedOrReturn sc@{ | |
check(invokeCont == null) | |
invokeCont = it | |
invokeValue = value | |
COROUTINE_SUSPENDED | |
} | |
fun runReset(): T { | |
// This is the stack of continuation in the `shift { ... }` after call to delimited continuation | |
var currentCont: Continuation<T> = this | |
// Trampoline loop to avoid call stack usage | |
loop@while (true) { | |
// Call shift { ... } body or break if there are no more shift calls | |
val shifted = takeShifted() ?: break | |
// If shift does not call any continuation, then its value becomes the result -- break out of the loop | |
try { | |
val value = shifted.invoke(this, this, currentCont) | |
if (value !== COROUTINE_SUSPENDED) { | |
result = Result.success(value as T) | |
break | |
} | |
} catch (e: Throwable) { | |
result = Result.failure(e) | |
break | |
} | |
// Shift has suspended - check if shift { ... } body had invoked continuation | |
currentCont = takeInvokeCont() ?: continue@loop | |
val shiftCont = takeShiftCont() | |
?: error("Delimited continuation is single-shot and cannot be invoked twice") | |
shiftCont.resume(invokeValue) | |
} | |
// Propagate the result to all pending continuations in shift { ... } bodies | |
currentCont.resumeWith(result!!) | |
// Return the final result | |
return result!!.getOrThrow() | |
} | |
private fun takeShifted() = shifted?.also { shifted = null } | |
private fun takeShiftCont() = shiftCont?.also { shiftCont = null } | |
private fun takeInvokeCont() = invokeCont?.also { invokeCont = null } | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.* | |
import kotlin.test.* | |
class DelimitedTest { | |
@Test | |
fun testNoShit() { | |
val x = reset<Int> { 42 } | |
assertEquals(42, x) | |
} | |
@Test | |
fun testShiftOnly() { | |
val x = reset<Int> { | |
shift<Int> { k -> k(42) } | |
} | |
assertEquals(42, x) | |
} | |
@Test | |
fun testShiftRight() { | |
val x = reset<Int> { | |
40 + shift<Int> { k -> k(2) } | |
} | |
assertEquals(42, x) | |
} | |
@Test | |
fun testShiftLeft() { | |
val x = reset<Int> { | |
shift<Int> { k -> k(40) } + 2 | |
} | |
assertEquals(42, x) | |
} | |
@Test | |
fun testShiftBoth() { | |
val x = reset<Int> { | |
shift<Int> { k -> k(40) } + | |
shift<Int> { k -> k(2) } | |
} | |
assertEquals(42, x) | |
} | |
@Test | |
fun testShiftToString() { | |
val x = reset<String> { | |
shift<Int> { k -> k(42) }.toString() | |
} | |
assertEquals("42", x) | |
} | |
@Test | |
fun testShiftWithoutContinuationInvoke() { | |
val x = reset<Int> { | |
shift<String> { | |
42 // does not call continuation, just override result | |
} | |
0 // this is not called | |
} | |
assertEquals(42, x) | |
} | |
// From: https://en.wikipedia.org/wiki/Delimited_continuation | |
// (* 2 (reset (+ 1 (shift k (k 5))))) | |
// k := (+ 1 []) | |
@Test | |
fun testWikiSample() { | |
val x = 2 * reset<Int> { | |
1 + shift<Int> { k -> k(5) } | |
} | |
assertEquals(12, x) | |
} | |
// It must be extension on DelimitedScope<Int> to be able to shift | |
private suspend fun DelimitedScope<Int>.shiftFun(x: Int): Int = | |
shift<Int> { k -> k(x) } * 2 | |
@Test | |
fun testShiftFromFunction() { | |
val x = reset<Int> { | |
2 + shiftFun(20) | |
} | |
assertEquals(42, x) | |
} | |
@Test | |
// Ensure there's no stack overflow because of many "shift" calls | |
fun testManyShifts() { | |
val res = reset<String> { | |
for (x in 0..10000) { | |
shift<Int> { k -> | |
k(x) | |
} | |
} | |
"OK" | |
} | |
assertEquals("OK", res) | |
} | |
@Test | |
// See https://gist.github.com/elizarov/5bbbe5a3b88985ae577d8ec3706e85ef#gistcomment-3304535 | |
fun testShiftRemainderCalled() { | |
val log = ArrayList<String>() | |
val x = reset<Int> { | |
val y = shift<Int> { k -> | |
log += "before 1" | |
val r = k(1) | |
log += "after 1" | |
r | |
} | |
log += y.toString() | |
val z = shift<Int> { k -> | |
log += "before 2" | |
val r = k(2) | |
log += "after 2" | |
r | |
} | |
log += z.toString() | |
y + z | |
} | |
assertEquals(3, x) | |
assertEquals(listOf( | |
"before 1", | |
"1", | |
"before 2", | |
"2", | |
"after 2", | |
"after 1" | |
), log) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@b-studios @elizarov thanks for the conversation and proposed suggestions. Jannis, Simon, and I worked in attempting multi-prompt and similar techniques and we have now a version of reset / shift that is gonna allow us to build effect handlers and remove the state label reflection tricks in the Arrow continuations. In case you are interested, this will keep evolving but this is the first attempt https://github.com/arrow-kt/arrow-core/pull/226/files if you have any comments on it feel free to do directly in the PR even if closed and we will address them there or as issues. Thanks again for all your help.