Skip to content

Instantly share code, notes, and snippets.

@NicolasT
Last active February 4, 2020 00:26
Show Gist options
  • Save NicolasT/5623368 to your computer and use it in GitHub Desktop.
Save NicolasT/5623368 to your computer and use it in GitHub Desktop.
RWST for state machines in OCaml
true: package(lwt)
true: package(lwt.unix)
<rwst.ml>: camlp4orf, use_monad
open Ocamlbuild_plugin;;
open Command;;
let ocamlfind_query pkg =
let cmd = Printf.sprintf "ocamlfind query %s" (Filename.quote pkg) in
Ocamlbuild_pack.My_unix.run_and_open cmd input_line;;
dispatch begin function
| After_rules ->
flag ["ocaml"; "pp"; "use_monad"]
(S[A(ocamlfind_query "monad-custom" ^ "/pa_monad.cmo")]);
| _ -> ()
end
(* Monoids, for the Writer part *)
module type MONOID = sig
type t
val mempty : t
val mappend : t -> t -> t
end
module ListM = functor(A : sig type t end) -> (struct
type t = A.t list
let mempty = []
let mappend a b = a @ b
end : MONOID with type t = A.t list)
(* Some useful signatures *)
module type MONAD = sig
type 'a t
val bind : 'a t -> ('a -> 'b t) -> 'b t
val return : 'a -> 'a t
end
module type RWS = sig
include MONAD
type r
val ask : r t
type w
val tell : w -> unit t
type s
val get : s t
val put : s -> unit t
end
(* Implementation of RWST, the Reader/Writer/State Monad Transformer *)
module RWST =
functor(R : sig type t end) ->
functor(W : MONOID) ->
functor(S : sig type t end) ->
functor(M : MONAD) -> (struct
(* Monad *)
type 'a t = RWST of (R.t -> S.t -> ('a * S.t * W.t) M.t)
let unRWST (RWST f) = f
let return a = RWST (fun _ s -> M.return (a, s, W.mempty))
let bind m k = RWST (fun r s -> perform with M.bind in
let f = unRWST m in
(a, s', w) <-- f r s;
let f' = unRWST (k a) in
(b, s'', w') <-- f' r s';
M.return (b, s'', W.mappend w w'))
(* MonadReader *)
type r = R.t
let ask = RWST (fun r s -> M.return (r, s, W.mempty))
(* MonadWriter *)
type w = W.t
let tell w = RWST (fun _ s -> M.return ((), s, w))
(* MonadState *)
type s = S.t
let get = RWST (fun _ s -> M.return (s, s, W.mempty))
let put s = RWST (fun _ _ -> M.return ((), s, W.mempty))
(* MonadTrans *)
let lift m = RWST (fun _ s -> perform with M.bind in
a <-- m;
M.return (a, s, W.mempty))
let runRWST (a : 'a t) (r : R.t) (s : S.t) =
let f = unRWST a in
f r s
end : sig
include RWS
val lift : 'a M.t -> 'a t
val runRWST : 'a t -> R.t -> S.t -> ('a * S.t * W.t) M.t
end with type r = R.t and type w = W.t and type s = S.t)
(* Lenses! *)
type ('a, 'b) lens = (('a -> 'b) * ('a -> 'b -> 'a))
module LensUtils = functor(M: RWS) -> (struct
let view (g, _) = M.bind M.ask (fun c -> M.return (g c))
let use (g, _) = M.bind M.get (fun s -> M.return (g s))
let (@=) (_, s) v = M.bind M.get (fun t -> M.put (s t v))
end : sig
val view : (M.r, 'b) lens -> 'b M.t
val use : (M.s, 'b) lens -> 'b M.t
val (@=) : (M.s, 'b) lens -> 'b -> unit M.t
end)
(* Application-specific datastructures *)
type config = { _configNodeId : string
; _configNodes : string list
; _configElectionTimeout : int
}
(* Lenses for config *)
let configNodeId : (config, string) lens =
let get c = c._configNodeId
and set c n = { c with _configNodeId = n } in
(get, set)
let configNodes : (config, string list) lens =
let get c = c._configNodes
and set c n = { c with _configNodes = n } in
(get, set)
let configElectionTimeout : (config, int) lens =
let get c = c._configElectionTimeout
and set c n = { c with _configElectionTimeout = n } in
(get, set)
type message = Accept of int
let string_of_message = function
Accept i -> Printf.sprintf "Accept %d" i
type command = Broadcast of message
| Send of (string * message)
| ResetElectionTimeout of int
let string_of_command = function
| Broadcast m -> Printf.sprintf "Broadcast %s" (string_of_message m)
| Send (n, m) -> Printf.sprintf "Send (%S, %s)" n (string_of_message m)
| ResetElectionTimeout i -> Printf.sprintf "ResetElectionTimeout %d" i
type event = Message of message
| ElectionTimeout
type slave_state = { _slaveI : int }
let string_of_slave_state s = Printf.sprintf "{ _slaveI = %d }" s._slaveI
type master_state = { _masterI : int }
let string_of_master_state s = Printf.sprintf "{ _masterI = %d }" s._masterI
type state = Slave of slave_state
| Master of master_state
let string_of_state = function
| Slave s -> Printf.sprintf "Slave %s" (string_of_slave_state s)
| Master s -> Printf.sprintf "Master %s" (string_of_master_state s)
module TransitionUtils = functor(S: sig type t end) -> functor(M : MONAD) -> (struct
module Transition =
RWST (struct type t = config end)
(ListM (struct type t = command end))
(struct type t = S.t end)
(M)
include Transition
module LU = LensUtils(Transition)
include LU
let (>>=) = bind
let runTransition = runRWST
let broadcast m = tell [Broadcast m]
let send n m = tell [Send (n, m)]
(* We can combine things: fetch something from config, and use it to emit a
* command *)
let resetElectionTimeout =
view configElectionTimeout >>= fun t ->
tell [ResetElectionTimeout t]
let currentState = get
(* This is obviously a bogus implementation *)
let isMajority l =
view configNodes >>= fun nodes ->
let m = true in
return m
end : sig
include MONAD
type r
type w
type s
val (>>=) : 'a t -> ('a -> 'b t) -> 'b t
val view : (r, 'b) lens -> 'b t
val use : (s, 'b) lens -> 'b t
val (@=) : (s, 'b) lens -> 'b -> unit t
val broadcast : message -> unit t
val send : string -> message -> unit t
val resetElectionTimeout : unit t
val currentState : s t
val isMajority : string list -> bool t
val lift : 'a M.t -> 'a t
val runTransition : 'a t -> r -> s -> ('a * s * w) M.t
end with type r = config
and type w = command list
and type s = S.t)
module type HANDLER = sig
type 'a t
type 'a m
type s
val handle : event -> state t
val runTransition : 'a t -> config -> s -> ('a * s * command list) m
end
module Slave = functor(M : MONAD) -> (struct
type s = slave_state
let i =
let get s = s._slaveI
and set s i = { s with _slaveI = i } in
(get, set)
module TU = TransitionUtils(struct type t = s end)(M)
open TU
type 'a t = 'a TU.t
type 'a m = 'a M.t
let handle = function
| ElectionTimeout -> perform
i' <-- use i;
return (Master { _masterI = i' + 1 })
| Message m -> match m with
Accept i' -> perform
i'' <-- use i;
if i'' > i'
then begin
perform resetElectionTimeout;
perform i @= i'' + 1;
perform broadcast (Accept i'');
i'' <-- use i;
return (Master { _masterI = i'' })
end
else return (Slave { _slaveI = 0 })
let runTransition = TU.runTransition
end : HANDLER with type s = slave_state and type 'a m = 'a M.t)
(* For some odd (well, demonstrational) reason, this one is not abstracted over
* some monad, but is hard-coded to Lwt
*)
module Master = (struct
type s = master_state
let i =
let get s = s._masterI
and set s i = { s with _masterI = i } in
(get, set)
module TU = TransitionUtils(struct type t = s end)(Lwt)
open TU
type 'a t = 'a TU.t
type 'a m = 'a Lwt.t
let handle = function
| ElectionTimeout -> perform
i' <-- use i;
return (Slave { _slaveI = i' })
| Message m -> perform
i' <-- use i;
perform i @= i' + 5;
perform send "node0" (Accept 1);
(* Underlying monad is Lwt, so we can lift actions from it *)
j <-- lift (Lwt.return 4);
perform resetElectionTimeout;
i' <-- use i;
perform broadcast (Accept (i' + j));
s <-- currentState;
return (Master s)
let runTransition = TU.runTransition
end : HANDLER with type s = master_state and type 'a m = 'a Lwt.t)
module Handle = struct
module S = Slave(Lwt)
let select (a, _, c) = Lwt.return (a, c)
let (>>=) = Lwt.bind
let handle cfg s evt = match s with
| Slave s' -> S.runTransition (S.handle evt) cfg s' >>= select
| Master s' -> Master.runTransition (Master.handle evt) cfg s' >>= select
end
;;
let cfg = { _configNodeId = "node0"
; _configNodes = ["node0"; "node1"]
; _configElectionTimeout = 10
}
and state0 = Slave { _slaveI = 10 }
and event0 = Message (Accept 1) in
let (a, w) = Lwt_main.run (Handle.handle cfg state0 event0) in
Printf.printf "New state: %s\n" (string_of_state a);
print_endline "Commands:";
List.iter (fun c -> Printf.printf " - %s\n" (string_of_command c)) w
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment