This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-module(sql_utils). | |
-behaviour(gen_server). | |
-export([init/1, handle_call/3, handle_cast/2, | |
handle_info/2, terminate/2, code_change/3]). | |
-export([start_link/1, connect/4, prepare/1, query/3, | |
dirty/1, transaction/1, installed_schema_updates/0, | |
available_schema_updates/0, update_schema/0, | |
prepare_core_statements/0]). | |
-export_type([statement_group/0, statement_declaration/0]). | |
-type statement_group() :: {GroupName :: atom(), Statements :: list(statement_declaration())}. | |
-type statement_declaration() :: {StatementName :: atom(), ParamTypes :: list(atom()), Sql :: iolist()}. | |
start_link(_Opts) -> | |
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). | |
connect(Host, Username, Password, Opts) -> | |
{ok, Pid} = pgsql:connect(Host, Username, Password, Opts), | |
gen_server:call(?MODULE, {monitor, Pid}), | |
{ok, Pid}. | |
-spec prepare(StatementGroups :: list(statement_group())) -> ok. | |
prepare(StatementGroups) -> | |
Statements = | |
lists:flatmap( | |
fun({GroupName, Statements}) -> | |
lists:map( | |
fun({StatementName, Types, Sql}) -> | |
{{GroupName, StatementName}, Types, iolist_to_binary(Sql)} | |
end, Statements) | |
end, StatementGroups), | |
NewStatements = | |
lists:foldl( | |
fun({StatementFQN, Types, Sql}, Acc) -> | |
case ets:lookup(prepared_statements, StatementFQN) of | |
%New statement | |
[] -> | |
[{StatementFQN, Types, Sql, 0} | Acc]; | |
%No change | |
[{StatementFQN, Types, Sql, _}] -> Acc; | |
%Changed statement | |
[{StatementFQN, _, _, Counter}] -> | |
[{StatementFQN, Types, Sql, Counter + 1}] | |
end | |
end, | |
[], | |
Statements), | |
ets:insert(prepared_statements, NewStatements), | |
ok. | |
query(Module, StatementName, Params) -> | |
Conn = get_connection(), | |
StatementHandle = ensure_prepared(Conn, {Module, StatementName}), | |
ok = pgsql:bind(Conn, StatementHandle, Params), | |
pgsql:execute(Conn, StatementHandle). | |
ensure_prepared(Conn, StatementFQN) -> | |
{StatementFQN, Types, Statement, LatestVersion} = | |
case ets:lookup(prepared_statements, StatementFQN) of | |
[] -> error({undef_statement, StatementFQN}); | |
[StatementSpecs] -> StatementSpecs | |
end, | |
case ets:lookup(connections_statements, {Conn, StatementFQN}) of | |
%unprepared | |
[] -> parse(Conn, StatementFQN, Statement, Types, LatestVersion); | |
%prepared | |
[{_, StatementHandle, LatestVersion}] -> StatementHandle; | |
%outdated | |
[{_, OldStatementHandle, _OldVsn}] -> | |
ok = pgsql:close(Conn, OldStatementHandle), | |
parse(Conn, StatementFQN, Statement, Types, LatestVersion) | |
end. | |
parse(Conn, {Module, Name} = StatementFQN, Statement, Types, LatestVersion) -> | |
SqlName = iolist_to_binary([atom_to_list(Module), "_", atom_to_list(Name)]), | |
{ok, StatementHandle} = pgsql:parse(Conn, SqlName, Statement, Types), | |
ets:insert(connections_statements, {{Conn, StatementFQN}, StatementHandle, LatestVersion}), | |
StatementHandle. | |
get_connection() -> | |
case get(sql_ctx) of | |
undefined -> error(no_context); | |
Conn -> Conn | |
end. | |
transaction(Fun) -> | |
case get(sql_ctx) of | |
undefined -> | |
Conn = pooler:take_member(pgsql), | |
put(sql_ctx, Conn), | |
try | |
begin_transaction(Conn), | |
Result = Fun(), | |
commit_transaction(Conn), | |
{atomic, Result} | |
catch | |
Type:Error -> | |
pgsql:sync(Conn), | |
rollback_transaction(Conn), | |
error_logger:error_report([ | |
"Aborted transaction", | |
{error, {Type, Error}}, | |
{stack_trace, erlang:get_stacktrace()} | |
]), | |
error({aborted, {Type, Error}}) | |
after | |
pooler:return_member(pgsql, Conn), | |
erase(sql_ctx) | |
end; | |
_Conn -> error(nested_context) | |
end. | |
dirty(Fun) -> | |
case get(sql_ctx) of | |
undefined -> | |
Conn = pooler:take_member(pgsql), | |
put(sql_ctx, Conn), | |
try | |
Fun() | |
after | |
pooler:return_member(pgsql, Conn), | |
erase(sql_ctx) | |
end; | |
_Conn -> error(nested_context) | |
end. | |
available_schema_updates() -> | |
[list_to_binary(Filename) || | |
Filename <- filelib:wildcard("*.sql", code:priv_dir("massmify") ++ "/sql/updates")]. | |
installed_schema_updates() -> | |
Rows = dirty( | |
fun() -> | |
{ok, Result} = query(?MODULE, installed_schema_updates, []), | |
Result | |
end | |
), | |
[Script || {Script} <- Rows]. | |
update_schema() -> | |
InstalledUpdates = installed_schema_updates(), | |
AvailableUpdates = available_schema_updates(), | |
_ = [install_update(UpdateFile) || UpdateFile <- lists:sort(AvailableUpdates -- InstalledUpdates)], | |
ok. | |
install_update(UpdateFile) -> | |
lager:info("Installing ~p ~n", [UpdateFile]), | |
{ok, Script} = file:read_file(code:priv_dir(massmify) ++ "/sql/updates/" ++ binary_to_list(UpdateFile)), | |
{atomic, ok} = transaction( | |
fun() -> | |
Conn = get_connection(), | |
{ok, _} = pgsql:squery(Conn, Script) | |
end | |
), | |
ok. | |
begin_transaction(Conn) -> {ok, [], []} = pgsql:squery(Conn, "BEGIN"). | |
commit_transaction(Conn) -> {ok, [], []} = pgsql:squery(Conn, "COMMIT"). | |
rollback_transaction(Conn) -> | |
try pgsql:squery(Conn, "ROLLBACK") of | |
{ok, [], []} -> ok; | |
{error, Err} ->%TODO: drop connection | |
lager:error("Can't rollback transaction (~p)", Err), | |
pgsql:sync(Conn) | |
catch | |
Type: Error -> | |
lager:error("Can't rollback transaction (~p, ~p)", [Type, Error]), | |
pgsql:sync(Conn) | |
end. | |
prepare_core_statements() -> | |
Statements = [ | |
{installed_schema_updates, | |
[], | |
"SELECT filename FROM schema_updates"} | |
], | |
prepare([ | |
{sql_utils, Statements} | |
]). | |
init([]) -> | |
ets:new(prepared_statements, [named_table, public, {read_concurrency, true}]), | |
ets:new(connections_statements, [named_table, public, {read_concurrency, true}]), | |
prepare_core_statements(), | |
{ok, []}. | |
terminate(_, []) -> ok. | |
handle_call({monitor, Conn}, _, []) -> | |
erlang:monitor(process, Conn), | |
{reply, ok, [], hibernate}. | |
handle_cast(_, []) -> {stop, unexpected, []}. | |
handle_info({'DOWN', _, process, Conn, _}, []) -> | |
ets:match_delete(connections_statements, {{Conn, '_'}, '_', '_'}), | |
{noreply, [], hibernate}. | |
code_change(_OldVsn, _State, _Extra) -> {ok, []}. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment