|
-module(ktls). |
|
|
|
-export([ |
|
start_server/0, |
|
init_acceptor/1, |
|
init_server/2, |
|
start_client/0, |
|
init_client/1, |
|
setup_tls/1, |
|
setup_tx_tls_1_3_aes_256_gcm/2, |
|
setup_rx_tls_1_3_aes_256_gcm/2 |
|
]). |
|
|
|
% See: kernel/include/linux/socket.h |
|
-define(SOL_TLS, 282). |
|
% See: kernel/include/uapi/linux/tcp.h |
|
-define(TCP_ULP, 31). |
|
% See: kernel/include/uapi/linux/tls.h |
|
-define(TLS_TX, 1). |
|
-define(TLS_RX, 2). |
|
-define(TLS_SET_RECORD_TYPE, 1). |
|
-define(TLS_GET_RECORD_TYPE, 2). |
|
|
|
-define(SERVER_TX, #{ |
|
iv => <<1,1,1,1,1,1,1,1>>, |
|
key => <<2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2>>, |
|
salt => <<3,3,3,3>>, |
|
rec_seq => <<0,0,0,0,0,0,0,0>> |
|
}). |
|
-define(SERVER_RX, #{ |
|
iv => <<4,4,4,4,4,4,4,4>>, |
|
key => <<5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5>>, |
|
salt => <<6,6,6,6>>, |
|
rec_seq => <<0,0,0,0,0,0,0,0>> |
|
}). |
|
-define(CLIENT_TX, ?SERVER_RX). |
|
-define(CLIENT_RX, ?SERVER_TX). |
|
|
|
-record(acceptor, { |
|
name = ktls_acceptor :: atom(), |
|
parent = undefined :: undefined | pid(), |
|
socket = undefined :: undefined | socket:socket(), |
|
accept_select_handle = undefined :: undefined | socket:select_handle() |
|
}). |
|
|
|
-record(state, { |
|
name = undefined :: undefined | atom(), |
|
parent = undefined :: undefined | pid(), |
|
socket = undefined :: undefined | socket:socket(), |
|
recv_select_handle = undefined :: undefined | socket:select_handle(), |
|
recvmsg_select_handle = undefined :: undefined | socket:select_handle(), |
|
send_select_handle = undefined :: undefined | socket:select_handle(), |
|
send_buffer = <<>> :: iodata() |
|
}). |
|
|
|
start_server() -> |
|
proc_lib:start_link(?MODULE, init_acceptor, [self()]). |
|
|
|
init_acceptor(Parent) -> |
|
{ok, ListenSocket} = socket:open(inet, stream, tcp), |
|
% ok = socket:setopt(ListenSocket, {otp, debug}, true), |
|
ok = socket:setopt(ListenSocket, {socket, reuseaddr}, true), |
|
ok = socket:setopt(ListenSocket, {socket, reuseport}, true), |
|
ok = socket:bind(ListenSocket, #{ |
|
family => inet, |
|
port => 31388, |
|
addr => {127, 0, 0, 1} |
|
}), |
|
ok = socket:listen(ListenSocket), |
|
_ = erlang:process_flag(trap_exit, true), |
|
ok = proc_lib:init_ack(Parent, {ok, self()}), |
|
State = #acceptor{ |
|
parent = Parent, |
|
socket = ListenSocket |
|
}, |
|
loop_acceptor(State). |
|
|
|
loop_acceptor(S = #acceptor{socket = ListenSocket, accept_select_handle = undefined}) -> |
|
case socket:accept(ListenSocket, nowait) of |
|
{ok, AcceptorSocket} -> |
|
{ok, ServerPid} = proc_lib:start_link(?MODULE, init_server, [self(), AcceptorSocket]), |
|
ok = socket:setopt(AcceptorSocket, {otp, controlling_process}, ServerPid), |
|
ServerPid ! {controlling_process, AcceptorSocket}, |
|
loop_acceptor(S); |
|
{select, {select_info, accept, AcceptSelectHandle}} -> |
|
loop_acceptor(S#acceptor{accept_select_handle = AcceptSelectHandle}) |
|
end; |
|
loop_acceptor(S = #acceptor{ |
|
socket = Socket, |
|
accept_select_handle = AcceptSelectHandle |
|
}) -> |
|
receive |
|
{'$socket', Socket, select, SelectHandle} -> |
|
case SelectHandle of |
|
AcceptSelectHandle -> |
|
loop_acceptor(S#acceptor{accept_select_handle = undefined}) |
|
end; |
|
Msg -> |
|
io:format("[~s] stray message: ~0p~n", [S#acceptor.name, Msg]), |
|
loop_acceptor(S) |
|
end. |
|
|
|
init_server(Parent, Socket) -> |
|
true = erlang:register(ktls_server, self()), |
|
ok = proc_lib:init_ack(Parent, {ok, self()}), |
|
ok = receive {controlling_process, Socket} -> ok end, |
|
ok = setup_tls(Socket), |
|
ok = setup_tx_tls_1_3_aes_256_gcm(Socket, ?SERVER_TX), |
|
ok = setup_rx_tls_1_3_aes_256_gcm(Socket, ?SERVER_RX), |
|
State = #state{ |
|
name = ktls_server, |
|
parent = Parent, |
|
socket = Socket |
|
}, |
|
loop(State). |
|
|
|
start_client() -> |
|
proc_lib:start_link(?MODULE, init_client, [self()]). |
|
|
|
init_client(Parent) -> |
|
true = erlang:register(ktls_client, self()), |
|
{ok, Socket} = socket:open(inet, stream, tcp), |
|
% ok = socket:setopt(Socket, {otp, debug}, true), |
|
ok = socket:connect(Socket, #{ |
|
family => inet, |
|
addr => {127, 0, 0, 1}, |
|
port => 31388 |
|
}), |
|
ok = setup_tls(Socket), |
|
ok = setup_tx_tls_1_3_aes_256_gcm(Socket, ?CLIENT_TX), |
|
ok = setup_rx_tls_1_3_aes_256_gcm(Socket, ?CLIENT_RX), |
|
ok = proc_lib:init_ack(Parent, {ok, self()}), |
|
State = #state{ |
|
name = ktls_client, |
|
parent = Parent, |
|
socket = Socket |
|
}, |
|
loop(State). |
|
|
|
setup_tls(Socket) -> |
|
SocketOption = {tcp, ?TCP_ULP}, |
|
Value = <<"tls">>, |
|
socket:setopt_native(Socket, SocketOption, Value). |
|
|
|
setup_tx_tls_1_3_aes_256_gcm(Socket, Options) -> |
|
SocketOption = {?SOL_TLS, ?TLS_TX}, |
|
Value = make_tls_1_3_aes_256_gcm_value(Options), |
|
socket:setopt_native(Socket, SocketOption, Value). |
|
|
|
setup_rx_tls_1_3_aes_256_gcm(Socket, Options) -> |
|
SocketOption = {?SOL_TLS, ?TLS_RX}, |
|
Value = make_tls_1_3_aes_256_gcm_value(Options), |
|
socket:setopt_native(Socket, SocketOption, Value). |
|
|
|
%% @private |
|
make_tls_1_3_aes_256_gcm_value(#{ |
|
iv := IV = <<_:64/bitstring>>, |
|
key := Key = <<_:256/bitstring>>, |
|
salt := Salt = <<_:32/bitstring>>, |
|
rec_seq := ReqSeq = <<_:64/bitstring>> |
|
}) -> |
|
% See: kernel/include/uapi/linux/tls.h |
|
<< |
|
16#04, 16#03, % TLS_1_3_VERSION |
|
16#34, 16#00, % TLS_CIPHER_AES_GCM_256 |
|
IV:64/bitstring, |
|
Key:256/bitstring, |
|
Salt:32/bitstring, |
|
ReqSeq:64/bitstring |
|
>>. |
|
|
|
%% @private |
|
loop(S = #state{socket = Socket, recv_select_handle = undefined}) -> |
|
case socket:recv(Socket, 0, nowait) of |
|
{ok, RecvData} -> |
|
io:format("[~s] recv: ~0p~n", [S#state.name, RecvData]), |
|
loop(S); |
|
{select, {select_info, recv, NewRecvSelectHandle}} -> |
|
loop(S#state{recv_select_handle = NewRecvSelectHandle}); |
|
{error, closed} -> |
|
error({socket_closed, recv}) |
|
end; |
|
loop(S = #state{socket = Socket, recvmsg_select_handle = undefined}) -> |
|
case socket:recvmsg(Socket, nowait) of |
|
{ok, RecvMsg} -> |
|
io:format("[~s] recvmsg: ~0p~n", [S#state.name, RecvMsg]), |
|
loop(S); |
|
{select, {select_info, recvmsg, NewRecvmsgSelectHandle}} -> |
|
loop(S#state{recvmsg_select_handle = NewRecvmsgSelectHandle}); |
|
{error, closed} -> |
|
error({socket_closed, recvmsg}) |
|
end; |
|
loop(S = #state{socket = Socket, send_select_handle = undefined, send_buffer = OldSendBuffer}) when OldSendBuffer =/= <<>> -> |
|
case socket:send(Socket, OldSendBuffer, nowait) of |
|
ok -> |
|
loop(S#state{send_buffer = <<>>}); |
|
{select, {select_info, send, NewSendSelectHandle}} -> |
|
loop(S#state{send_select_handle = NewSendSelectHandle}); |
|
{select, {{select_info, send, NewSendSelectHandle}, NewSendBuffer}} -> |
|
loop(S#state{send_select_handle = NewSendSelectHandle, send_buffer = NewSendBuffer}); |
|
{error, closed} -> |
|
error({socket_closed, send}) |
|
end; |
|
loop(S = #state{ |
|
socket = Socket, |
|
recv_select_handle = RecvSelectHandle, |
|
recvmsg_select_handle = RecvmsgSelectHandle, |
|
send_select_handle = SendSelectHandle, |
|
send_buffer = OldSendBuffer |
|
}) -> |
|
receive |
|
{'$socket', Socket, select, SelectHandle} -> |
|
case SelectHandle of |
|
RecvSelectHandle -> |
|
loop(S#state{recv_select_handle = undefined}); |
|
RecvmsgSelectHandle -> |
|
loop(S#state{recvmsg_select_handle = undefined}); |
|
SendSelectHandle -> |
|
loop(S#state{send_select_handle = undefined}) |
|
end; |
|
{send, MoreData} -> |
|
loop(S#state{send_buffer = erlang:iolist_to_iovec([OldSendBuffer | MoreData])}); |
|
{get_record_type, <<RecordType>>, MsgData} -> |
|
CmsgSend = #{ |
|
level => ?SOL_TLS, |
|
type => ?TLS_GET_RECORD_TYPE, |
|
data => <<RecordType>> |
|
}, |
|
MsgSend = #{ |
|
iov => erlang:iolist_to_iovec(MsgData), |
|
ctrl => [CmsgSend] |
|
}, |
|
io:format("[~s] sendmsg(~0p)~n", [S#state.name, MsgSend]), |
|
ok = socket:sendmsg(Socket, MsgSend, infinity), |
|
loop(S); |
|
{set_record_type, <<RecordType>>, MsgData} -> |
|
CmsgSend = #{ |
|
level => ?SOL_TLS, |
|
type => ?TLS_SET_RECORD_TYPE, |
|
data => <<RecordType>> |
|
}, |
|
MsgSend = #{ |
|
iov => erlang:iolist_to_iovec(MsgData), |
|
ctrl => [CmsgSend] |
|
}, |
|
io:format("[~s] sendmsg(~0p)~n", [S#state.name, MsgSend]), |
|
ok = socket:sendmsg(Socket, MsgSend, infinity), |
|
loop(S); |
|
Msg -> |
|
io:format("[~s] stray message: ~0p~n", [S#state.name, Msg]), |
|
loop(S) |
|
end. |