Skip to content

Instantly share code, notes, and snippets.

@michaelst
Last active May 24, 2023 17:36
Show Gist options
  • Save michaelst/b49c05f39019fd5a92b6377f97dca3dd to your computer and use it in GitHub Desktop.
Save michaelst/b49c05f39019fd5a92b6377f97dca3dd to your computer and use it in GitHub Desktop.
Postgres Proxy
def accept(port) do
{:ok, socket} =
:gen_tcp.listen(port, [:binary, active: false, reuseaddr: true, packet: 0, nodelay: true])
Logger.info("Accepting connections on port #{port}")
loop_acceptor(socket)
end
defp loop_acceptor(socket) do
{:ok, client_conn} = :gen_tcp.accept(socket)
state = %{
client_conn: client_conn,
...
}
{:ok, pid} =
Task.Supervisor.start_child(QueryDesk.PostgresProxy.TaskSupervisor, fn ->
read_client(state, nil)
end)
:ok = :gen_tcp.controlling_process(client_conn, pid)
loop_acceptor(socket)
end
defp read_client(%{client_conn: client_conn} = state, nil) do
data = client_conn |> read_line()
handle_message(state, parse_msg(data))
end
defp read_line({:sslsocket, _port, _pids} = socket) do
{:ok, data} = :ssl.recv(socket, 0)
data
end
defp read_line(socket) when is_port(socket) do
{:ok, data} = :gen_tcp.recv(socket, 0)
data
end
defp maybe_connect_to_database(
%{user: user, team: team, connect_params: connect_params, database_conn: nil} = state
) do
database =
QueryDesk.Api.get!(
Database,
[team_id: team.id, name: connect_params["database"]],
load: [:default_credential, :users]
)
# only allow connections if no reviews are required and they are allowed to access
if database.default_credential.reviews_required == 0 and
QueryDesk.Auth.Utils.can_access_database?(user, database) do
{:ok, database_conn, pid} = Utils.open_database_connection(database, state)
...
:ok =
Utils.send(
state,
Utils.startup_message(database, connect_params)
)
state
end
end
defp open_local_database_connection(database, state) do
{:ok, database_conn} = open_database_connection(database)
state =
state
|> Map.put(:database_conn, database_conn)
|> Map.put(:database, database)
{:ok, pid} =
Task.Supervisor.start_child(QueryDesk.PostgresProxy.TaskSupervisor, fn ->
QueryDesk.PostgresProxy.read_database(state)
end)
case database_conn do
{:sslsocket, _port, _pids} -> :ok = :ssl.controlling_process(database_conn, pid)
port when is_port(port) -> :ok = :gen_tcp.controlling_process(database_conn, pid)
end
{:ok, database_conn, pid}
end
def open_database_connection(database) do
{:ok, database_conn} =
:gen_tcp.connect(to_charlist(database.hostname), 5432,
mode: :binary,
active: false,
packet: :raw
)
if database.ssl do
# send ssl request
:gen_tcp.send(
database_conn,
<<8::integer-size(32), 1234::integer-size(16), 5679::integer-size(16)>>
)
# S means ssl is supported and that we can start the connection
{:ok, <<?S>>} = :gen_tcp.recv(database_conn, 1)
ssl_opts =
Enum.reject(
[
verify: :verify_none,
cacertfile: create_ssl_file(database, :cacertfile),
keyfile: create_ssl_file(database, :keyfile),
certfile: create_ssl_file(database, :certfile)
],
fn {_k, v} -> v == "" end
)
:ssl.connect(database_conn, ssl_opts)
else
{:ok, database_conn}
end
end
# once connecting we will receive a message from the database to send password (md5 example)
defp maybe_send_to_client(
<<?R, 0, 0, 0, 12, 0, 0, 0, 5, salt::binary-size(4)>>,
%{database: database} = state
) do
user = database.default_credential.username
pass = database.default_credential.password
digest = :erlang.md5([pass, user]) |> Base.encode16(case: :lower)
digest = :erlang.md5([digest, salt]) |> Base.encode16(case: :lower)
size = byte_size(digest) + 8
:ok =
Utils.send(
state,
<<?p, size::integer-size(32), "md5", digest::binary, 0>>
)
end
def startup_message(database, connect_params) do
encoded_connect_params =
connect_params
|> Map.put("database", database.database)
|> Map.put("user", database.default_credential.username)
|> Map.put("application_name", "QueryDesk Proxy")
|> Enum.reduce(<<>>, fn {k, v}, acc -> acc <> <<k::binary, 0, v::binary, 0>> end)
size = byte_size(encoded_connect_params) + 9
<<size::integer-size(32), 0, 3, 0, 0, encoded_connect_params::binary, 0>>
end
# SSL Request
def parse_msg(<<len::integer-32, 1234::integer-16, 5679::integer-16>> = bin) do
case bin do
<<msg_body::binary-size(len), final_rest::binary>> ->
{:ok, {{:msgSSLRequest, nil}, msg_body}, final_rest}
...
# Most Messages
def parse_msg(<<c::size(8), rest::binary>>) do
tag = tag_to_msg_type(c)
<<len::integer-32, _::binary>> = rest
case rest do
<<msg_body::binary-size(len), other_msg::binary>> ->
{:ok, {{tag, c}, msg_body}, other_msg}
_other ->
{:continuation,
fn data ->
handle_continuation(len, {tag, c}, rest, data)
end}
end
end
defp handle_message(
state,
{:ok, {{:msgQuery, _c}, <<_len::unsigned-integer-32, query_data::binary>>}, _rest} = msg
) do
query = String.trim_trailing(query_data, <<0>>)
handle_query(query, state, msg)
next_message(state, msg)
end
defp handle_query(query, state, {:ok, {{_msg_type, c}, data}, _rest}) do
...
Logger.debug("running query: #{query}")
:ok = Utils.send(state, <<c, data::binary>>)
end
# Utils.send/2
def send(%{database_conn: {:sslsocket, _port, _pids} = database_conn}, binary) do
:ssl.send(database_conn, binary)
end
def send(%{database_conn: database_conn}, binary) when is_port(database_conn) do
:gen_tcp.send(database_conn, binary)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment