Skip to content

Instantly share code, notes, and snippets.

@pthariensflame
Last active June 15, 2022 18:42
  • Star 46 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save pthariensflame/5054294 to your computer and use it in GitHub Desktop.
An introduction to the indexed state monad in Haskell, Scala, and C#.

The Indexed State Monad in Haskell, Scala, and C#

Have you ever had to write code that made a complex series of succesive modifications to a single piece of mutable state? (Almost certainly yes.)

Did you ever wish you could make the compiler tell you if a particular operation on the state was illegal at a given point in the modifications? (If you're a fan of static typing, probably yes.)

If that's the case, the indexed state monad can help!

Motivation

If you're familiar with the regular state monad, you might wonder what makes the indexed version better. The answer is that the indexed state monad allows you to change the type of the state as you go, and yet remain completely type-safe the whole time. This means that the compiler will stop you if you try to perform a type-invalid operation on the current state, but you can then turn around and modify that state in such a way that its type changes. All indexed state computations indicate in their type not only what the type of the monadic result is (as with the regular state monad) but also, and seperately, what the input and output state types are; all three types are distinct and can be completely unrelated, which, with regards to the two state types, is not possible with the regular state monad: it doesn't allow the state to change its type during the computation, whereas the indexed state monad does.

Implementation

First, the necessary preliminary boilerplate:

Haskell:

module IndexedState where
import Prelude hiding (fmap, (>>=), (>>), return)

Scala:

package indexedState

C#:

using System;
using Unit = System.Reactive.Unit;

public namespace IndexedState
{
    // Unfortunately, C#'s standard Tuple type is invariant, so this is necessary.
    public interface Pair<out A, out B>
    {
        public A V1 { get; }
        public B V2 { get; }
    }
    public static class Pair
    {
        public static Pair<A, B> Create<A, B>(A a, B b)
        {
            return new PairImpl<A, B>(a, b);
        }
        private class PairImpl<A, B> : Pair<A, B>
        {
            private readonly A v1;
            private readonly B v2;
            public PairImpl(A a, B b)
            {
                v1 = a;
                v2 = b;
            }
            public A V1
            {
                get { return v1; }
            }
            public A V2
            {
                get { return v2; }
            }
        }
    }

The indexed state monad is implemented as a function from the initial state to a tuple of the result and the final state.

Haskell:

newtype IState i o a = IState { runIState :: i -> (a, o) }

evalIState :: IState i o a -> i -> a
evalIState st i = fst $ runIState st i

execIState :: IState i o a -> i -> o
execIState st i = snd $ runIState st i

Scala:

object IState extends IStateMonadFuncs with IStateFuncs {
  def apply[I, O, A](run: I => (A, O)): IState[I, O, A] = new IState[I, O, A](run)
}
final class IState[-I, +O, +A](val run: I => (A, O)) extends IStateMonadOps[I, O, A] {
  def eval(i: I): A = this.run(i)._1
  def exec(i: I): O = this.run(i)._2
}

C#:

    public delegate Pair<A, O> IState<in I, out O, out A>(I i);
    public static class IState
    {
        public static Pair<A, O> Run<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i);
        }
        public static A Eval<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i).V1;
        }
        public static O Exec<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i).V2;
        }

In order to make IState into an indexed monad, and be able to take advantange of each language's syntactic support for monads (do notation in Haskell, for comprehensions in Scala, and LINQ query expressions in C#), we need to implement the fundamental bind, unit, map, and join combinators, which Haskell, Scala and C# all call different things. Additionally, Haskell requires auxiliary convenience combinators called then and fail, and C# requires an auxiliary convenience combinator called bindMap.

Haskell:

-- unit
return :: a -> IState s s a
return a = IState $ \s -> (a, s)

-- map
fmap :: (a -> b) -> IState i o a -> IState i o b
fmap f v = IState $ \i -> let (a, o) = runIState v i
                          in (f a, o)

-- join
join :: IState i m (IState m o a) -> IState i o a
join v = IState $ \i -> let (w, m) = runIState v i
                        in runIState w m

-- bind
(>>=) :: IState i m a -> (a -> IState m o b) -> IState i o b
v >>= f = IState $ \i -> let (a, m) = runIState v i
                         in runIState (f a) m

-- then
(>>) :: IState i m a -> IState m o b -> IState i o b
v >> w = v >>= \_ -> w

-- fail
fail :: String -> IState i o a
fail str = error str

Scala:

private[indexedState] sealed trait IStateMonadFuncs { this: IState.type =>
  // unit
  def point[S, A](a: A): IState[S, S, A] = IState { s => (a, s) }
}
private[indexedState] sealed trait IStateMonadOps[-I, +O, +A] { this: IState[I, O, A] =>
  // map
  def map[B](f: A => B): IState[I, O, B] = IState { i =>
    val (a, o) = this.run(i)
    (f(a), o)
  }
  
  // join
  def flatten[E, B](implicit ev: A <:< IState[O, E, B]): IState[I, E, B] = IState { i =>
    val (n, o) = this.run(i)
    ev(n).run(o)
  }
  
  // bind
  def flatMap[E, B](f: A => IState[O, E, B]): IState[I, E, B] = IState { i =>
    val (n, o) = this.run(i)
    f(n).run(o)
  }
}

C#:

        // unit
        public static IState<S, S, A> ToIState<S, A>(this A a)
        {
            return (s => Pair.Create<A, S>(a, s));
        }
        
        // map
        public static IState<I, O, B> Select<I, O, A, B>(this IState<I, O, A> st, Func<A, B> func)
        {
            return (i =>
            {
                var ao = st.Run(i);
                return Pair.Create<B, O>(func(ao.V1), ao.V2);
            });
        }
        
        // join
        public static IState<I, O, A> Flatten<I, M, O, A>(this IState<I, M, IState<M, O, A>> st)
        {
            return (i =>
            {
                var qm = st.Run(i);
                return qm.V1.Run(qm.V2);
            });
        }
        
        // bind
        public static IState<I, O, B> SelectMany<I, M, O, A, B>(this IState<I, M, A> st, Func<A, IState<M, O, B>> func)
        {
            return (i =>
            {
                var am = st.Run(i);
                return func(am.V1).Run(am.V2);
            });
        }
        
        // bindMap
        public static IState<I, O, C> SelectMany<I, M, O, A, B, C>(this IState<I, M, A> st, Func<A, IState<M, O, B>> func, Func<A, B, C> selector)
        {
            return (i =>
            {
                var am = st.Run(i);
                var a = am.V1;
                var bo = func(a).Run(am.V2);
                return Pair.Create<C, O>(selector(a, bo.V1), bo.V2);
            });
        }

Now we just need some IState-specific primitives:

Haskell:

get :: IState s s s
get = IState $ \s -> (s, s)

put :: o -> IState i o ()
put o = IState $ \_ -> ((), o)

modify :: (i -> o) -> IState i o ()
modify f = IState $ \i -> ((), f i)

Scala:

private[indexedState] sealed trait IStateFuncs { this: IState.type =>
  def get[S]: IState[S, S, S] = IState { s => (s, s) }
  
  def put[O](o: O): IState[Any, O, Unit] = IState { _ => ((), o) }
  
  def modify[I, O](f: I => O): IState[I, O, Unit] = IState { i => ((), f(i)) }
}

C#:

        public static IState<S, S, S> Get<S>()
        {
            return (s => Pair.Create<S, S>(s, s));
        }
        public static IState<I, O, Unit> Put<I, O>(O o)
        {
            return (_ => Pair.Create<Unit, O>(Unit.Default, o));
        }
        public static IState<I, O, Unit> Modify<I, O>(Func<I, O> func)
        {
            return (i => Pair.Create<Unit, O>(Unit.Default, func(i)));
        }
    }
}

That's it! We can now use our indexed state monad with each language's built-in support for monad notation.

Usage

Here's how to take advantage of what we've just built:

Haskell:

{-# LANGUAGE RebindableSyntax #-}
import Prelude hiding ((>>=), (>>), return, fmap)
import IndexedState

-- | Performs 'someIntToCharFunction' on the input state, returning the old state.
myIStateComputation :: IState Int Char Int
myIStateComputation = do original <- get
                         modify someIntToCharFunction
                         return original

Scala:

import indexedState._

object example {
  /**
   * Performs `someIntToCharFunction` on the input state, returning the old state.
   */
  val myIStateComputation: IState[Int, Char, Int] = for {
    original <- IState.get[Int]
    _        <- IState.modify[Int, Char](someIntToCharFunction)
  } yield original
}

C#:

using IndexedState;

public static class Example
{
    /// <summary>
    ///  Performs <c>SomeIntToCharFunction</c> on the input state, returning the old state.
    /// </summary>
    public IState<Int, Char, Int> MyIStateComputation()
    {
        return (from original in IState.Get<Int>()
                from _        in IState.Modify<Int, Char>(SomeIntToCharFunction)
                select original);
    }
}
@paradigmatic
Copy link

Nice writing ! I think there are two typos in the Scala examples. First, you define flatten twice and forget the flatMap method. Second, in the example usage, you yield x which is undefined.

@ivanopagano
Copy link

It looks like another error slipped in the scala definition of exec, whose return type I expect to be O

def exec(i: I): A = this.run(i)._2

should be

def exec(i: I): O = this.run(i)._2

nice job!

@tdammers
Copy link

Shouldn't you make the Haskell one an instance of Monad (and probably Applicative too) instead of hiding the monadic operators from the Prelude import and then reimplementing your own?

@jberryman
Copy link

@tdammers take a look at the type for bind; you can't make a proper Monad instance (or: this isn't actually a monad, but more of a Category/arrowish sort of thing... maybe it has a proper name).

I played with this same construction recently too.

@sebastiaanvisser
Copy link

@pthariensflame
Copy link
Author

@sebastiaanvisser Yes, it is an IxMonad. All Monads are IxMonads that simply ignore their type parameters. :)

@Shimuuar
Copy link

Shimuuar commented Jun 6, 2013

@jberryman Ordinary monad is arrowish thing as well. And it have newtype wrapper for Arrow&Co instance:

(>=>) :: (a → m b) → (b → m c) → (a → m c)

newtype Kleisli m a b = Kleisli { runKleisli :: a  m b }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment