Skip to content

Instantly share code, notes, and snippets.

@vog
Created September 23, 2019 16:34
Show Gist options
  • Save vog/c0ee4cff2aeff4840b4e6c947f7a7ec2 to your computer and use it in GitHub Desktop.
Save vog/c0ee4cff2aeff4840b4e6c947f7a7ec2 to your computer and use it in GitHub Desktop.
(executable
(name test)
(libraries logs logs.lwt lwt lwt.unix sqlite3 uri)
(preprocess (pps lwt_ppx))
(modes native))
opam-version: "2.0"
synopsis: "Dummy"
homepage: "Dummy"
bug-reports: "Dummy"
maintainer: "Dummy"
authors: "Dummy"
depends: [
"dune"
"logs"
"lwt_ppx"
"sqlite3"
"uri"
]
#!/bin/sh
opam switch create -wy --solver=mccs . 4.07.0
opam install -wy --solver=mccs --set-root .
opam pin -y add lwt --dev-repo
opam exec -- dune build --root . test.exe
./_build/default/test.exe
module Caqti_prereq = struct
let ident x = x
let (%>) f g x = g (f x)
module List = struct
include List
let rec fold f = function
| [] -> fun acc -> acc
| x :: xs -> f x %> fold f xs
end
let default_log_src = Logs.Src.create "caqti"
end
module Caqti_mult = struct
type +'m t = (* not GADT due to variance *)
| One
let one : [> `One] t = One
end
module Caqti_driver_info = struct
type dialect_tag = [`Sqlite]
type parameter_style =
[ `None
| `Linear of string
| `Indexed of (int -> string) ]
type t = {
index: int;
uri_scheme: string;
dialect_tag: dialect_tag;
parameter_style: parameter_style;
describe_has_typed_params: bool;
describe_has_typed_fields: bool;
can_transact: bool;
can_pool: bool;
can_concur: bool;
}
let next_backend_index = ref 0
let create
~uri_scheme
~dialect_tag
~parameter_style
~can_pool
~can_concur
~can_transact
~describe_has_typed_params
~describe_has_typed_fields
() =
{
index = (let i = !next_backend_index in incr next_backend_index; i);
uri_scheme;
dialect_tag;
parameter_style;
describe_has_typed_params;
describe_has_typed_fields;
can_transact;
can_pool;
can_concur;
}
end
module Caqti_type = struct
type _ t =
| Unit : unit t
module Std = struct
let unit = Unit
end
include Std
end
module Caqti_request = struct
open Printf
type query =
| L of string
| P of int
| S of query list
type ('a, 'b, +'m) t = {
id: int option;
query: Caqti_driver_info.t -> query;
param_type: 'a Caqti_type.t;
row_type: 'b Caqti_type.t;
row_mult: 'm Caqti_mult.t;
} constraint 'm = [< `Zero | `One | `Many]
let last_id = ref (-1)
let create ?(oneshot = false) param_type row_type row_mult query =
let id = if oneshot then None else (incr last_id; Some !last_id) in
{id; query; param_type; row_type; row_mult}
let param_type request = request.param_type
let row_type request = request.row_type
let query_id request = request.id
let query request = request.query
(* Convenience *)
let invalid_arg_f fmt = ksprintf invalid_arg fmt
let format_query ~env qs =
let n = String.length qs in
let rec skip_quoted j =
if j = n then
invalid_arg_f "Caqti_request.create_p: Unmatched quote in %S" qs
else if qs.[j] = '\'' then
if j + 1 < n && qs.[j + 1] = '\'' then
skip_quoted (j + 2)
else
j + 1
else
skip_quoted (j + 1) in
let rec scan_int i p =
if i = n then (i, p) else
(match qs.[i] with
| '0'..'9' as ch ->
scan_int (i + 1) (p * 10 + Char.code ch - Char.code '0')
| _ -> (i, p)) in
let rec skip_end_paren j =
if j = n then invalid_arg_f "Unbalanced end-parenthesis in %S" qs else
if qs.[j] = '(' then skip_end_paren (skip_end_paren (j + 1)) else
if qs.[j] = ')' then j + 1 else
skip_end_paren (j + 1) in
let check_idr s =
let l = String.length s in
for i = 0 to (if l > 1 && s.[l - 1] = '.' then l - 2 else l - 1) do
(match s.[i] with
| 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' -> ()
| _ -> invalid_arg_f "Invalid character %C in identifier %S." s.[i] s)
done in
let rec loop p i j acc = (* acc is reversed *)
if j = n then L (String.sub qs i (j - i)) :: acc else
(match qs.[j] with
| '\'' ->
let k = skip_quoted (j + 1) in
loop p i k acc
| '?' ->
if p < 0 then invalid_arg "Mixed ? and $i style parameters." else
let acc = L (String.sub qs i (j - i)) :: acc in
loop (p + 1) (j + 1) (j + 1) (P p :: acc)
| '$' ->
if j + 1 = n then invalid_arg "$ at end of query" else
let acc = L (String.sub qs i (j - i)) :: acc in
(match qs.[j + 1] with
| '$' ->
let acc = L"$" :: acc in
loop p (j + 2) (j + 2) acc
| '0'..'9' ->
if p > 0 then invalid_arg "Mixed ? and $i style parameters." else
let k, p' = scan_int (j + 1) 0 in
let acc = P (p' - 1) :: acc in
loop (-1) k k acc
| '(' ->
let k = skip_end_paren (j + 2) in
let acc = env (String.sub qs (j + 2) (k - j - 3)) :: acc in
loop p k k acc
| '.' ->
let acc = env "." :: acc in
loop p (j + 2) (j + 2) acc
| 'a'..'z' ->
(match String.index qs '.' with
| exception Not_found -> invalid_arg "Unterminated '$'."
| k ->
let idr = String.sub qs (j + 1) (k - j) in
check_idr idr;
let acc = env idr :: acc in
loop p (k + 1) (k + 1) acc)
| _ ->
invalid_arg "Unescaped $ in query string.")
| _ ->
loop p i (j + 1) acc) in
(match loop 0 0 0 [] with
| [] -> invalid_arg "Caqti_request.create_p: Empty query string."
| [frag] -> frag
| rev_frags -> S (List.rev rev_frags))
let no_env _ _ = raise Not_found
let rec simplify = function
| L "" -> S []
| S frags -> S (frags |> List.map simplify |> List.filter ((<>) (S [])))
| L _ | P _ as frag -> frag
let create_p ?(env = no_env) ?oneshot param_type row_type row_mult qs =
create ?oneshot param_type row_type row_mult
(fun di ->
let env k =
(match simplify (env di k) with
| exception Not_found ->
let l = String.length k in
if l = 0 || k.[l - 1] <> '.' then
invalid_arg_f "No expansion provided for $(%s) \
as needed by query %S." k (qs di) else
let k' = String.sub k 0 (l - 1) in
(match simplify (env di k') with
| exception Not_found ->
invalid_arg_f "No expansion provided for $(%s) or $(%s) \
as needed by query %S." k k' (qs di)
| S[] as v -> v
| v -> S[v; L"."])
| v -> v)
in
format_query ~env (qs di))
let find ?env ?oneshot pt rt qs =
create_p ?env ?oneshot pt rt Caqti_mult.one (fun _ -> qs)
end
module Caqti_driver_lib = struct
open Caqti_prereq
let nonlinear_param_length templ =
let rec loop = function
| Caqti_request.L _ -> ident
| Caqti_request.P n -> max (n + 1)
| Caqti_request.S frags -> List.fold loop frags in
loop templ 0
let linear_param_order templ =
let a = Array.make (nonlinear_param_length templ) [] in
let rec loop = function
| Caqti_request.L _ -> fun j -> j
| Caqti_request.P i -> fun j -> a.(i) <- j :: a.(i); j + 1
| Caqti_request.S frags -> List.fold loop frags in
let _ = loop templ 0 in
Array.to_list a
let linear_query_string templ =
let buf = Buffer.create 64 in
let rec loop = function
| Caqti_request.L s -> Buffer.add_string buf s
| Caqti_request.P _ -> Buffer.add_char buf '?'
| Caqti_request.S frags -> List.iter loop frags in
loop templ;
Buffer.contents buf
end
module Caqti_error = struct
type msg = ..
let msg_pp = Hashtbl.create 7
let define_msg ~pp ec = Hashtbl.add msg_pp ec pp
type msg += Msg : string -> msg
let () =
let pp ppf = function
| Msg s -> Format.pp_print_string ppf s
| _ -> assert false in
define_msg ~pp [%extension_constructor Msg]
type load_error = {
uri : Uri.t;
msg : msg;
}
type connection_error = {
uri : Uri.t;
msg : msg;
}
type query_error = {
uri : Uri.t;
query : string;
msg : msg;
}
type coding_error = {
uri : Uri.t;
msg : msg;
}
(* Load *)
let load_rejected ~uri msg = `Load_rejected ({uri; msg} : load_error)
let load_failed ~uri msg = `Load_failed ({uri; msg} : load_error)
(* Connect *)
let connect_rejected ~uri msg =
`Connect_rejected ({uri; msg} : connection_error)
let connect_failed ~uri msg =
`Connect_failed ({uri; msg} : connection_error)
(* Call *)
let request_failed ~uri ~query msg =
`Request_failed ({uri; query; msg} : query_error)
(* Retrieve *)
let response_failed ~uri ~query msg =
`Response_failed ({uri; query; msg} : query_error)
let response_rejected ~uri ~query msg =
`Response_rejected ({uri; query; msg} : query_error)
(* Common *)
type call =
[ `Encode_rejected of coding_error
| `Encode_failed of coding_error
| `Request_rejected of query_error
| `Request_failed of query_error
| `Response_rejected of query_error ]
type retrieve =
[ `Decode_rejected of coding_error
| `Response_failed of query_error
| `Response_rejected of query_error ]
type call_or_retrieve = [call | retrieve]
type load =
[ `Load_rejected of load_error
| `Load_failed of load_error ]
type connect =
[ `Connect_rejected of connection_error
| `Connect_failed of connection_error
| `Post_connect of call_or_retrieve ]
type _load_or_connect = [load | connect]
type t = [load | connect | call | retrieve]
let rec _uri : 'a. ([< t] as 'a) -> Uri.t = function
| `Load_rejected ({uri; _} : load_error) -> uri
| `Load_failed ({uri; _} : load_error) -> uri
| `Connect_rejected ({uri; _} : connection_error) -> uri
| `Connect_failed ({uri; _} : connection_error) -> uri
| `Post_connect err -> _uri err
| `Encode_rejected ({uri; _} : coding_error) -> uri
| `Encode_failed ({uri; _} : coding_error) -> uri
| `Request_rejected ({uri; _} : query_error) -> uri
| `Request_failed ({uri; _} : query_error) -> uri
| `Decode_rejected ({uri; _} : coding_error) -> uri
| `Response_failed ({uri; _} : query_error) -> uri
| `Response_rejected ({uri; _} : query_error) -> uri
end
module Caqti_response_sig = struct
module type S = sig
type +'b future
type ('b, +'m) t
val find :
('b, [< `One]) t -> ('b, [> Caqti_error.retrieve]) result future
end
end
module Caqti_connection_sig = struct
module type Base = sig
type +'a future
module Response : Caqti_response_sig.S with type 'a future := 'a future
val call :
f: (('b, 'm) Response.t -> ('c, 'e) result future) ->
('a, 'b, 'm) Caqti_request.t -> 'a ->
('c, [> Caqti_error.call] as 'e) result future
end
module type S = sig
include Base
val find :
('a, 'b, [< `One]) Caqti_request.t -> 'a ->
('b, [> Caqti_error.call_or_retrieve] as 'e) result future
end
end
module Caqti_driver_sig = struct
module type System_common = sig
type +'a future
val (>>=) : 'a future -> ('a -> 'b future) -> 'b future
val (>|=) : 'a future -> ('a -> 'b) -> 'b future
val return : 'a -> 'a future
val join : unit future list -> unit future
module Mvar : sig
type 'a t
val create : unit -> 'a t
val store : 'a -> 'a t -> unit
val fetch : 'a t -> 'a future
end
module Log : sig
type 'a log = ('a, unit future) Logs.msgf -> unit future
val err : ?src: Logs.src -> 'a log
val warn : ?src: Logs.src -> 'a log
val info : ?src: Logs.src -> 'a log
val debug : ?src: Logs.src -> 'a log
end
end
module type System_unix = sig
include System_common
module Unix : sig
type file_descr
val wrap_fd : (file_descr -> 'a future) -> Unix.file_descr -> 'a future
end
module Preemptive : sig
val detach : ('a -> 'b) -> 'a -> 'b future
val run_in_main : (unit -> 'a future) -> 'a
end
end
module type S = sig
type +'a future
module type CONNECTION =
Caqti_connection_sig.Base with type 'a future := 'a future
val connect :
Uri.t -> ((module CONNECTION), [> Caqti_error.connect]) result future
end
module type Of_system_unix =
functor (System : System_unix) ->
S with type 'a future := 'a System.future
end
module Caqti_connect = struct
open Printf
let dynload_library = ref @@ fun lib ->
Error (sprintf "Neither %s nor the dynamic linker is linked into the \
application." lib)
let drivers = Hashtbl.create 11
let define_unix_driver scheme p = Hashtbl.add drivers scheme p
let load_driver_functor ~uri scheme =
(try Ok (Hashtbl.find drivers scheme) with
| Not_found ->
(match !dynload_library ("caqti-driver-" ^ scheme) with
| Ok () ->
(try Ok (Hashtbl.find drivers scheme) with
| Not_found ->
let msg = sprintf "The driver for %s did not register itself \
after apparently loading." scheme in
Error (Caqti_error.load_failed ~uri (Caqti_error.Msg msg)))
| Error msg ->
Error (Caqti_error.load_failed ~uri (Caqti_error.Msg msg))))
module Make_unix (System : Caqti_driver_sig.System_unix) = struct
open System
module type DRIVER =
Caqti_driver_sig.S with type 'a future := 'a System.future
let drivers : (string, (module DRIVER)) Hashtbl.t = Hashtbl.create 11
let load_driver uri =
(match Uri.scheme uri with
| None ->
let msg = "Missing URI scheme." in
Error (Caqti_error.load_rejected ~uri (Caqti_error.Msg msg))
| Some scheme ->
(try Ok (Hashtbl.find drivers scheme) with
| Not_found ->
(match load_driver_functor ~uri scheme with
| Ok make_driver ->
let module Make_driver =
(val make_driver : Caqti_driver_sig.Of_system_unix) in
let module Driver = Make_driver (System) in
let driver = (module Driver : DRIVER) in
Hashtbl.add drivers scheme driver;
Ok driver
| Error _ as r -> r)))
module type CONNECTION_BASE =
Caqti_connection_sig.Base with type 'a future := 'a System.future
module type CONNECTION =
Caqti_connection_sig.S with type 'a future := 'a System.future
type _connection = (module CONNECTION)
module Connection_of_base (D : DRIVER) (C : CONNECTION_BASE) : CONNECTION =
struct
module Response = C.Response
let use_count = ref 0
let use f =
if !use_count <> 0 then failwith "Concurrent access to DB connection.";
incr use_count;
if !use_count <> 1 then failwith "Concurrent access to DB connection.";
f () >|= fun r ->
decr use_count;
r
let call ~f req param = use (fun () -> C.call ~f req param)
let find q p = call ~f:Response.find q p
end
let connect uri : ((module CONNECTION), _) result future =
Printf.eprintf "DEBUG-connect\n%!";
(match load_driver uri with
| Ok driver ->
let module Driver = (val driver) in
Driver.connect uri >|=
(function
| Ok connection ->
let module Connection = (val connection) in
let module Connection = Connection_of_base (Driver) (Connection) in
Ok (module Connection : CONNECTION)
| Error err -> Error err)
| Error err ->
return (Error err))
end
end
module Caqti_lwt = struct
open Caqti_prereq
module System = struct
type 'a future = 'a Lwt.t
let (>>=) = Lwt.(>>=)
let (>|=) = Lwt.(>|=)
let return = Lwt.return
let join = Lwt.join
module Mvar = struct
type 'a t = 'a Lwt_mvar.t
let create = Lwt_mvar.create_empty
let store x v = Lwt.async (fun () -> Lwt_mvar.put v x)
let fetch = Lwt_mvar.take
end
module Log = struct
type 'a log = 'a Logs_lwt.log
let err ?(src = default_log_src) = Logs_lwt.err ~src
let warn ?(src = default_log_src) = Logs_lwt.warn ~src
let info ?(src = default_log_src) = Logs_lwt.info ~src
let debug ?(src = default_log_src) = Logs_lwt.debug ~src
end
module Unix = struct
type file_descr = Lwt_unix.file_descr
let wrap_fd f fd = f (Lwt_unix.of_unix_file_descr fd)
end
module Preemptive = Lwt_preemptive
end
include Caqti_connect.Make_unix (System)
end
module Caqti_driver_sqlite3 = struct
open Caqti_driver_lib
open Printf
let driver_info =
Caqti_driver_info.create
~uri_scheme:"sqlite3"
~dialect_tag:`Sqlite
~parameter_style:(`Linear "?")
~can_pool:false
~can_concur:false
~can_transact:true
~describe_has_typed_params:false
~describe_has_typed_fields:true
()
let get_uri_bool uri name =
(match Uri.get_query_param uri name with
| Some ("true" | "yes") -> Some true
| Some ("false" | "no") -> Some false
| Some _ ->
ksprintf invalid_arg "Boolean expected for URI parameter %s." name
| None -> None)
let get_uri_int uri name =
(match Uri.get_query_param uri name with
| Some s ->
(try Some (int_of_string s) with
| Failure _ ->
ksprintf invalid_arg "Integer expected for URI parameter %s." name)
| None -> None)
type Caqti_error.msg += Rc : Sqlite3.Rc.t -> Caqti_error.msg
let () =
let pp ppf = function
| Rc rc -> Format.pp_print_string ppf (Sqlite3.Rc.to_string rc)
| _ -> assert false in
Caqti_error.define_msg ~pp [%extension_constructor Rc]
let encode_param
: type a. uri: Uri.t -> Sqlite3.stmt -> a Caqti_type.t -> a ->
int list list -> (int list list, _) result =
fun ~uri:_ _stmt t x ->
(match t, x with
| Caqti_type.Unit, () -> fun os -> Ok os
)
let decode_row
: type b. uri: Uri.t -> Sqlite3.stmt -> int -> b Caqti_type.t ->
(int * b, _) result =
fun ~uri:_ _stmt i ->
(function
| Caqti_type.Unit -> Ok (i, ())
)
module Q = struct
end
module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
open System
let (>>=?) m mf = m >>= (function Ok x -> mf x | Error _ as r -> return r)
let (>|=?) m f = m >|= (function Ok x -> f x | Error _ as r -> r)
module type CONNECTION =
Caqti_connection_sig.Base with type 'a future := 'a System.future
module Connection (Db : sig val uri : Uri.t val db : Sqlite3.db end)
: CONNECTION =
struct
open Db
module Response = struct
type ('b, 'm) t = {
stmt: Sqlite3.stmt;
row_type: 'b Caqti_type.t;
query: string;
}
let fetch_row {stmt; row_type; query} =
(match Sqlite3.step stmt with
| Sqlite3.Rc.DONE -> Ok None
| Sqlite3.Rc.ROW ->
(match decode_row ~uri stmt 0 row_type with
| Ok (n, y) ->
let n' = Sqlite3.data_count stmt in
if n = n' then
Ok (Some y)
else
let msg = sprintf "Decoded only %d of %d fields." n n' in
let msg = Caqti_error.Msg msg in
Error (Caqti_error.response_rejected ~uri ~query msg)
| Error _ as r -> r)
| rc ->
Error (Caqti_error.response_failed ~uri ~query (Rc rc)))
let find resp =
let retrieve () =
(match fetch_row resp with
| Ok None ->
let msg = Caqti_error.Msg "Received no rows for find." in
Error (Caqti_error.response_rejected ~uri ~query:resp.query msg)
| Ok (Some y) ->
(match fetch_row resp with
| Ok None -> Ok y
| Ok (Some _) ->
let msg = "Received multiple rows for find." in
let msg = Caqti_error.Msg msg in
let query = resp.query in
Error (Caqti_error.response_rejected ~uri ~query msg)
| Error _ as r -> r)
| Error _ as r -> r) in
Preemptive.detach retrieve ()
end
let pcache = Hashtbl.create 19
let call ~f req param =
let param_type = Caqti_request.param_type req in
let row_type = Caqti_request.row_type req in
let prepare_helper query =
try
let stmt = Sqlite3.prepare db query in
(match Sqlite3.prepare_tail stmt with
| None -> Ok stmt
| Some stmt -> Ok stmt)
with Sqlite3.Error msg ->
let msg = Caqti_error.Msg msg in
Error (Caqti_error.request_failed ~uri ~query msg) in
let prepare () =
let templ = Caqti_request.query req driver_info in
let query = linear_query_string templ in
let os = linear_param_order templ in
Preemptive.detach prepare_helper query >|=? fun stmt ->
Ok (stmt, os, query) in
(match Caqti_request.query_id req with
| None -> prepare ()
| Some id ->
(try return (Ok (Hashtbl.find pcache id)) with
| Not_found ->
prepare () >|=? fun pcache_entry ->
Hashtbl.add pcache id pcache_entry;
Ok pcache_entry))
>>=? fun (stmt, os, query) ->
(* CHECKME: Does binding involve IO? *)
(match encode_param ~uri stmt param_type param os with
| Ok os ->
assert (os = []);
return (Ok Response.{stmt; query; row_type})
| Error _ as r -> return r)
>>=? fun resp ->
(* CHECKME: Does finalize or reset involve IO? *)
let cleanup () =
(match Caqti_request.query_id req with
| None ->
(match Sqlite3.finalize stmt with
| Sqlite3.Rc.OK -> return ()
| _ ->
Log.warn (fun p ->
p "Ignoring error when finalizing statement."))
| Some id ->
(match Sqlite3.reset stmt with
| Sqlite3.Rc.OK -> return ()
| _ ->
Log.warn (fun p ->
p "Dropping cache statement due to error.") >|= fun () ->
Hashtbl.remove pcache id)) in
(try f resp >>= fun r -> cleanup () >|= fun () -> r
with exn -> cleanup () >|= fun () -> raise exn (* should not happen *))
end
let connect uri =
try
(* Check URI and extract parameters. *)
assert (Uri.scheme uri = Some "sqlite3");
(match Uri.userinfo uri, Uri.host uri with
| None, (None | Some "") -> ()
| _ -> invalid_arg "Sqlite URI cannot contain user or host components.");
let mode =
(match get_uri_bool uri "write", get_uri_bool uri "create" with
| Some false, Some true -> invalid_arg "Create mode presumes write."
| (Some false), (Some false | None) -> Some `READONLY
| (Some true | None), (Some true | None) -> None
| (Some true | None), (Some false) -> Some `NO_CREATE) in
let busy_timeout = get_uri_int uri "busy_timeout" in
(* Connect, configure, wrap. *)
Preemptive.detach
(fun () ->
Sqlite3.db_open ~mutex:`FULL ?mode (Uri.path uri |> Uri.pct_decode))
() >|= fun db ->
(match busy_timeout with
| None -> ()
| Some timeout -> Sqlite3.busy_timeout db timeout);
let module Arg = struct let uri = uri let db = db end in
let module Db = Connection (Arg) in
Ok (module Db : CONNECTION)
with
| Invalid_argument msg ->
return (Error (Caqti_error.connect_rejected ~uri (Caqti_error.Msg msg)))
| Sqlite3.Error msg ->
return (Error (Caqti_error.connect_failed ~uri (Caqti_error.Msg msg)))
end
let () = Caqti_connect.define_unix_driver "sqlite3" (module Connect_functor)
end
let test () =
match%lwt Caqti_lwt.connect (Uri.of_string "sqlite3:test.db") with
| Error _ -> Lwt.return ()
| Ok (module Db : Caqti_lwt.CONNECTION) ->
let%lwt _ = Db.find (Caqti_request.find Caqti_type.unit Caqti_type.unit "select 1") () in
let%lwt () = Lwt_io.printf "Test" in
failwith "Test"
let () =
Printexc.record_backtrace true;
Lwt_main.run @@
test ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment