Skip to content

Instantly share code, notes, and snippets.

@bitwalker
Forked from pthariensflame/IndexedState.md
Last active December 14, 2015 18:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bitwalker/5130480 to your computer and use it in GitHub Desktop.
Save bitwalker/5130480 to your computer and use it in GitHub Desktop.

The Indexed State Monad in Haskell, Scala, and C#

Have you ever had to write code that made a complex series of succesive modifications to a single piece of mutable state? (Almost certainly yes.)

Did you ever wish you could make the compiler tell you if a particular operation on the state was illegal at a given point in the modifications? (If you're a fan of static typing, probably yes.)

If that's the case, the indexed state monad can help!

Motivation

If you're familiar with the regular state monad, you might wonder what makes the indexed version better. The answer is that the indexed state monad allows you to change the type of the state as you go, and yet remain completely type-safe the whole time. This means that the compiler will stop you if you try to perform a type-invalid operation on the current state, but you can then turn around and modify that state in such a way that its type changes. All indexed state computations indicate in their type not only what the type of the monadic result is (as with the regular state monad) but also, and seperately, what the input and output state types are; all three types are distinct and can be completely unrelated, which, with regards to the two state types, is not possible with the regular state monad: it doesn't allow the state to change its type during the computation, whereas the indexed state monad does.

Implementation

First, the necessary preliminary boilerplate:

Haskell:

module IndexedState where
import Prelude hiding (fmap, (>>=), (>>), return)

Scala:

package indexedState

C#:

using System;
using Unit = System.Reactive.Unit;

public namespace IndexedState
{
    // Unfortunately, C#'s standard Tuple type is invariant, so this is necessary.
    public interface Pair<out A, out B>
    {
        public A V1 { get; }
        public B V2 { get; }
    }
    public static class Pair
    {
        public static Pair<A, B> Create<A, B>(A a, B b)
        {
            return new PairImpl<A, B>(a, b);
        }
        private class PairImpl<A, B> : Pair<A, B>
        {
            private readonly A v1;
            private readonly B v2;
            public PairImpl(A a, B b)
            {
                v1 = a;
                v2 = b;
            }
            public A V1
            {
                get { return v1; }
            }
            public A V2
            {
                get { return v2; }
            }
        }
    }

The indexed state monad is implemented as a function from the initial state to a tuple of the result and the final state.

Haskell:

newtype IState i o a = IState { runIState :: i -> (a, o) }

evalIState :: IState i o a -> i -> a
evalIState st i = fst $ runIState st i

execIState :: IState i o a -> i -> o
execIState st i = snd $ runIState st i

Scala:

object IState extends IStateMonadFuncs with IStateFuncs {
  def apply[I, O, A](run: I => (A, O)): IState[I, O, A] = new IState[I, O, A](run)
}
final class IState[-I, +O, +A](val run: I => (A, O)) extends IStateMonadOps[I, O, A] {
  def eval(i: I): A = this.run(i)._1
  def exec(i: I): A = this.run(i)._2
}

C#:

    public delegate Pair<A, O> IState<in I, out O, out A>(I i);
    public static class IState
    {
        public static Pair<A, O> Run<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i);
        }
        public static A Eval<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i).V1;
        }
        public static O Exec<I, O, A>(this IState<I, O, A> st, I i)
        {
            return st(i).V2;
        }

In order to make IState into an indexed monad, and be able to take advantange of each language's syntactic support for monads (do notation in Haskell, for comprehensions in Scala, and LINQ query expressions in C#), we need to implement the fundamental bind, unit, map, and join combinators, which Haskell, Scala and C# all call different things. Additionally, Haskell requires auxiliary convenience combinators called then and fail, and C# requires an auxiliary convenience combinator called bindMap.

Haskell:

-- unit
return :: a -> IState s s a
return a = IState $ \s -> (a, s)

-- map
fmap :: (a -> b) -> IState i o a -> IState i o b
fmap f v = IState $ \i -> let (a, o) = runIState v i
                          in (f a, o)

-- join
join :: IState i m (IState m o a) -> IState i o a
join v = IState $ \i -> let (w, m) = runIState v i
                        in runIState w m

-- bind
(>>=) :: IState i m a -> (a -> IState m o b) -> IState i o b
v >>= f = IState $ \i -> let (a, m) = runIState v i
                         in runIState (f a) m

-- then
(>>) :: IState i m a -> IState m o b -> IState i o b
v >> w = v >>= \_ -> w

-- fail
fail :: String -> IState i o a
fail str = error str

Scala:

private[indexedState] sealed trait IStateMonadFuncs { this: IState.type =>
  // unit
  def point[S, A](a: A): IState[S, S, A] = IState { s => (a, s) }
}
private[indexedState] sealed trait IStateMonadOps[-I, +O, +A] { this: IState[I, O, A] =>
  // map
  def map[B](f: A => B): IState[I, O, B] = IState { i =>
    val (a, o) = this.run(i)
    (f(a), o)
  }
  
  // join
  def flatten[E, B](implicit ev: A <:< IState[O, E, B]): IState[I, E, B] = IState { i =>
    val (n, o) = this.run(i)
    ev(n).run(o)
  }
  
  // bind
  def flatMap[E, B](f: A => IState[O, E, B]): IState[I, E, B] = IState { i =>
    val (n, o) = this.run(i)
    f(n).run(o)
  }
}

C#:

        // unit
        public static IState<S, S, A> ToIState<S, A>(this A a)
        {
            return (s => Pair.Create<A, S>(a, s));
        }
        
        // map
        public static IState<I, O, B> Select<I, O, A, B>(this IState<I, O, A> st, Func<A, B> func)
        {
            return (i =>
            {
                var ao = st.Run(i);
                return Pair.Create<B, O>(func(ao.V1), ao.V2);
            });
        }
        
        // join
        public static IState<I, O, A> Flatten<I, M, O, A>(this IState<I, M, IState<M, O, A>> st)
        {
            return (i =>
            {
                var qm = st.Run(i);
                return qm.V1.Run(qm.V2);
            });
        }
        
        // bind
        public static IState<I, O, B> SelectMany<I, M, O, A, B>(this IState<I, M, A> st, Func<A, IState<M, O, B>> func)
        {
            return (i =>
            {
                var am = st.Run(i);
                return func(am.V1).Run(am.V2);
            });
        }
        
        // bindMap
        public static IState<I, O, C> SelectMany<I, M, O, A, B, C>(this IState<I, M, A> st, Func<A, IState<M, O, B>> func, Func<A, B, C> selector)
        {
            return (i =>
            {
                var am = st.Run(i);
                var a = am.V1;
                var bo = func(a).Run(am.V2);
                return Pair.Create<C, O>(selector(a, bo.V1), bo.V2);
            });
        }

Now we just need some IState-specific primitives:

Haskell:

get :: IState s s s
get = IState $ \s -> (s, s)

put :: o -> IState i o ()
put o = IState $ \_ -> ((), o)

modify :: (i -> o) -> IState i o ()
modify f = IState $ \i -> ((), f i)

Scala:

private[indexedState] sealed trait IStateFuncs { this: IState.type =>
  def get[S]: IState[S, S, S] = IState { s => (s, s) }
  
  def put[O](o: O): IState[Any, O, Unit] = IState { _ => ((), o) }
  
  def modify[I, O](f: I => O): IState[I, O, Unit] = IState { i => ((), f(i)) }
}

C#:

        public static IState<S, S, S> Get<S>()
        {
            return (s => Pair.Create<S, S>(s, s));
        }
        public static IState<I, O, Unit> Put<I, O>(O o)
        {
            return (_ => Pair.Create<Unit, O>(Unit.Default, o));
        }
        public static IState<I, O, Unit> Modify<I, O>(Func<I, O> func)
        {
            return (i => Pair.Create<Unit, O>(Unit.Default, func(i)));
        }
    }
}

That's it! We can now use our indexed state monad with each language's built-in support for monad notation.

Usage

Here's how to take advantage of what we've just build:

Haskell:

{-# LANGUAGE RebindableSyntax #-}
import Prelude hiding ((>>=), (>>), return, fmap)
import IndexedState

-- | Performs 'someIntToCharFunction' on the input state, returning the old state.
myIStateComputation :: IState Int Char Int
myIStateComputation = do original <- get
                         modify someIntToCharFunction
                         return original

Scala:

import indexedState._

object example {
  /**
   * Performs `someIntToCharFunction` on the input state, returning the old state.
   */
  val myIStateComputation: IState[Int, Char, Int] = for {
    original <- IState.get[Int]
    _        <- IState.modify[Int, Char](someIntToCharFunction)
  } yield x
}

C#:

using IndexedState;

public static class Example
{
    /// <summary>
    ///  Performs <c>SomeIntToCharFunction</c> on the input state, returning the old state.
    /// </summary>
    public IState<Int, Char, Int> MyStateComputation()
    {
        return (from original in IState.Get<Int>()
                from _        in IState.Modify<Int, Char>(SomeIntToCharFunction)
                select x)
    }
}

Imagine you're handling a network connection. The connection can be either open, temporarily closed, or permanently discarded. You want the compiler to tell you at compile time if you try to draw from or close a closed connection, if you try to open or discard an open connection, if you try to do anything with a discarded connection, or if you try to begin with a non-closed connection or end with a non-discarded connection.

You create an abstract base class Connection to hold any common code needed to use the underlying raw handle, placed in private and protected methods. You then create three subclasses of Connection: OpenConnection (with Close() and GetData() methods), ClosedConnection (with Open() and Discard() methods), and DiscardedConnection (with no additional methods). All the methods secretly mutate the underlying handle, but they return a new Connection of the appropriate subclass and using the same underlying handle. Obviously, you should not call more than one connection method on a given Connection object, but the indexed state monad will take care of that for you without a problem.

You then declare a method with no parameters (or just a variable, preferably immutable) that returns IState<ClosedConnection, DiscardedConnection, A>, where A is the type you actually want to return. When you implement that method (or give that variable its value), the compiler will do everything I said it would, and you can use LINQ syntax to make it look as though you're actually mutating a variable. You can even just pass a LINQ query expression directly to a method that expects an IState, so long as IState is the type you're manipulating inside the expression.

EDIT: The example above in code:

using System;
using IndexedState;
using Handle = ...; // import <c>Handle</c> from somewhere
using Unit = System.Reactive.Unit;

public abstract class Connection 
{
    protected readonly Handle handle;
    protected Connection(Handle h) { handle = h; }
    // any additional common code
}

public class OpenConnection : Connection
{
    public OpenConnection(Handle h) { super(h); }
    public ClosedConnection Close() { return new ClosedConnection(handle); } // dummy
    public byte[] GetData() { // implement using the handle }
}

public class ClosedConnection : Connection
{
    public ClosedConnection(Handle h) { super(h); }
    public OpenConnection Open() { return new OpenConnection(handle); } // dummy
    public DiscardedConnection Discard() { return new DiscardedConnection(handle); } // dummy
}

public class DiscardedConnection : Connection
{
    public DiscardedConnection(Handle h) { super(h); }
}

public static class Conn
{
    public static IState<OpenConnection, OpenConnection, byte[]> GetData() { return (c => Pair.Create(c.GetData(), c)); }
    public static IState<ClosedConnection, OpenConnection, Unit> Open() { return IState.Modify(c => c.Open()); }
    public static IState<OpenConnection, ClosedConnection, Unit> Close() { return IState.Modify(c => c.Close()); }
    public static IState<ClosedConnection, DiscardedConnection, Unit> Discard() { return IState.Modify(c => c.Discard()); }
    public static A RunWithConnection<A>(this IState<ClosedConnection, DiscardedConnection, A> st, Handle h) { st.Eval(new ClosedConnection(handle)) }
}


public static String GetFirstString_CompleteExample(Handle h)
{ return (
    from _u1        in Conn.Open()
    from stringData in Conn.GetData()
    from _u2        in Conn.Close()
    from _u3        in Conn.Discard()
    select Convert.ToBase64String(stringData)
).RunWithConnection(h); }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment