Last active
July 21, 2024 13:25
-
-
Save brendanzab/3b56f900248ed70ce9be6f9c4021c548 to your computer and use it in GitHub Desktop.
Attempts at encoding state monads using mutable references in OCaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module IndexedMonad = struct | |
module type S = sig | |
type ('i, 'a) t | |
val pure : 'a -> (_, 'a) t | |
val bind : ('i, 'a) t -> ('a -> ('i, 'b) t) -> ('i, 'b) t | |
end | |
(** Operators to make working with indexed monads more pleasant *) | |
module type Notation = sig | |
type ('i, 'a) t | |
(** Binding operators *) | |
val ( let* ) : ('i, 'a) t -> ('a -> ('i, 'b) t) -> ('i, 'b) t | |
val ( and* ) : ('i, 'a) t -> ('i, 'b) t -> ('i, 'a * 'b) t | |
val ( let+ ) : ('i, 'a) t -> ('a -> 'b) -> ('i, 'b) t | |
val ( and+ ) : ('i, 'a) t -> ('i, 'b) t -> ('i, 'a * 'b) t | |
end | |
module Notation (M : S) : Notation | |
with type ('i, 'a) t = ('i, 'a) M.t | |
= struct | |
type ('i, 'a) t = ('i, 'a) M.t | |
let ( let* ) = M.bind | |
let ( and* ) t n = | |
let* x = t in | |
let* y = n in | |
M.pure (x, y) | |
let ( let+ ) t f = M.bind t (fun x -> M.pure (f x)) | |
let ( and+ ) t n = ( and* ) t n | |
end | |
(** A monad indexed by a region parameter, allowing for a more efficient | |
implementation of mutable state. *) | |
module State : sig | |
include IndexedMonad.S | |
(** A mutable reference, tied to some region *) | |
type ('r, 'a) ref = | |
private 'a Stdlib.ref | |
(** Create a mutable reference in the current region *) | |
val ref : 'a -> ('r, ('r, 'a) ref) t | |
(** Access the shared state from the environment *) | |
val read : ('r, 'a) ref -> ('r, 'a) t | |
(** Replace the shared state of the environment *) | |
val write : 'a -> ('r, 'a) ref -> ('r, unit) t | |
(** A type that binds a new region parameter *) | |
type 'a region = { | |
region : 'r. unit -> ('r, 'a) t; | |
} | |
(** Run a computation in a region *) | |
val run : 'a region -> 'a | |
end = struct | |
type ('r, 'a) t = unit -> 'a | |
let bind t f = fun x -> f (t x) x | |
let pure x = fun _ -> x | |
type ('r, 'a) ref = 'a Stdlib.ref | |
let ref x = fun () -> Stdlib.ref x | |
let read x = fun () -> !x | |
let write x rx = fun () -> rx := x | |
type 'a region = { | |
region : 'r. unit -> ('r, 'a) t; | |
} | |
let run { region } = region () () | |
end | |
module Example = struct | |
open Notation (State) | |
let test = State.run { | |
region = fun () -> | |
let* x = State.ref 1 in | |
let* () = x |> State.write 3 in | |
State.read x | |
} | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment