Skip to content

Instantly share code, notes, and snippets.

@polytypic
Last active January 9, 2024 12:02
Show Gist options
  • Save polytypic/fa0880ab4f112c2f3cb86b54d5cddc65 to your computer and use it in GitHub Desktop.
Save polytypic/fa0880ab4f112c2f3cb86b54d5cddc65 to your computer and use it in GitHub Desktop.
A hack to implement efficient TLS (thread-local-storage)

Currently OCaml 5 provides a Domain.DLS module for domain-local storage.

Unfortunately,

  1. there is no corresponding Thread.TLS for (sys)thread-local storage, and

  2. the current implementation of Domain.DLS is not thread-safe.

I don’t want to spend time to motivate this topic, but for many of the use cases of Domain.DLS, what you actually want, is to use a Thread.TLS. IOW, many of the uses of Domain.DLS are probably “wrong” and should actually use a Thread.TLS, because, when using Domain.DLS, the implicit assumption is often that you don’t have multiple threads on the domain, but that is typically decided at a higher level in the application and so making such an assumption is typically not safe.

Domain.DLS is not thread-safe

I mentioned that the current implementation of Domain.DLS is not thread-safe. What I mean by that is that the current implementation is literally not thread-safe at all in the sense that unrelated concurrent Domain.DLS accesses can actually break the DLS. That is because the state updates performed by Domain.DLS contain safe-points during which the OCaml runtime may switch between (sys)threads.

Consider the implementation of Domain.DLS.get:

let get (idx, init) =
  let st = maybe_grow idx in
  let v = st.(idx) in
  if v == unique_value then
    let v' = Obj.repr (init ()) in
    st.(idx) <- (Sys.opaque_identity v');
    Obj.magic v'
  else Obj.magic v

If there are two (or more) threads on a single domain that concurrently call get before init has been called initially, then what might happen is that init gets called twice (or more) and the threads get different values which could e.g. be pointers to two different mutable objects.

Consider the implementation of maybe_grow:

let maybe_grow idx =
  let st = get_dls_state () in
  let sz = Array.length st in
  if idx < sz then st
  else begin
    let rec compute_new_size s =
      if idx < s then s else compute_new_size (2 * s)
    in
    let new_sz = compute_new_size sz in
    let new_st = Array.make new_sz unique_value in
    Array.blit st 0 new_st 0 sz;
    set_dls_state new_st;
    new_st
  end

Imagine calling get (which calls maybe_grow) with two different keys from two different threads concurrently. The end result might be that two different arrays are allocated and only one of them “wins”. What this means, for example, is that effects of set calls may effectively be undone by concurrent calls of get.

In other words, the Domain.DLS, as it is currently implemented, is not thread-safe.

How to implement an efficient Thread.TLS?

If you dig into the implementation of threads, you will notice that the opaque Thread.t type is actually a heap block (record) of three fields. You can see the Thread.t accessors:

#define Ident(v) Field(v, 0)
#define Start_closure(v) Field(v, 1)
#define Terminated(v) Field(v, 2)

and the Thread.t allocation:

static value caml_thread_new_descriptor(value clos)
{
  CAMLparam1(clos);
  CAMLlocal1(mu);
  value descr;
  /* Create and initialize the termination semaphore */
  mu = caml_threadstatus_new();
  /* Create a descriptor for the new thread */
  descr = caml_alloc_3(0, Val_long(atomic_fetch_add(&thread_next_id, +1)),
                       clos, mu);
  CAMLreturn(descr);
}
The second field, Start_closure, is used to pass the closure to the thread start:

static void * caml_thread_start(void * v)
{
  caml_thread_t th = (caml_thread_t) v;
  int dom_id = th->domain_id;
  value clos;
  void * signal_stack;

  caml_init_domain_self(dom_id);

  st_tls_set(caml_thread_key, th);

  thread_lock_acquire(dom_id);
  restore_runtime_state(th);
  signal_stack = caml_init_signal_stack();

  clos = Start_closure(Active_thread->descr);
  caml_modify(&(Start_closure(Active_thread->descr)), Val_unit);
  caml_callback_exn(clos, Val_unit);
  caml_thread_stop();
  caml_free_signal_stack(signal_stack);
  return 0;
}

and, as seen above, it is overwritten with the unit value before the closure is called.

What this means is that when you call Thread.self () and get a reference to the current Thread.t, the Start_closure field of that heap block will be the unit value:

assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())

Let’s hijack that field for the purpose of implementing an efficient TLS!

Here is the full hack:

module TLS : sig
  type 'a key
  val new_key : (unit -> 'a) -> 'a key
  val get : 'a key -> 'a
  val set : 'a key -> 'a -> unit
end = struct
  type 'a key = { index : int; compute : unit -> 'a }

  let counter = Atomic.make 0
  let unique () = Obj.repr counter

  let new_key compute =
    let index = Atomic.fetch_and_add counter 1 in
    { index; compute }

  type t = { _id : int; mutable tls : Obj.t }

  let ceil_pow_2_minus_1 n =
    let n = n lor (n lsr 1) in
    let n = n lor (n lsr 2) in
    let n = n lor (n lsr 4) in
    let n = n lor (n lsr 8) in
    let n = n lor (n lsr 16) in
    if Sys.int_size > 32 then n lor (n lsr 32) else n

  let[@inline never] grow_tls t before index =
    let new_length = ceil_pow_2_minus_1 (index + 1) in
    let after = Array.make new_length (unique ()) in
    Array.blit before 0 after 0 (Array.length before);
    t.tls <- Obj.repr after;
    after

  let[@inline] get_tls index =
    let t = Obj.magic (Thread.self ()) in
    let tls = t.tls in
    if Obj.is_int tls then grow_tls t [||] index
    else
      let tls = (Obj.magic tls : Obj.t array) in
      if index < Array.length tls then tls else grow_tls t tls index

  let get key =
    let tls = get_tls key.index in
    let value = Array.unsafe_get tls key.index in
    if value != unique () then Obj.magic value
    else
      let value = key.compute () in
      Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
      value

  let set key value =
    let tls = get_tls key.index in
    Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))
end

The above achieves about 80% of the performance of Domain.DLS allowing roughly 241M TLS.gets/s (vs 296M Domain.DLS.gets/s) on my laptop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment