Created
November 11, 2014 13:07
-
-
Save gsg/936d74c3fbc1f75ad623 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* | |
* GADT friendly hash tables. | |
*) | |
module type GadtHashEq = sig | |
type 'a key | |
type 'a value | |
val eq_assoc : 'a key -> 'b key -> 'b value -> 'a value option | |
val hash : 'a key -> int | |
end | |
module Make (O : GadtHashEq) : sig | |
type t | |
val create : unit -> t | |
val clear : t -> unit | |
val length : t -> int | |
val add : t -> 'a O.key -> 'a O.value -> unit | |
val alter : t -> 'a O.key -> ('a O.value -> 'a O.value) -> 'a O.value -> unit | |
val mem : t -> 'a O.key -> bool | |
val find : t -> 'a O.key -> 'a O.value option | |
val find_exn : t -> 'a O.key -> 'a O.value | |
val find_default : t -> 'a O.key -> 'a O.value -> 'a O.value | |
type iter_arg = { f : 'a . 'a O.key -> 'a O.value -> unit } | |
val iter : iter_arg -> t -> unit | |
type 'a fold_arg = { f : 'b . 'b O.key -> 'b O.value -> 'a -> 'a } | |
val fold : 'a fold_arg -> t -> 'a -> 'a | |
type info = { | |
max_chain_length : int; | |
chain_lengths : int array; | |
} | |
val info : t -> info | |
end = struct | |
type chain = | |
| Nil : chain | |
| Cons : 'a O.key * 'a O.value * chain -> chain | |
type t = { | |
mutable len : int; | |
mutable mask : int; | |
mutable data : chain array; | |
} | |
let create () = { len = 0; mask = 0; data = [|Nil|] } | |
let clear tbl = | |
tbl.len <- 0; | |
tbl.mask <- 0; | |
tbl.data <- [|Nil|] | |
let length tbl = tbl.len | |
let resize tbl newlen = | |
let newdata = Array.make newlen Nil in | |
let newmask = newlen - 1 in | |
let rec rehash_chain = function | |
| Nil -> () | |
| Cons (k, v, rest) -> | |
let i = O.hash k land newmask in | |
newdata.(i) <- Cons (k, v, newdata.(i)); | |
rehash_chain rest in | |
Array.iter rehash_chain tbl.data; | |
tbl.data <- newdata; | |
tbl.mask <- newmask | |
let is_pow2 n = n land (n - 1) = 0 | |
let add tbl k v = | |
let len = succ tbl.len in | |
if is_pow2 len then resize tbl (Array.length tbl.data * 2); | |
let i = O.hash k land tbl.mask in | |
tbl.data.(i) <- Cons (k, v, tbl.data.(i)); | |
tbl.len <- len | |
let find : type a . t -> a O.key -> a O.value option = | |
fun tbl key -> | |
let rec loop = function | |
| Nil -> None | |
| Cons (k, v, rest) -> | |
match O.eq_assoc key k v with | |
| None -> loop rest | |
| x -> x in | |
loop tbl.data.(O.hash key land tbl.mask) | |
let find_exn tbl key = | |
match find tbl key with | |
| Some x -> x | |
| None -> raise Not_found | |
let find_default tbl key default = | |
match find tbl key with | |
| Some x -> x | |
| None -> default | |
let mem tbl key = | |
match find tbl key with | |
| Some _ -> true | |
| None -> false | |
let alter : type a . t -> a O.key -> (a O.value -> a O.value) -> a O.value -> unit = | |
fun tbl key f init -> | |
let rec loop = function | |
| Nil -> Cons (key, init, Nil) | |
| Cons (k, v, rest) -> | |
match O.eq_assoc key k v with | |
| None -> Cons (k, v, loop rest) | |
| Some v' -> Cons (key, f v', rest) in | |
let i = O.hash key land tbl.mask in | |
tbl.data.(i) <- loop tbl.data.(i) | |
type iter_arg = { f : 'a . 'a O.key -> 'a O.value -> unit } | |
let iter : iter_arg -> t -> unit = | |
fun f tbl -> | |
let rec iter_chain = function | |
| Nil -> () | |
| Cons (k, v, rest) -> f.f k v; iter_chain rest in | |
Array.iter iter_chain tbl.data | |
type 'a fold_arg = { f : 'b . 'b O.key -> 'b O.value -> 'a -> 'a } | |
let fold : 'a fold_arg -> t -> 'a -> 'a = | |
fun f tbl init -> | |
let x = ref init in | |
let rec fold_chain = function | |
| Nil -> () | |
| Cons (k, v, rest) -> x := f.f k v !x; fold_chain rest in | |
Array.iter fold_chain tbl.data; | |
!x | |
type info = { | |
max_chain_length : int; | |
chain_lengths : int array; | |
} | |
let info tbl = | |
let maxlen = ref 0 in | |
let rec chainlen len = function | |
| Nil -> len | |
| Cons (_, _, rest) -> chainlen (len + 1) rest in | |
let lengths = Array.map (chainlen 0) tbl.data in | |
Array.iter (fun l -> maxlen := max !maxlen l) lengths; | |
{ | |
max_chain_length = !maxlen; | |
chain_lengths = lengths; | |
} | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment