Skip to content

Instantly share code, notes, and snippets.

@wesen
Created July 21, 2023 23:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wesen/69f140f74f84d73cc5762785de261f27 to your computer and use it in GitHub Desktop.
Save wesen/69f140f74f84d73cc5762785de261f27 to your computer and use it in GitHub Desktop.

2023 07 21 Go Monads

class PromiseMonad<T> {
    constructor(private value: Promise<T>) {}

    bind<U>(transform: (value: T) => PromiseMonad<U>): PromiseMonad<U> {
        return new PromiseMonad(
            this.value.then(value => transform(value).value)
        );
    }

    static resolve<U>(value: U): PromiseMonad<U> {
        return new PromiseMonad(Promise.resolve(value));
    }
}
let monad = PromiseMonad.resolve(5);
monad = monad.bind(value => PromiseMonad.resolve(value * 2));
monad.value.then(console.log);  // Outputs: 10
type StateTransformer<S, M extends Monad<any>, T> = (state: S) => M extends Monad<infer R> ? Monad<[T, S]> : never;

class StateT<S, M extends Monad<any>, T> {
    constructor(public runStateT: StateTransformer<S, M, T>) {}

    bind<U>(f: (value: T) => StateT<S, M, U>): StateT<S, M, U> {
        return new StateT((s: S) => this.runStateT(s).bind(([t, s1]) => f(t).runStateT(s1)));
    }

    static lift<M extends Monad<any>, T>(m: M): StateT<any, M, T> {
        return new StateT(s => m.bind(t => M.resolve([t, s])));
    }

    static put<S, M extends Monad<any>>(s: S): StateT<S, M, void> {
        return new StateT(_ => M.resolve([undefined, s]));
    }

    static get<S, M extends Monad<any>>(): StateT<S, M, S> {
        return new StateT(s => M.resolve([s, s]));
    }
}
let monad = StateT.lift(PromiseMonad.resolve(5));
monad = monad.bind(value => StateT.put(value * 2));
monad.runStateT(0).value.then(console.log);  // Outputs: [undefined, 10]
interface Monad<T> {
    bind<U>(transform: (value: T) => Monad<U>): Monad<U>;
    static resolve<U>(value: U): Monad<U>;
}
type ErrorTransformer<E, M extends Monad<any>, T> = M extends Monad<infer R> ? Monad<Either<E, T>> : never;

class ErrorT<E, M extends Monad<any>, T> {
    constructor(public runErrorT: ErrorTransformer<E, M, T>) {}

    bind<U>(f: (value: T) => ErrorT<E, M, U>): ErrorT<E, M, U> {
        return new ErrorT(this.runErrorT.bind(either => either.match({
            left: e => M.resolve(Either.left(e)),
            right: t => f(t).runErrorT
        })));
    }

    static lift<M extends Monad<any>, T>(m: M): ErrorT<any, M, T> {
        return new ErrorT(m.bind(t => M.resolve(Either.right(t))));
    }

    static throwError<E, M extends Monad<any>, T>(e: E): ErrorT<E, M, T> {
        return new ErrorT(M.resolve(Either.left(e)));
    }
}
let monad = ErrorT.lift(PromiseMonad.resolve(5));
monad = monad.bind(value => value > 0 ? ErrorT.throwError("Value must be non-positive") : ErrorT.lift(PromiseMonad.resolve(value * 2)));
monad.runErrorT.value.then(console.log);  // Outputs: Left("Value must be non-positive")
class CancelablePromiseMonad<T> {
    private controller: AbortController;

    constructor(private value: Promise<T>) {
        this.controller = new AbortController();
    }

    bind<U>(transform: (value: T) => CancelablePromiseMonad<U>): CancelablePromiseMonad<U> {
        return new CancelablePromiseMonad(
            this.value.then(value => {
                if (this.controller.signal.aborted) {
                    throw new Error('Operation was cancelled');
                }
                return transform(value).value;
            })
        );
    }

    cancel() {
        this.controller.abort();
    }

    static resolve<U>(value: U): CancelablePromiseMonad<U> {
        return new CancelablePromiseMonad(Promise.resolve(value));
    }
}
let monad = CancelablePromiseMonad.resolve(5);
monad = monad.bind(value => CancelablePromiseMonad.resolve(value * 2));
monad.cancel();
monad.value.catch(console.log);  // Outputs: Error: Operation was cancelled
package main

import "time"

type Monad[T any] struct {
	value <-chan T
}

func Bind[T any, U any](m Monad[T], transform func(value T) Monad[U]) Monad[U] {
	return Monad[U]{
		value: func() <-chan U {
			c := make(chan U)
			go func() {
				defer close(c)
				// we only expect a single value from a monad's value
				for v := range m.value {
					for u := range transform(v).value {
						c <- u
					}
				}
			}()
			return c
		}(),
	}
}

func Resolve[T any](value T) Monad[T] {
	c := make(chan T, 1)
	c <- value
	close(c)
	return Monad[T]{value: c}
}

func main() {

	m1 := Resolve(1)
	m2 := Bind(m1, func(value int) Monad[int] {
		return Resolve(value + 1)
	})
	m3 := Bind(m2, func(value int) Monad[int] {
		return Resolve(value + 1)
	})
	m4After2Seconds := Bind(m3, func(value int) Monad[int] {
		// wait 2 seconds before finishing the computation
		time.Sleep(2 * time.Second)
		return Resolve(value + 1)
	})

	v := <-m4After2Seconds.value
	println(v)
}
Since go 1.18, go supports generics, the syntax is the following:

For simple functions:

func GMin[T constraints.Ordered](x, y T) T {
    if x < y {
        return x
    }
    return y
}

For structs and interfaces:

type Tree[T interface{}] struct {
    left, right *Tree[T]
    value       T
}

For methods, no additional generic types are allowed:

func (t *Tree[T]) Lookup(x T) *Tree[T] { ... }
cat "$filepath"
type Result[T any] struct {
    Value T
    Err   error
}

type Error string

func (e Error) Error() string { return string(e) }

type Monad[T any] struct {
	value <-chan Result[T]
}

func Bind[T any, U any](ctx context.Context, m Monad[T], transform func(value T) Monad[U]) Monad[U] {
	return Monad[U]{
		value: func() <-chan Result[U] {
			c := make(chan Result[U])
			go func() {
				defer close(c)
				for {
					select {
					case r, ok := <-m.value:
						if !ok {
							return
						}
						if r.Err != nil {
							c <- Result[U]{Err: r.Err}
							return
						}
						for u := range transform(r.Value).value {
							c <- u
						}
					case <-ctx.Done():
						c <- Result[U]{Err: Error("context done")}
						return
					}
				}
			}()
			return c
		}(),
	}
}

func Resolve[T any](value T) Monad[T] {
	c := make(chan Result[T], 1)
	c <- Result[T]{Value: value}
	close(c)
	return Monad[T]{value: c}
}

func Reject[T any](err error) Monad[T] {
	c := make(chan Result[T], 1)
	c <- Result[T]{Err: err}
	close(c)
	return Monad[T]{value: c}
}
package main

import (
	"context"
	"fmt"
	"log"
	"os"
	"time"
)

type Monad[T any] struct {
	value <-chan T
	logger *log.Logger
}

func NewMonad[T any](logger *log.Logger) Monad[T] {
	return Monad[T]{logger: logger}
}

func Bind[T any, U any](ctx context.Context, m Monad[T], transform func(value T) Monad[U]) Monad[U] {
	return Monad[U]{
		value: func() <-chan U {
			c := make(chan U)
			go func() {
				defer close(c)
				for {
					select {
					case v, ok := <-m.value:
						if !ok {
							return
						}
						for u := range transform(v).value {
							c <- u
						}
					case <-ctx.Done():
						m.logger.Println("context done")
						return
					}
				}
			}()
			return c
		}(),
		logger: m.logger,
	}
}

func Resolve[T any](m Monad[T], value T) Monad[T] {
	c := make(chan T, 1)
	c <- value
	close(c)
	return Monad[T]{value: c, logger: m.logger}
}

func main() {
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()

	logger := log.New(os.Stdout, "Async Monad: ", log.Ltime)

	m1 := NewMonad[int](logger)
	m1 = Resolve(m1, 1)
	m2 := Bind(ctx, m1, func(value int) Monad[int] {
		return Resolve(m1, value + 1)
	})
	m3 := Bind(ctx, m2, func(value int) Monad[int] {
		return Resolve(m1, value + 1)
	})
	m4After2Seconds := Bind(ctx, m3, func(value int) Monad[int] {
		select {
		case <-time.After(2 * time.Second):
		case <-ctx.Done():
			m1.logger.Println("context done in sleep")
			return Resolve(m1, 69)
		}

		return Resolve(m1, value + 1)
	})

	v := <-m4After2Seconds.value
	println(v)
}
package main

import (
	"context"
	"fmt"
	"log"
	"os"
	"sync"
	"time"
)

type Monad[T any] struct {
	value <-chan T
	logger *log.Logger
}

func NewMonad[T any](logger *log.Logger) Monad[T] {
	return Monad[T]{logger: logger}
}

func Bind[T any, U any](ctx context.Context, m Monad[T], transform func(value T) Monad[U]) Monad[U] {
	c := make(chan U)
	var wg sync.WaitGroup

	go func() {
		for v := range m.value {
			wg.Add(1)
			go func(v T) {
				defer wg.Done()
				for u := range transform(v).value {
					c <- u
				}
			}(v)
		}
		wg.Wait()
		close(c)
	}()

	return Monad[U]{value: c, logger: m.logger}
}

func Resolve[T any](m Monad[T], values ...T) Monad[T] {
	c := make(chan T, len(values))
	for _, v := range values {
		c <- v
	}
	close(c)
	return Monad[T]{value: c, logger: m.logger}
}

func main() {
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()

	logger := log.New(os.Stdout, "Async Monad: ", log.Ltime)

	m1 := NewMonad[int](logger)
	m1 = Resolve(m1, 1, 2, 3)
	m2 := Bind(ctx, m1, func(value int) Monad[int] {
		return Resolve(m1, value+1, value+2, value+3)
	})

	for v := range m2.value {
		println(v)
	}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment