Skip to content

Instantly share code, notes, and snippets.

@Jackarain
Created November 25, 2023 08:06
Show Gist options
  • Save Jackarain/49043e156985a806aca76e3ab17e80cc to your computer and use it in GitHub Desktop.
Save Jackarain/49043e156985a806aca76e3ab17e80cc to your computer and use it in GitHub Desktop.
基于 c++ 20 的 boost asio ssl_stream 实现
//
// ssl_stream.hpp
// ~~~~~~~~~~~~~~
//
// Copyright (c) 2023 Jack (jack dot wgm at gmail dot com)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
#ifndef INCLUDE__2023_11_24__SSL_STREAM_HPP
#define INCLUDE__2023_11_24__SSL_STREAM_HPP
#include <type_traits>
#include <boost/asio/ssl/context.hpp>
#include <boost/asio/ssl/detail/engine.hpp>
#include <boost/asio/ssl/stream.hpp>
#include <boost/asio/ssl/detail/handshake_op.hpp>
#include <boost/asio/ssl/detail/read_op.hpp>
#include <boost/asio/ssl/detail/write_op.hpp>
#include <boost/asio/ssl/detail/shutdown_op.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/detached.hpp>
#include "proxy/use_awaitable.hpp"
namespace secure {
namespace net = boost::asio;
template <typename Stream>
class ssl_stream : public net::ssl::stream_base
{
ssl_stream(const ssl_stream&) = delete;
ssl_stream& operator=(const ssl_stream&) = delete;
enum { max_tls_record_size = 17 * 1024 };
using engine = net::ssl::detail::engine;
auto handshake_op(stream_base::handshake_type type)
{
return [this, type](
engine& eng,
boost::system::error_code& ec,
std::size_t& bytes_transferred
) mutable
{
bytes_transferred = 0;
return eng.handshake(type, ec);
};
}
auto shutdown_op()
{
return [this](
engine& eng,
boost::system::error_code& ec,
std::size_t& bytes_transferred
) mutable
{
bytes_transferred = 0;
return eng.shutdown(ec);
};
}
template <typename ConstBufferSequence>
auto write_op(const ConstBufferSequence& buffers)
{
return [this, buffers](
engine& eng,
boost::system::error_code& ec,
std::size_t& bytes_transferred
) mutable
{
unsigned char storage[
net::detail::buffer_sequence_adapter<
net::const_buffer,
ConstBufferSequence>::linearisation_storage_size
];
net::const_buffer buffer =
net::detail::buffer_sequence_adapter<
net::const_buffer, ConstBufferSequence>::linearise(
buffers,
net::buffer(storage)
);
return eng.write(buffer, ec, bytes_transferred);
};
}
template <typename MutableBufferSequence>
auto read_op(const MutableBufferSequence& buffers)
{
return [this, buffers](
engine& eng,
boost::system::error_code& ec,
std::size_t& bytes_transferred
) mutable
{
net::mutable_buffer buffer =
net::detail::buffer_sequence_adapter<net::mutable_buffer,
MutableBufferSequence>::first(buffers);
return eng.read(buffer, ec, bytes_transferred);
};
}
template <typename ConstBufferSequence>
auto buffered_handshake_op(stream_base::handshake_type type,
const ConstBufferSequence& buffers)
{
return [this, type, buffers,
total_buffer_size(net::buffer_size(buffers))]
(
engine& eng,
boost::system::error_code& ec,
std::size_t& bytes_transferred
) mutable
{
auto iter = net::buffer_sequence_begin(buffers);
auto end = net::buffer_sequence_end(buffers);
std::size_t accumulated_size = 0;
for (;;)
{
engine::want want = eng.handshake(type, ec);
if (want != engine::want_input_and_retry ||
bytes_transferred == total_buffer_size)
return want;
while (iter != end)
{
net::const_buffer buffer(*iter);
if (bytes_transferred >= accumulated_size + buffer.size())
{
accumulated_size += buffer.size();
++iter;
continue;
}
if (bytes_transferred > accumulated_size)
buffer = buffer + (bytes_transferred - accumulated_size);
bytes_transferred += buffer.size();
buffer = eng.put_input(buffer);
bytes_transferred -= buffer.size();
break;
}
}
};
}
template <typename Operation>
std::size_t
sync_io(const Operation& op, boost::system::error_code& ec)
{
boost::system::error_code io_ec;
std::size_t bytes_transferred = 0;
net::mutable_buffer write_buf;
do switch (op(engine_, ec, bytes_transferred))
{
case engine::want_input_and_retry:
if (input_.size() == 0)
{
input_ = net::buffer
(
input_buffer_,
next_layer_.read_some(input_buffer_, io_ec)
);
if (!ec)
ec = io_ec;
}
input_ = engine_.put_input(input_);
continue;
case engine::want_output_and_retry:
write_buf = engine_.get_output(output_buffer_);
net::write(
next_layer_,
write_buf,
io_ec);
if (!ec)
ec = io_ec;
continue;
case engine::want_output:
write_buf = engine_.get_output(output_buffer_);
net::write(next_layer_,
write_buf,
io_ec);
if (!ec)
ec = io_ec;
engine_.map_error_code(ec);
return bytes_transferred;
default:
engine_.map_error_code(ec);
return bytes_transferred;
} while(!ec);
engine_.map_error_code(ec);
return 0;
}
template <typename Operation, typename Handler>
void async_io(const Operation& op, Handler& handler)
{
net::co_spawn(get_executor(),
[this, op = op, handler = std::move(handler)]() mutable -> net::awaitable<void>
{
boost::system::error_code ec;
boost::system::error_code io_ec;
std::size_t bytes_transferred = 0;
net::mutable_buffer write_buf;
do switch (op(engine_, ec, bytes_transferred))
{
case engine::want_input_and_retry:
if (input_.size() == 0)
{
input_ = net::buffer
(
input_buffer_,
co_await next_layer_.async_read_some(
input_buffer_, net_awaitable[io_ec])
);
if (!ec)
ec = io_ec;
}
input_ = engine_.put_input(input_);
continue;
case engine::want_output_and_retry:
write_buf = engine_.get_output(output_buffer_);
co_await net::async_write(
next_layer_,
write_buf,
net_awaitable[io_ec]);
if (!ec)
ec = io_ec;
continue;
case engine::want_output:
write_buf = engine_.get_output(output_buffer_);
co_await net::async_write(next_layer_,
write_buf,
net_awaitable[io_ec]);
if (!ec)
ec = io_ec;
engine_.map_error_code(ec);
handler(ec, bytes_transferred);
co_return;
default:
engine_.map_error_code(ec);
handler(ec, bytes_transferred);
co_return;
} while(!ec);
engine_.map_error_code(ec);
handler(ec, 0);
co_return;
}, net::detached);
}
public:
using native_handle_type = SSL*;
using next_layer_type = typename std::remove_reference<Stream>::type;
using lowest_layer_type = typename next_layer_type::lowest_layer_type;
using executor_type = typename lowest_layer_type::executor_type;
public:
template <typename Arg>
ssl_stream(Arg&& arg, net::ssl::context& ctx)
: next_layer_(std::move(arg))
, context_(ctx)
, engine_(ctx.native_handle())
, input_buffer_space_(max_tls_record_size)
, output_buffer_space_(max_tls_record_size)
, input_buffer_(boost::asio::buffer(input_buffer_space_))
, output_buffer_(boost::asio::buffer(output_buffer_space_))
{
}
~ssl_stream() = default;
ssl_stream(ssl_stream&& other)
: next_layer_(std::move(other.next_layer_))
, context_(other.context_)
, engine_(std::move(other.engine_))
, input_buffer_space_(std::move(other.input_buffer_space_))
, output_buffer_space_(std::move(other.output_buffer_space_))
{
input_buffer_ = boost::asio::buffer(input_buffer_space_);
output_buffer_ = boost::asio::buffer(output_buffer_space_);
input_ = std::move(other.input_);
}
ssl_stream& operator=(ssl_stream&& other)
{
if (this != &other)
{
next_layer_ = std::move(other.next_layer_);
context_ = other.context_;
engine_ = std::move(other.engine_);
input_buffer_space_ = std::move(other.input_buffer_space_);
output_buffer_space_ = std::move(other.output_buffer_space_);
input_buffer_ = boost::asio::buffer(input_buffer_space_);
output_buffer_ = boost::asio::buffer(output_buffer_space_);
input_ = std::move(other.input_);
}
return *this;
}
executor_type get_executor() noexcept
{
return next_layer_.lowest_layer().get_executor();
}
native_handle_type native_handle()
{
return engine_.native_handle();
}
const next_layer_type& next_layer() const
{
return next_layer_;
}
next_layer_type& next_layer()
{
return next_layer_;
}
lowest_layer_type& lowest_layer()
{
return next_layer_.lowest_layer();
}
const lowest_layer_type& lowest_layer() const
{
return next_layer_.lowest_layer();
}
void set_verify_mode(net::ssl::verify_mode v)
{
boost::system::error_code ec;
set_verify_mode(v, ec);
net::detail::throw_error(ec, "set_verify_mode");
}
void set_verify_mode(net::ssl::verify_mode v, boost::system::error_code& ec)
{
engine_.set_verify_mode(v, ec);
}
void set_verify_depth(int depth)
{
boost::system::error_code ec;
set_verify_depth(depth, ec);
net::detail::throw_error(ec, "set_verify_depth");
}
void set_verify_depth(int depth, boost::system::error_code& ec)
{
engine_.set_verify_depth(depth, ec);
}
template <typename VerifyCallback>
void set_verify_callback(VerifyCallback callback)
{
boost::system::error_code ec;
this->set_verify_callback(callback, ec);
net::detail::throw_error(ec, "set_verify_callback");
}
template <typename VerifyCallback>
void set_verify_callback(VerifyCallback callback, boost::system::error_code& ec)
{
engine_.set_verify_callback(
new net::ssl::detail::verify_callback<VerifyCallback>(callback), ec);
}
void handshake(handshake_type type)
{
boost::system::error_code ec;
handshake(type, ec);
net::detail::throw_error(ec, "handshake");
}
void handshake(handshake_type type, boost::system::error_code& ec)
{
sync_io(handshake_op(type), ec);
}
template <typename ConstBufferSequence>
void handshake(handshake_type type, const ConstBufferSequence& buffers)
{
boost::system::error_code ec;
handshake(type, buffers, ec);
net::detail::throw_error(ec, "handshake");
}
template <typename ConstBufferSequence>
void handshake(handshake_type type,
const ConstBufferSequence& buffers, boost::system::error_code& ec)
{
sync_io(buffered_handshake_op<ConstBufferSequence>(type, buffers), ec);
}
template <
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code))
HandshakeToken = net::default_completion_token_t<executor_type>>
auto async_handshake(handshake_type type,
HandshakeToken&& token = net::default_completion_token_t<executor_type>())
{
return net::async_initiate<HandshakeToken,
void (boost::system::error_code, std::size_t)>(
[this] (auto handler, auto type) mutable
{
async_io(handshake_op(type), handler);
}, token, type);
}
template <typename ConstBufferSequence,
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code,
std::size_t)) BufferedHandshakeToken
= net::default_completion_token_t<executor_type>>
auto async_handshake(handshake_type type, const ConstBufferSequence& buffers,
BufferedHandshakeToken&& token
= net::default_completion_token_t<executor_type>())
{
return net::async_initiate<BufferedHandshakeToken,
void (boost::system::error_code, std::size_t)>(
[this] (auto handler, auto type, auto buffers) mutable
{
async_io(buffered_handshake_op<
decltype(buffers)>(type, buffers), handler);
}, token, type, buffers);
}
void shutdown()
{
boost::system::error_code ec;
shutdown(ec);
net::detail::throw_error(ec, "shutdown");
}
void shutdown(boost::system::error_code& ec)
{
sync_io(shutdown_op(), ec);
}
template <
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code))
ShutdownToken
= net::default_completion_token_t<executor_type>>
auto async_shutdown(
ShutdownToken&& token = net::default_completion_token_t<executor_type>())
{
return net::async_initiate<ShutdownToken,
void (boost::system::error_code, std::size_t)>(
[this] (auto handler) mutable
{
async_io(shutdown_op(), handler);
}, token);
}
template <typename ConstBufferSequence>
std::size_t write_some(const ConstBufferSequence& buffers)
{
boost::system::error_code ec;
std::size_t n = write_some(buffers, ec);
net::detail::throw_error(ec, "write_some");
return n;
}
template <typename ConstBufferSequence>
std::size_t write_some(const ConstBufferSequence& buffers,
boost::system::error_code& ec)
{
return sync_io(write_op<ConstBufferSequence>(buffers), ec);
}
template <typename ConstBufferSequence,
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code,
std::size_t)) WriteToken = net::default_completion_token_t<executor_type>>
auto async_write_some(const ConstBufferSequence& buffers,
WriteToken&& token = net::default_completion_token_t<executor_type>())
{
return net::async_initiate<WriteToken,
void (boost::system::error_code, std::size_t)>(
[this] (auto handler, auto buffers) mutable
{
async_io(write_op<decltype(buffers)>(buffers), handler);
}, token, buffers);
}
template <typename MutableBufferSequence>
std::size_t read_some(const MutableBufferSequence& buffers)
{
boost::system::error_code ec;
std::size_t n = read_some(buffers, ec);
net::detail::throw_error(ec, "read_some");
return n;
}
template <typename MutableBufferSequence>
std::size_t read_some(const MutableBufferSequence& buffers,
boost::system::error_code& ec)
{
return sync_io(read_op<MutableBufferSequence>(buffers), ec);
}
template <typename MutableBufferSequence,
BOOST_ASIO_COMPLETION_TOKEN_FOR(void (boost::system::error_code,
std::size_t)) ReadToken = net::default_completion_token_t<executor_type>>
auto async_read_some(const MutableBufferSequence& buffers,
ReadToken&& token = net::default_completion_token_t<executor_type>())
{
return net::async_initiate<ReadToken,
void (boost::system::error_code, std::size_t)>(
[this] (auto handler, auto buffers) mutable
{
async_io(read_op<decltype(buffers)>(buffers), handler);
}, token, buffers);
}
private:
Stream next_layer_;
net::ssl::context& context_;
engine engine_;
std::vector<unsigned char> output_buffer_space_;
net::mutable_buffer output_buffer_;
std::vector<unsigned char> input_buffer_space_;
net::mutable_buffer input_buffer_;
net::const_buffer input_;
};
}
#endif // INCLUDE__2023_11_24__SSL_STREAM_HPP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment