Skip to content

Instantly share code, notes, and snippets.

@seriyps
Created October 1, 2020 23:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seriyps/637bf11e34bf4fc7535de6e46eda4f33 to your computer and use it in GitHub Desktop.
Save seriyps/637bf11e34bf4fc7535de6e46eda4f33 to your computer and use it in GitHub Desktop.
%% Version of epgsql:equery/3 that uses named statements and caches them in process dictionary
%%
%% Algorithm pseudocode is:
%% <pre>
%% stmt = cache_get(name)
%% if not stmt:
%% stmt = parse_and_describe(name, sql) # network roundtrip
%% cache_put(name, stmt)
%% return bind_and_execute(stmt, params) # network roundtrip
%% </pre>
%%
%% https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
%%
%% # 1st roundtrip (cached)
%% > Parse(name)
%% < ParseComplete
%% > Describe
%% < ParameterDescription
%% < RowDescription | NoData
%%
%% # 2nd roundtrip
%% > Bind
%% < BindComplete
%% > Execute
%% < {DataRow*
%% < CommandComplete} | EmptyQuery
%% > Close
%% < CloseComplete
%% > Sync
%% < ReadyForQuery
-module(epgsql_cmd_cached_equery).
-behaviour(epgsql_command).
-export([init/1, execute/2, handle_message/4]).
-export([run/3, run/4]).
-export_type([response/0]).
-type response() :: {ok, Count :: non_neg_integer(), Cols :: [epgsql:column()], Rows :: [tuple()]}
| {ok, Count :: non_neg_integer()}
| {ok, Cols :: [epgsql:column()], Rows :: [tuple()]}
| {error, epgsql:query_error()}.
-include("epgsql.hrl").
-include("protocol.hrl").
-record(cquery,
{
%% Data from client (init/1):
name :: iodata(),
sql :: iodata(),
params :: [any()],
%% Data either from cache or from `parse'
stmt :: #statement{} | undefined,
decoder :: undefined | epgsql_wire:row_decoder()
}).
run(C, SQL, Params) ->
run(C, integer_to_binary(erlang:phash2(SQL)), SQL, Params).
-spec run(epgsql:connection(), Name :: iodata(), epgsql:sql_query(), [epgsql:bind_param()]) ->
response().
run(C, Name, SQL, Params) ->
epgsql_sock:sync_command(C, ?MODULE, {Name, SQL, Params}).
init({Name, SQL, Params}) ->
#cquery{name = Name,
sql = SQL,
params = Params}.
execute(Sock, #cquery{stmt = undefined, name = Name, sql = Sql} = St) ->
case cache_get(Name) of
not_found ->
%% see epgsql_cmd_parse
Codec = epgsql_sock:get_codec(Sock),
ColumnEncoding = epgsql_wire:encode_types([], Codec),
epgsql_sock:send_multi(
Sock,
[
{?PARSE, [Name, 0, Sql, 0, ColumnEncoding]},
{?DESCRIBE, [?PREPARED_STATEMENT, Name, 0]},
{?FLUSH, []}
]),
{ok, Sock, St};
#statement{types = TypeNames, columns = Columns} = Stmt ->
Sock2 = epgsql_sock:notify(Sock, {types, TypeNames}),
Sock3 = epgsql_sock:notify(Sock2, {columns, Columns}),
execute(Sock3, St#cquery{stmt = Stmt})
end;
execute(Sock, #cquery{stmt = #statement{name = StatementName,
columns = Columns, types = Types} = Stmt,
name = InStatementName, params = Params} = St) ->
string:equal(StatementName, InStatementName) orelse
error({wrong_statement, InStatementName, StatementName}),
cache_put(StatementName, Stmt),
%% see epgsql_cmd_prepared_query
TypedParams = lists:zip(Types, Params),
Codec = epgsql_sock:get_codec(Sock),
Bin1 = epgsql_wire:encode_parameters(TypedParams, Codec),
Bin2 = epgsql_wire:encode_formats(Columns),
epgsql_sock:send_multi(
Sock,
[
{?BIND, ["", 0, StatementName, 0, Bin1, Bin2]},
{?EXECUTE, ["", 0, <<0:?int32>>]},
{?SYNC, []}
]),
{ok, Sock, St}.
%% Parse stage
handle_message(?PARSE_COMPLETE, <<>>, Sock, _State) ->
{noaction, Sock};
handle_message(?PARAMETER_DESCRIPTION, Bin, Sock, #cquery{name = Name, stmt = undefined} = St) ->
Codec = epgsql_sock:get_codec(Sock),
TypeInfos = epgsql_wire:decode_parameters(Bin, Codec),
OidInfos = [epgsql_binary:typeinfo_to_oid_info(Type, Codec) || Type <- TypeInfos],
TypeNames = [epgsql_binary:typeinfo_to_name_array(Type, Codec) || Type <- TypeInfos],
Sock2 = epgsql_sock:notify(Sock, {types, TypeNames}),
Stmt = #statement{name = Name,
parameter_info = OidInfos,
types = TypeNames},
{noaction, Sock2, St#cquery{stmt = Stmt}};
handle_message(?ROW_DESCRIPTION, <<Count:?int16, Bin/binary>>, Sock, #cquery{stmt = Stmt0} = St) ->
Codec = epgsql_sock:get_codec(Sock),
Columns = epgsql_wire:decode_columns(Count, Bin, Codec),
Columns2 = [Col#column{format = epgsql_wire:format(Col, Codec)} || Col <- Columns],
Decoder = epgsql_wire:build_decoder(Columns2, Codec),
Sock2 = epgsql_sock:notify(Sock, {columns, Columns2}),
Stmt = Stmt0#statement{columns = Columns2},
{requeue, Sock2, St#cquery{decoder = Decoder,
stmt = Stmt}};
handle_message(?NO_DATA, <<>>, Sock, #cquery{stmt = Stmt} = St) ->
Sock2 = epgsql_sock:notify(Sock, no_data),
{requeue, Sock2, St#cquery{stmt = Stmt#statement{columns = []}}};
%% Bind + Execute stage
handle_message(?BIND_COMPLETE, <<>>, Sock, #cquery{stmt = #statement{columns = Columns}} = St) ->
Codec = epgsql_sock:get_codec(Sock),
Decoder = epgsql_wire:build_decoder(Columns, Codec),
{noaction, Sock, St#cquery{decoder = Decoder}};
handle_message(?DATA_ROW, <<_Count:?int16, Bin/binary>>,
Sock, #cquery{decoder = Decoder} = St) ->
Row = epgsql_wire:decode_data(Bin, Decoder),
{add_row, Row, Sock, St};
handle_message(?COMMAND_COMPLETE, Bin, Sock,
#cquery{stmt = #statement{columns = Cols}} = St) ->
Complete = epgsql_wire:decode_complete(Bin),
Rows = epgsql_sock:get_rows(Sock),
Result = case Complete of
{_, Count} when Cols == [] ->
{ok, Count};
{_, Count} ->
{ok, Count, Cols, Rows};
_ ->
{ok, Cols, Rows}
end,
{add_result, Result, {complete, Complete}, Sock, St};
handle_message(?EMPTY_QUERY, <<>>, Sock, St) ->
{add_result, {ok, [], []}, {complete, empty}, Sock, St};
handle_message(?READY_FOR_QUERY, _Status, Sock, _State) ->
case epgsql_sock:get_results(Sock) of
[Result] ->
{finish, Result, done, Sock};
[] ->
{finish, done, done, Sock}
end;
handle_message(?ERROR, Error, Sock, _St) ->
Result = {error, Error},
{finish, Result, Result, Sock};
handle_message(_, _, _, _) ->
unknown.
%%
%% Internal
%%
cache_get(Name) ->
case get(stmt_cache) of
#{Name := Stmt} -> Stmt;
_ -> not_found
end.
cache_put(Name, Stmt) ->
Cache = case get(stmt_cache) of
undefined -> #{Name => Stmt};
Cache0 -> Cache0#{Name => Stmt}
end,
put(stmt_cache, Cache).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment