Skip to content

Instantly share code, notes, and snippets.

@elizarov
Last active October 23, 2023 20:19
Show Gist options
  • Save elizarov/5bbbe5a3b88985ae577d8ec3706e85ef to your computer and use it in GitHub Desktop.
Save elizarov/5bbbe5a3b88985ae577d8ec3706e85ef to your computer and use it in GitHub Desktop.
Delimited Continuations shift/reset in Kotlin
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
/**
* Implementation for Delimited Continuations `shift`/`reset` primitives via Kotlin Coroutines.
* See [https://en.wikipedia.org/wiki/Delimited_continuation].
*
* The following LISP code:
*
* ```
* (* 2 (reset (+ 1 (shift k (k 5)))))
* ```
*
* translates to:
*
* ```
* 2 * reset<Int> {
* 1 + shift<Int> { k -> k(5) }
* }
* ```
*/
fun <T> reset(body: suspend DelimitedScope<T>.() -> T): T =
DelimitedScopeImpl<T>().also { impl ->
body.startCoroutine(impl, impl)
}.runReset()
interface DelimitedContinuation<T, R>
@RestrictsSuspension
abstract class DelimitedScope<T> {
abstract suspend fun <R> shift(block: suspend DelimitedScope<T>.(DelimitedContinuation<T, R>) -> T): R
abstract suspend operator fun <R> DelimitedContinuation<T, R>.invoke(value: R): T
}
private typealias ShiftedFun<T> = (DelimitedScope<T>, DelimitedContinuation<T, Any?>, Continuation<T>) -> Any?
@Suppress("UNCHECKED_CAST")
private class DelimitedScopeImpl<T> : DelimitedScope<T>(), Continuation<T>, DelimitedContinuation<T, Any?> {
private var shifted: ShiftedFun<T>? = null
private var shiftCont: Continuation<Any?>? = null
private var invokeCont: Continuation<T>? = null
private var invokeValue: Any? = null
private var result: Result<T>? = null
override val context: CoroutineContext
get() = EmptyCoroutineContext
override fun resumeWith(result: Result<T>) {
this.result = result
}
override suspend fun <R> shift(block: suspend DelimitedScope<T>.(DelimitedContinuation<T, R>) -> T): R =
suspendCoroutineUninterceptedOrReturn {
this.shifted = block as ShiftedFun<T>
this.shiftCont = it as Continuation<Any?>
COROUTINE_SUSPENDED
}
override suspend fun <R> DelimitedContinuation<T, R>.invoke(value: R): T =
suspendCoroutineUninterceptedOrReturn sc@{
check(invokeCont == null)
invokeCont = it
invokeValue = value
COROUTINE_SUSPENDED
}
fun runReset(): T {
// This is the stack of continuation in the `shift { ... }` after call to delimited continuation
var currentCont: Continuation<T> = this
// Trampoline loop to avoid call stack usage
loop@while (true) {
// Call shift { ... } body or break if there are no more shift calls
val shifted = takeShifted() ?: break
// If shift does not call any continuation, then its value becomes the result -- break out of the loop
try {
val value = shifted.invoke(this, this, currentCont)
if (value !== COROUTINE_SUSPENDED) {
result = Result.success(value as T)
break
}
} catch (e: Throwable) {
result = Result.failure(e)
break
}
// Shift has suspended - check if shift { ... } body had invoked continuation
currentCont = takeInvokeCont() ?: continue@loop
val shiftCont = takeShiftCont()
?: error("Delimited continuation is single-shot and cannot be invoked twice")
shiftCont.resume(invokeValue)
}
// Propagate the result to all pending continuations in shift { ... } bodies
currentCont.resumeWith(result!!)
// Return the final result
return result!!.getOrThrow()
}
private fun takeShifted() = shifted?.also { shifted = null }
private fun takeShiftCont() = shiftCont?.also { shiftCont = null }
private fun takeInvokeCont() = invokeCont?.also { invokeCont = null }
}
import org.junit.*
import kotlin.test.*
class DelimitedTest {
@Test
fun testNoShit() {
val x = reset<Int> { 42 }
assertEquals(42, x)
}
@Test
fun testShiftOnly() {
val x = reset<Int> {
shift<Int> { k -> k(42) }
}
assertEquals(42, x)
}
@Test
fun testShiftRight() {
val x = reset<Int> {
40 + shift<Int> { k -> k(2) }
}
assertEquals(42, x)
}
@Test
fun testShiftLeft() {
val x = reset<Int> {
shift<Int> { k -> k(40) } + 2
}
assertEquals(42, x)
}
@Test
fun testShiftBoth() {
val x = reset<Int> {
shift<Int> { k -> k(40) } +
shift<Int> { k -> k(2) }
}
assertEquals(42, x)
}
@Test
fun testShiftToString() {
val x = reset<String> {
shift<Int> { k -> k(42) }.toString()
}
assertEquals("42", x)
}
@Test
fun testShiftWithoutContinuationInvoke() {
val x = reset<Int> {
shift<String> {
42 // does not call continuation, just override result
}
0 // this is not called
}
assertEquals(42, x)
}
// From: https://en.wikipedia.org/wiki/Delimited_continuation
// (* 2 (reset (+ 1 (shift k (k 5)))))
// k := (+ 1 [])
@Test
fun testWikiSample() {
val x = 2 * reset<Int> {
1 + shift<Int> { k -> k(5) }
}
assertEquals(12, x)
}
// It must be extension on DelimitedScope<Int> to be able to shift
private suspend fun DelimitedScope<Int>.shiftFun(x: Int): Int =
shift<Int> { k -> k(x) } * 2
@Test
fun testShiftFromFunction() {
val x = reset<Int> {
2 + shiftFun(20)
}
assertEquals(42, x)
}
@Test
// Ensure there's no stack overflow because of many "shift" calls
fun testManyShifts() {
val res = reset<String> {
for (x in 0..10000) {
shift<Int> { k ->
k(x)
}
}
"OK"
}
assertEquals("OK", res)
}
@Test
// See https://gist.github.com/elizarov/5bbbe5a3b88985ae577d8ec3706e85ef#gistcomment-3304535
fun testShiftRemainderCalled() {
val log = ArrayList<String>()
val x = reset<Int> {
val y = shift<Int> { k ->
log += "before 1"
val r = k(1)
log += "after 1"
r
}
log += y.toString()
val z = shift<Int> { k ->
log += "before 2"
val r = k(2)
log += "after 2"
r
}
log += z.toString()
y + z
}
assertEquals(3, x)
assertEquals(listOf(
"before 1",
"1",
"before 2",
"2",
"after 2",
"after 1"
), log)
}
}
@flash-gordon
Copy link

@elizarov I figured testNoShit is named incorrectly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment