Skip to content

Instantly share code, notes, and snippets.

@marc0der
Last active November 15, 2020 21:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marc0der/d1b89b6077639fdd6ffd6df7b9b3855a to your computer and use it in GitHub Desktop.
Save marc0der/d1b89b6077639fdd6ffd6df7b9b3855a to your computer and use it in GitHub Desktop.
ST Monad
class ForST private constructor() {
companion object
}
typealias STOf<S, A> = arrow.Kind2<ForST, S, A>
typealias STPartialOf<S> = arrow.Kind<ForST, S>
@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <S, A> STOf<S, A>.fix(): ST<S, A> = this as ST<S, A>
abstract class ST<S, A> internal constructor() : STOf<S, A> {
companion object {
operator fun <S, A> invoke(a: () -> A): ST<S, A> {
val memo by lazy(a)
return object : ST<S, A>() {
override fun run(s: S) = Pair(memo, s)
}
}
fun <A> runST(st: RunnableST<A>): A =
st.invoke<Unit>().run(Unit).first
}
protected abstract fun run(s: S): Pair<A, S>
fun <B> map(f: (A) -> B): ST<S, B> = object : ST<S, B>() {
override fun run(s: S): Pair<B, S> {
val (a, s1) = this@ST.run(s)
return Pair(f(a), s1)
}
}
fun <B> flatMap(f: (A) -> ST<S, B>): ST<S, B> = object : ST<S, B>() {
override fun run(s: S): Pair<B, S> {
val (a, s1) = this@ST.run(s)
return f(a).run(s1)
}
}
}
interface RunnableST<A> {
fun <S> invoke(): ST<S, A>
}
abstract class STRef<S, A> private constructor() {
companion object {
operator fun <S, A> invoke(a: A): ST<S, STRef<S, A>> = ST {
object : STRef<S, A>() {
override var cell: A = a
}
}
}
protected abstract var cell: A
fun read(): ST<S, A> = ST {
cell
}
fun write(a: A): ST<S, Unit> = object : ST<S, Unit>() {
override fun run(s: S): Pair<Unit, S> {
cell = a
return Pair(Unit, s)
}
}
}
@extension
interface STMonad<S, A> : Monad<STPartialOf<S>> {
override fun <A> just(a: A): STOf<S, A> = ST { a }
override fun <A, B> STOf<S, A>.flatMap(
f: (A) -> STOf<S, B>
): STOf<S, B> =
this.fix().flatMap { a -> f(a).fix() }
override fun <A, B> tailRecM(
a: A,
f: (A) -> STOf<S, Either<A, B>>
): STOf<S, B> = TODO()
}
//TODO: wish to write the following with a for comprehension
val prog = object : RunnableST<Pair<Int, Int>> {
override fun <S> invoke(): ST<S, Pair<Int, Int>> =
STRef<S, Int>(10).flatMap { r1: STRef<S, Int> ->
STRef<S, Int>(20).flatMap { r2: STRef<S, Int> ->
r1.read().flatMap { x ->
r2.read().flatMap { y ->
r1.write(y + 1).flatMap {
r2.write(x + 1).flatMap {
r1.read().flatMap { a ->
r2.read().map { b ->
Pair(a, b)
}
}
}
}
}
}
}
}
}
fun main() {
println(ST.runST(prog))
}
@marc0der
Copy link
Author

marc0der commented Nov 15, 2020

@raulraja This is what I'd like to do instead of prog above.

val prog2 = object : RunnableST<Pair<Int, Int>> {
    override fun <S> invoke(): ST<S, Pair<Int, Int>> =
        ST.fx {
            val r1 = !STRef<S, Int>(10)
            val r2 = !STRef<S, Int>(20)
            val x = !r1.read()
            val y = !r2.read()
            !r1.write(y + 1)
            !r2.write(x + 1)
            val a = !r1.read()
            val b = !r2.read()
            a to b
        }
}

@marc0der
Copy link
Author

In the end, I was missing some boilerplate code. All I needed was some sugar on the ST's companion to delegate the fx function call through to the monad type class' fx function. Go figure!

fun <S, A> ST.Companion.fx(
    c: suspend MonadSyntax<STPartialOf<S>>.() -> A
): ST<S, A> =
    ST.monad<S, A>().fx.monad(c).fix()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment