Skip to content

Instantly share code, notes, and snippets.

@mfp
Last active February 21, 2017 16:28
Show Gist options
  • Save mfp/0359cd3ad6494648ae63d2b49819b27e to your computer and use it in GitHub Desktop.
Save mfp/0359cd3ad6494648ae63d2b49819b27e to your computer and use it in GitHub Desktop.
turn any iter_s into bounded concurrency iter_n with dynamic limit
(*
* ocamlfind ocamlc -package lwt,lwt.unix -o region region.ml -linkpkg
* *)
module Region : sig
type t
val make : int -> t
val resize : t -> int -> unit
val enter : t -> int -> (unit -> 'a Lwt.t) -> 'a Lwt.t
val enter_p : t -> int -> (unit -> unit Lwt.t) -> unit Lwt.t
val await : t -> unit Lwt.t
end =
struct
open Lwt.Infix
type th = T : _ Lwt.t -> th
type t =
{ mutable size : int;
mutable count : int;
waiters : (unit Lwt.u * int) Queue.t;
ths : th Lwt_sequence.t;
}
let make count =
{ size = count; count = 0; waiters = Queue.create ();
ths = Lwt_sequence.create (); }
let resize reg sz = reg.size <- sz
let leave_region reg sz =
try
if reg.count - sz >= reg.size then raise Queue.Empty;
let (w, sz') = Queue.take reg.waiters in
reg.count <- reg.count - sz + sz';
Lwt.wakeup_later w ()
with Queue.Empty ->
reg.count <- reg.count - sz
let run_in_region_1 reg sz thr =
let th =
Lwt.finalize
thr
(fun () ->
leave_region reg sz;
Lwt.return_unit) in
let node = Lwt_sequence.add_l (T th) reg.ths in
Lwt.on_termination th (fun () -> Lwt_sequence.remove node);
th
let enter reg sz thr =
if reg.count >= reg.size then begin
let (res, w) = Lwt.task () in
Queue.add (w, sz) reg.waiters;
res >>= fun () -> run_in_region_1 reg sz thr
end else begin
reg.count <- reg.count + sz;
run_in_region_1 reg sz thr
end
let enter_p reg sz thr =
let run () =
ignore begin
Lwt.catch
(fun () -> run_in_region_1 reg sz thr)
(fun exn ->
Queue.iter
(fun (w, _) -> try Lwt.wakeup_exn w exn with _ -> ())
reg.waiters;
Queue.clear reg.waiters;
Lwt.return_unit)
end;
Lwt.return_unit
in
if reg.count >= reg.size then begin
let (res, w) = Lwt.task () in
Queue.add (w, sz) reg.waiters;
res >>= run
end else begin
reg.count <- reg.count + sz;
run ()
end
let await reg =
Lwt.join @@
Lwt_sequence.fold_l
(fun (T th) ths -> (th >>= fun _ -> Lwt.return_unit) :: ths) reg.ths []
end
let t0 = Unix.gettimeofday ()
let puts fmt =
Printf.printf (" [%3.1fs] " ^^ fmt ^^ "\n%!") (Unix.gettimeofday () -. t0)
let () =
Lwt_main.run begin
let open Lwt.Infix in
let l = Array.to_list @@ Array.init 10 (fun i -> i) in
let cnt = ref 0 in
let proc n =
incr cnt;
puts "-> ENTER %d (count %02d)" n !cnt;
Lwt_unix.sleep (Random.float 1.0) >>= fun () ->
decr cnt;
puts "<- EXIT %d (count %02d)" n !cnt;
Lwt.return_unit in
let region = Region.make 3 in
Lwt_list.iter_s
(fun n ->
puts "ITER %02d" n;
Region.enter_p region 1 (fun () -> proc n))
l >>= fun () ->
Region.await region >>= fun () ->
Lwt.return @@ print_endline "DONE"
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment