Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
-module(proxy_protocol).
%% Code between markers is stolen from Cowboy
%%% BEGIN STOLEN CODE
-export([parse_request/3]).
-export([start_link/4]).
-export([init/4]).
-record(state, {
socket :: inet:socket(),
transport :: module(),
middlewares :: [module()],
compress :: boolean(),
env :: cowboy_middleware:env(),
onrequest :: undefined | cowboy:onrequest_fun(),
onresponse = undefined :: undefined | cowboy:onresponse_fun(),
max_empty_lines :: non_neg_integer(),
req_keepalive = 1 :: non_neg_integer(),
max_keepalive :: non_neg_integer(),
max_request_line_length :: non_neg_integer(),
max_header_name_length :: non_neg_integer(),
max_header_value_length :: non_neg_integer(),
max_headers :: non_neg_integer(),
timeout :: timeout(),
until :: non_neg_integer() | infinity
}).
-spec start_link(ranch:ref(), inet:socket(), module(), term()) -> {ok, pid()}.
start_link(Ref, Socket, Transport, Opts) ->
Pid = spawn_link(?MODULE, init, [Ref, Socket, Transport, Opts]),
{ok, Pid}.
%% Faster alternative to proplists:get_value/3.
get_value(Key, Opts, Default) ->
case lists:keyfind(Key, 1, Opts) of
{_, Value} -> Value;
_ -> Default
end.
-spec init(ranch:ref(), inet:socket(), module(), term()) -> ok.
init(Ref, Socket, Transport, Opts) ->
Compress = get_value(compress, Opts, false),
MaxEmptyLines = get_value(max_empty_lines, Opts, 5),
MaxHeaderNameLength = get_value(max_header_name_length, Opts, 64),
MaxHeaderValueLength = get_value(max_header_value_length, Opts, 4096),
MaxHeaders = get_value(max_headers, Opts, 100),
MaxKeepalive = get_value(max_keepalive, Opts, 100),
MaxRequestLineLength = get_value(max_request_line_length, Opts, 4096),
Middlewares = get_value(middlewares, Opts, [cowboy_router, cowboy_handler]),
Env = [{listener, Ref}|get_value(env, Opts, [])],
OnRequest = get_value(onrequest, Opts, undefined),
OnResponse = get_value(onresponse, Opts, undefined),
Timeout = get_value(timeout, Opts, 5000),
ok = ranch:accept_ack(Ref),
wait_request(<<>>, #state{socket=Socket, transport=Transport,
middlewares=Middlewares, compress=Compress, env=Env,
max_empty_lines=MaxEmptyLines, max_keepalive=MaxKeepalive,
max_request_line_length=MaxRequestLineLength,
max_header_name_length=MaxHeaderNameLength,
max_header_value_length=MaxHeaderValueLength, max_headers=MaxHeaders,
onrequest=OnRequest, onresponse=OnResponse,
timeout=Timeout, until=until(Timeout)}, 0).
-spec wait_request(binary(), #state{}, non_neg_integer()) -> ok.
wait_request(Buffer, State=#state{socket=Socket, transport=Transport,
until=Until}, ReqEmpty) ->
case recv(Socket, Transport, Until) of
{ok, Data} ->
parse_request(<< Buffer/binary, Data/binary >>, State, ReqEmpty);
{error, _} ->
terminate(State)
end.
-spec recv(inet:socket(), module(), non_neg_integer() | infinity)
-> {ok, binary()} | {error, closed | timeout | atom()}.
recv(Socket, Transport, infinity) ->
Transport:recv(Socket, 0, infinity);
recv(Socket, Transport, Until) ->
{Me, S, Mi} = os:timestamp(),
Now = Me * 1000000000 + S * 1000 + Mi div 1000,
Timeout = Until - Now,
if Timeout < 0 ->
{error, timeout};
true ->
Transport:recv(Socket, 0, Timeout)
end.
-spec terminate(#state{}) -> ok.
terminate(#state{socket=Socket, transport=Transport}) ->
Transport:close(Socket),
ok.
-spec until(timeout()) -> non_neg_integer() | infinity.
until(infinity) ->
infinity;
until(Timeout) ->
{Me, S, Mi} = os:timestamp(),
Me * 1000000000 + S * 1000 + Mi div 1000 + Timeout.
%%% END STOLEN CODE
parse_request(<<"PROXY ", Data/binary>>,
State = #state{socket = Socket, transport = Transport, until = Until}, ReqEmpty) ->
{Proxy, Other} = case binary:split(Data, [<<"\r\n">>]) of
[P, O] -> {P, O};
[P] -> {P, <<>>}
end,
case parse_proxy_protocol(Proxy) of
unknown_peer when Other =:= <<>> ->
{ok, NextData} = recv(Socket, Transport, Until),
cowboy_protocol:parse_request(NextData, State, ReqEmpty),
{ok, State};
unknown_peer ->
cowboy_protocol:parse_request(Other, State, ReqEmpty),
{ok, State};
not_proxy_protocol ->
Transport:close(Socket),
throw(not_proxy_protocol);
ProxyInfo when Other =:= <<>> ->
%% saucy
put(proxy_info, ProxyInfo),
case recv(Socket, Transport, Until) of
{ok, NextData} ->
cowboy_protocol:parse_request(NextData, State, ReqEmpty);
{error, _} ->
terminate(State)
end;
ProxyInfo ->
%% saucy
put(proxy_info, ProxyInfo),
cowboy_protocol:parse_request(Other, State, ReqEmpty)
end;
parse_request(Data, State, ReqEmpty) ->
cowboy_protocol:parse_request(Data, State, ReqEmpty).
parse_proxy_protocol(<<"TCP", Proto:1/binary, _:1/binary, Info/binary>>) ->
InfoStr = binary_to_list(Info),
case string:tokens(InfoStr, " \r\n") of
[SourceAddress, DestAddress, SourcePort, DestPort] ->
case {parse_inet(Proto), parse_ips([SourceAddress, DestAddress], []),
parse_ports([SourcePort, DestPort], [])} of
{ProtoParsed, [SourceInetAddress, DestInetAddress], [SourceInetPort, DestInetPort]} ->
{ProtoParsed, SourceInetAddress, DestInetAddress, SourceInetPort, DestInetPort};
_ ->
malformed_proxy_protocol
end
end;
parse_proxy_protocol(<<"UNKNOWN", _/binary>>) ->
unknown_peer;
parse_proxy_protocol(_) ->
not_proxy_protocol.
parse_inet(<<"4">>) ->
ipv4;
parse_inet(<<"6">>) ->
ipv6;
parse_inet(_) ->
{error, invalid_inet_version}.
parse_ports([], Retval) ->
Retval;
parse_ports([Port | Ports], Retval) ->
try list_to_integer(Port) of
IntPort ->
parse_ports(Ports, Retval ++ [IntPort])
catch
error:badarg ->
{error, invalid_port}
end.
parse_ips([], Retval) ->
Retval;
parse_ips([Ip | Ips], Retval) ->
case inet:parse_address(Ip) of
{ok, ParsedIp} ->
parse_ips(Ips, Retval ++ [ParsedIp]);
_ ->
{error, invalid_address}
end.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment