Skip to content

Instantly share code, notes, and snippets.

@potatosalad
Created November 5, 2021 19:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save potatosalad/9268dfba761d4ef2f69b75013da1c56f to your computer and use it in GitHub Desktop.
Save potatosalad/9268dfba761d4ef2f69b75013da1c56f to your computer and use it in GitHub Desktop.
Example of Linux Kernel TLS (KTLS) with Erlang/OTP 24 `socket` NIF
erlc ktls.erl
cc -o ktls_client ktls_client.c

Run Erlang server or client with:

erl -pa .

Start Erlang server:

1> ktls:start_server().
{ok,<0.85.0>}

Start Erlang client:

1> ktls:start_client().
{ok,<0.86.0>}

Send message to either ktls_server or ktls_client named processes:

2> ktls_server ! {send, <<"hello">>}.
{send, <<"hello">>}

On the client node, you should see:

[ktls_client] recvmsg: #{ctrl => [#{data => <<23>>,level => 282,type => 2}],flags => [eor],iov => [<<"hello">>]}

You can also change the control message and send some data, too:

2> ktls_client ! {set_record_type, <<"a">>, <<"foo">>}.
{set_record_type, <<"a">>, <<"foo">>}
3> ktls_client ! {get_record_type, <<"b">>, <<"bar">>}.
{get_record_type, <<"b">>, <<"bar">>}

On the server node, you should see:

[ktls_server] recvmsg: #{ctrl => [#{data => <<"a">>,level => 282,type => 2}],flags => [eor],iov => [<<"foo">>]}
[ktls_server] recvmsg: #{ctrl => [#{data => <<"b">>,level => 282,type => 1}],flags => [eor],iov => [<<"bar">>]}

There is also a C client application that can be used to make sure the server is working, run it with:

./ktls_client
-module(ktls).
-export([
start_server/0,
init_acceptor/1,
init_server/2,
start_client/0,
init_client/1,
setup_tls/1,
setup_tx_tls_1_3_aes_256_gcm/2,
setup_rx_tls_1_3_aes_256_gcm/2
]).
% See: kernel/include/linux/socket.h
-define(SOL_TLS, 282).
% See: kernel/include/uapi/linux/tcp.h
-define(TCP_ULP, 31).
% See: kernel/include/uapi/linux/tls.h
-define(TLS_TX, 1).
-define(TLS_RX, 2).
-define(TLS_SET_RECORD_TYPE, 1).
-define(TLS_GET_RECORD_TYPE, 2).
-define(SERVER_TX, #{
iv => <<1,1,1,1,1,1,1,1>>,
key => <<2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2>>,
salt => <<3,3,3,3>>,
rec_seq => <<0,0,0,0,0,0,0,0>>
}).
-define(SERVER_RX, #{
iv => <<4,4,4,4,4,4,4,4>>,
key => <<5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5>>,
salt => <<6,6,6,6>>,
rec_seq => <<0,0,0,0,0,0,0,0>>
}).
-define(CLIENT_TX, ?SERVER_RX).
-define(CLIENT_RX, ?SERVER_TX).
-record(acceptor, {
name = ktls_acceptor :: atom(),
parent = undefined :: undefined | pid(),
socket = undefined :: undefined | socket:socket(),
accept_select_handle = undefined :: undefined | socket:select_handle()
}).
-record(state, {
name = undefined :: undefined | atom(),
parent = undefined :: undefined | pid(),
socket = undefined :: undefined | socket:socket(),
recv_select_handle = undefined :: undefined | socket:select_handle(),
recvmsg_select_handle = undefined :: undefined | socket:select_handle(),
send_select_handle = undefined :: undefined | socket:select_handle(),
send_buffer = <<>> :: iodata()
}).
start_server() ->
proc_lib:start_link(?MODULE, init_acceptor, [self()]).
init_acceptor(Parent) ->
{ok, ListenSocket} = socket:open(inet, stream, tcp),
% ok = socket:setopt(ListenSocket, {otp, debug}, true),
ok = socket:setopt(ListenSocket, {socket, reuseaddr}, true),
ok = socket:setopt(ListenSocket, {socket, reuseport}, true),
ok = socket:bind(ListenSocket, #{
family => inet,
port => 31388,
addr => {127, 0, 0, 1}
}),
ok = socket:listen(ListenSocket),
_ = erlang:process_flag(trap_exit, true),
ok = proc_lib:init_ack(Parent, {ok, self()}),
State = #acceptor{
parent = Parent,
socket = ListenSocket
},
loop_acceptor(State).
loop_acceptor(S = #acceptor{socket = ListenSocket, accept_select_handle = undefined}) ->
case socket:accept(ListenSocket, nowait) of
{ok, AcceptorSocket} ->
{ok, ServerPid} = proc_lib:start_link(?MODULE, init_server, [self(), AcceptorSocket]),
ok = socket:setopt(AcceptorSocket, {otp, controlling_process}, ServerPid),
ServerPid ! {controlling_process, AcceptorSocket},
loop_acceptor(S);
{select, {select_info, accept, AcceptSelectHandle}} ->
loop_acceptor(S#acceptor{accept_select_handle = AcceptSelectHandle})
end;
loop_acceptor(S = #acceptor{
socket = Socket,
accept_select_handle = AcceptSelectHandle
}) ->
receive
{'$socket', Socket, select, SelectHandle} ->
case SelectHandle of
AcceptSelectHandle ->
loop_acceptor(S#acceptor{accept_select_handle = undefined})
end;
Msg ->
io:format("[~s] stray message: ~0p~n", [S#acceptor.name, Msg]),
loop_acceptor(S)
end.
init_server(Parent, Socket) ->
true = erlang:register(ktls_server, self()),
ok = proc_lib:init_ack(Parent, {ok, self()}),
ok = receive {controlling_process, Socket} -> ok end,
ok = setup_tls(Socket),
ok = setup_tx_tls_1_3_aes_256_gcm(Socket, ?SERVER_TX),
ok = setup_rx_tls_1_3_aes_256_gcm(Socket, ?SERVER_RX),
State = #state{
name = ktls_server,
parent = Parent,
socket = Socket
},
loop(State).
start_client() ->
proc_lib:start_link(?MODULE, init_client, [self()]).
init_client(Parent) ->
true = erlang:register(ktls_client, self()),
{ok, Socket} = socket:open(inet, stream, tcp),
% ok = socket:setopt(Socket, {otp, debug}, true),
ok = socket:connect(Socket, #{
family => inet,
addr => {127, 0, 0, 1},
port => 31388
}),
ok = setup_tls(Socket),
ok = setup_tx_tls_1_3_aes_256_gcm(Socket, ?CLIENT_TX),
ok = setup_rx_tls_1_3_aes_256_gcm(Socket, ?CLIENT_RX),
ok = proc_lib:init_ack(Parent, {ok, self()}),
State = #state{
name = ktls_client,
parent = Parent,
socket = Socket
},
loop(State).
setup_tls(Socket) ->
SocketOption = {tcp, ?TCP_ULP},
Value = <<"tls">>,
socket:setopt_native(Socket, SocketOption, Value).
setup_tx_tls_1_3_aes_256_gcm(Socket, Options) ->
SocketOption = {?SOL_TLS, ?TLS_TX},
Value = make_tls_1_3_aes_256_gcm_value(Options),
socket:setopt_native(Socket, SocketOption, Value).
setup_rx_tls_1_3_aes_256_gcm(Socket, Options) ->
SocketOption = {?SOL_TLS, ?TLS_RX},
Value = make_tls_1_3_aes_256_gcm_value(Options),
socket:setopt_native(Socket, SocketOption, Value).
%% @private
make_tls_1_3_aes_256_gcm_value(#{
iv := IV = <<_:64/bitstring>>,
key := Key = <<_:256/bitstring>>,
salt := Salt = <<_:32/bitstring>>,
rec_seq := ReqSeq = <<_:64/bitstring>>
}) ->
% See: kernel/include/uapi/linux/tls.h
<<
16#04, 16#03, % TLS_1_3_VERSION
16#34, 16#00, % TLS_CIPHER_AES_GCM_256
IV:64/bitstring,
Key:256/bitstring,
Salt:32/bitstring,
ReqSeq:64/bitstring
>>.
%% @private
loop(S = #state{socket = Socket, recv_select_handle = undefined}) ->
case socket:recv(Socket, 0, nowait) of
{ok, RecvData} ->
io:format("[~s] recv: ~0p~n", [S#state.name, RecvData]),
loop(S);
{select, {select_info, recv, NewRecvSelectHandle}} ->
loop(S#state{recv_select_handle = NewRecvSelectHandle});
{error, closed} ->
error({socket_closed, recv})
end;
loop(S = #state{socket = Socket, recvmsg_select_handle = undefined}) ->
case socket:recvmsg(Socket, nowait) of
{ok, RecvMsg} ->
io:format("[~s] recvmsg: ~0p~n", [S#state.name, RecvMsg]),
loop(S);
{select, {select_info, recvmsg, NewRecvmsgSelectHandle}} ->
loop(S#state{recvmsg_select_handle = NewRecvmsgSelectHandle});
{error, closed} ->
error({socket_closed, recvmsg})
end;
loop(S = #state{socket = Socket, send_select_handle = undefined, send_buffer = OldSendBuffer}) when OldSendBuffer =/= <<>> ->
case socket:send(Socket, OldSendBuffer, nowait) of
ok ->
loop(S#state{send_buffer = <<>>});
{select, {select_info, send, NewSendSelectHandle}} ->
loop(S#state{send_select_handle = NewSendSelectHandle});
{select, {{select_info, send, NewSendSelectHandle}, NewSendBuffer}} ->
loop(S#state{send_select_handle = NewSendSelectHandle, send_buffer = NewSendBuffer});
{error, closed} ->
error({socket_closed, send})
end;
loop(S = #state{
socket = Socket,
recv_select_handle = RecvSelectHandle,
recvmsg_select_handle = RecvmsgSelectHandle,
send_select_handle = SendSelectHandle,
send_buffer = OldSendBuffer
}) ->
receive
{'$socket', Socket, select, SelectHandle} ->
case SelectHandle of
RecvSelectHandle ->
loop(S#state{recv_select_handle = undefined});
RecvmsgSelectHandle ->
loop(S#state{recvmsg_select_handle = undefined});
SendSelectHandle ->
loop(S#state{send_select_handle = undefined})
end;
{send, MoreData} ->
loop(S#state{send_buffer = erlang:iolist_to_iovec([OldSendBuffer | MoreData])});
{get_record_type, <<RecordType>>, MsgData} ->
CmsgSend = #{
level => ?SOL_TLS,
type => ?TLS_GET_RECORD_TYPE,
data => <<RecordType>>
},
MsgSend = #{
iov => erlang:iolist_to_iovec(MsgData),
ctrl => [CmsgSend]
},
io:format("[~s] sendmsg(~0p)~n", [S#state.name, MsgSend]),
ok = socket:sendmsg(Socket, MsgSend, infinity),
loop(S);
{set_record_type, <<RecordType>>, MsgData} ->
CmsgSend = #{
level => ?SOL_TLS,
type => ?TLS_SET_RECORD_TYPE,
data => <<RecordType>>
},
MsgSend = #{
iov => erlang:iolist_to_iovec(MsgData),
ctrl => [CmsgSend]
},
io:format("[~s] sendmsg(~0p)~n", [S#state.name, MsgSend]),
ok = socket:sendmsg(Socket, MsgSend, infinity),
loop(S);
Msg ->
io:format("[~s] stray message: ~0p~n", [S#state.name, Msg]),
loop(S)
end.
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <arpa/inet.h>
#include <errno.h>
#include <linux/socket.h>
#include <linux/tls.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
static int klts_send_ctrl_message(int sock, unsigned char record_type,
void *data, size_t length)
{
struct msghdr msg = {0};
int cmsg_len = sizeof(record_type);
struct cmsghdr *cmsg;
char buf[CMSG_SPACE(cmsg_len)];
struct iovec msg_iov; /* Vector of data to send/receive into. */
msg.msg_control = buf;
msg.msg_controllen = sizeof(buf);
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_TLS;
cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
cmsg->cmsg_len = CMSG_LEN(cmsg_len);
*CMSG_DATA(cmsg) = record_type;
msg.msg_controllen = cmsg->cmsg_len;
msg_iov.iov_base = data;
msg_iov.iov_len = length;
msg.msg_iov = &msg_iov;
msg.msg_iovlen = 1;
return sendmsg(sock, &msg, 0);
}
int main(void)
{
int sockfd;
struct sockaddr_in servaddr;
sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == -1) {
fprintf(stderr, "ERROR: socket(2) failed with (%d) %s\n", errno, strerror(errno));
return 1;
}
(void)memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = inet_addr("127.0.0.1");
servaddr.sin_port = htons(31388);
if (connect(sockfd, (void *)&servaddr, sizeof(servaddr)) != 0) {
fprintf(stderr, "ERROR: connect(2) failed with (%d) %s\n", errno, strerror(errno));
return 1;
}
struct tls12_crypto_info_aes_gcm_256 client_tx = {
.info = {
.version = TLS_1_3_VERSION,
.cipher_type = TLS_CIPHER_AES_GCM_256,
},
.iv = {4,4,4,4,4,4,4,4},
.key = {5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5},
.salt = {6,6,6,6},
.rec_seq = {0,0,0,0,0,0,0,0},
};
struct tls12_crypto_info_aes_gcm_256 client_rx = {
.info = {
.version = TLS_1_3_VERSION,
.cipher_type = TLS_CIPHER_AES_GCM_256,
},
.iv = {1,1,1,1,1,1,1,1},
.key = {2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2},
.salt = {3,3,3,3},
.rec_seq = {0,0,0,0,0,0,0,0},
};
if (setsockopt(sockfd, SOL_TCP, TCP_ULP, "tls", sizeof("tls")) != 0) {
fprintf(stderr, "ERROR: setsockopt(2) for {SOL_TCP, TCP, ULP, \"tls\"} failed with (%d) %s\n", errno, strerror(errno));
return 1;
}
if (setsockopt(sockfd, SOL_TLS, TLS_TX, &client_tx, sizeof(client_tx)) != 0) {
fprintf(stderr, "ERROR: setsockopt(2) for {SOL_TLS, TLS_TX, &client_tx} failed with (%d) %s\n", errno, strerror(errno));
return 1;
}
if (setsockopt(sockfd, SOL_TLS, TLS_RX, &client_rx, sizeof(client_rx)) != 0) {
fprintf(stderr, "ERROR: setsockopt(2) for {SOL_TLS, TLS_RX, &client_rx} failed with (%d) %s\n", errno, strerror(errno));
return 1;
}
do {
char *msg = "hello";
send(sockfd, msg, strlen(msg), 0);
} while (0);
sleep(1);
do {
char *msg = "world";
send(sockfd, msg, strlen(msg), 0);
} while (0);
sleep(2);
do {
char *msg = "abc";
klts_send_ctrl_message(sockfd, 'a', (void *)msg, strlen(msg));
} while (0);
sleep(3);
do {
char *msg = "hello world";
send(sockfd, msg, strlen(msg), 0);
} while (0);
sleep(5);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment