Skip to content

Instantly share code, notes, and snippets.

@sehe

sehe/.gitignore Secret

Forked from jammerxd/Connection.h
Last active September 28, 2021 13:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sehe/a32a59096279d5fef99c9824a6da0168 to your computer and use it in GitHub Desktop.
Save sehe/a32a59096279d5fef99c9824a6da0168 to your computer and use it in GitHub Desktop.
callgrind*
log
// Client-Test.cpp : This file contains the 'main' function. Program execution begins and ends there.
//
#include "prerequisites.h"
class MyClient : public Client<Message<MessageTypes>, Executor> {
public:
using Message = ::Message<MessageTypes>;
MyClient(Executor ex) : MyClient::base_type(ex), timer_(ex) {}
virtual void OnDisconnect(ConnPtr const&)
{
std::cout << "Disconnected." << std::endl;
timer_.cancel();
isexiting_ = true;
}
virtual void OnMessage(MsgPtr const& msg, ConnPtr const&)
{
switch (msg->message_header.id) {
case MessageTypes::ServerAccept: {
auto id = msg->get<int>();
SetId(id);
std::cout << "[ SERVER ] SENT ID: " << id << std::endl;
StartSendMessages();
break;
}
case MessageTypes::SendText: {
std::cout << "[ SERVER ] SENT MESSAGE " << msg->size()
<< " bytes long." << std::endl;
break;
}
case MessageTypes::MessageAll: {
auto now = Clock::now();
// atomic laptime
auto time_taken = now - std::exchange(start_, now);
auto s = msg->TextFragments().front();
if (!isfirst_.exchange(false)) {
if (time_taken > 55ms) {
std::cout << "MESSAGE WAS DELAYED (" << (time_taken / 1ms)
<< "ms, length: " << s.length() << ")"
<< std::endl;
}
}
break;
}
default: throw std::runtime_error("Message type not implemented");
}
}
virtual void OnConnect()
{
std::cout << "[ DEBUG ] Thread Id: " << std::this_thread::get_id() << std::endl;
}
virtual void OnMessageSent(MsgPtr const&)
{
std::cout << "Message sent" << std::endl;
}
~MyClient()
{
isexiting_ = true;
// Disconnect(); // happens in baseclass destructor
}
private:
void StartSendMessages()
{
if (num_msgs_ || !IsConnected())
return; // already running
std::cout << "[ SERVER ] StartSendMessages" << std::endl;
/* generate secret number between 1 and 10: */
num_msgs_ = Dist{1, 10}(prng_);
post(_strand, [this] { TimedSendLoop(); });
}
void TimedSendLoop()
{
std::cout << "TimedSendLoop #" << num_msgs_ + 1 << std::endl;
{
auto msg_size = Dist{409'600, 921'600}(prng_);
Message msg;
msg.message_header.id = MessageTypes::SendText;
auto payload = msg.Alloc(msg_size);
std::fill(begin(payload), end(payload), 'a' + num_msgs_ % 27);
Send(std::move(msg));
}
auto delay = 1ms * Dist{1'000, 10'000}(prng_);
std::cout << "Sleeping for " << delay / 1.0s << std::endl;
timer_.expires_from_now(delay);
timer_.async_wait([this](error_code ec) {
if (isexiting_ || ec == boost::asio::error::operation_aborted) {
std::cout << "TimedSendLoop: " << ec.message() << std::endl;
return;
}
if (num_msgs_--) {
TimedSendLoop();
}
});
}
Clock::time_point start_;
std::atomic_bool isfirst_{true};
std::atomic_bool isexiting_{false};
// SendMessages state
using Dist = std::uniform_int_distribution<>;
std::mt19937 prng_{std::random_device{}()};
Timer timer_;
int num_msgs_ = 0;
};
int main()
{
boost::asio::thread_pool io;
std::string host = "localhost";
uint16_t port = 40'000;
{
std::deque<std::unique_ptr<MyClient> > clients;
std::generate_n( //
back_inserter(clients), 200, [&] {
auto c = std::make_unique<MyClient>(io.get_executor());
std::cout << "Connect" << std::endl;
c->Connect(host, port);
return c;
});
std::this_thread::sleep_for(10s);
} // destructors call Disconnect()
std::cout << "DONE" << std::endl;
io.join();
}
#pragma once
#include "prerequisites.h"
template <typename Message, typename Executor>
class Client
{
protected:
using base_type = Client<Message, Executor>;
using Strand = boost::asio::strand<Executor>;
using conn_t = Connection<Message, Strand>;
using MsgPtr = typename conn_t::MsgPtr;
using ConnPtr = std::shared_ptr<conn_t>;
protected:
virtual void OnConnect() {}
virtual void OnDisconnect(ConnPtr const&) {}
virtual void OnMessage(MsgPtr const&, ConnPtr const&) {}
virtual void OnMessageSent(MsgPtr const&, ConnPtr const&) {}
public:
bool Connect(const std::string& host, const uint16_t port)
{
tcp::resolver::results_type endpoints;
try {
// Resolve hostname/ip-address into tangiable physical address
tcp::resolver resolver(_strand);
// TODO FIXME why not async? Could this be kept out of the Client so
// that there is no blocking resolve at the start of each
// connection?
endpoints = resolver.resolve(host, std::to_string(port));
} catch (std::exception& e) {
std::cerr << "Client Exception: " << e.what() << std::endl;
return false;
}
post(_strand, [this, endpoints] {
// Create connection
using boost::placeholders::_1;
using boost::placeholders::_2;
_connection = conn_t::create( //
_strand, 0, //
boost::bind(&Client::DoOnMessage, this, _1, _2),
boost::bind(&Client::DoOnMessageSent, this, _1, _2),
boost::bind(&Client::DoOnDisconnected, this, _1));
async_connect( //
_connection->socket(), endpoints,
[this](std::error_code ec, tcp::endpoint) {
if (!ec) {
_connection->accepted();
OnConnect();
}
});
});
return true;
}
void Disconnect()
{
post(_strand, [c = _connection] { //
if (c) {
c->Disconnect(true, true, true);
}
});
}
virtual ~Client() // NOTE the virtual again
{
Disconnect();
}
Client(Executor executor) : _strand(make_strand(executor)) {}
// not safe outside strand
bool IsConnected() { return _connection && _connection->socket().is_open(); }
void Send(Message msg)
{
post(_strand, [c = _connection, msg = std::move(msg)]() mutable { //
if (c) {
c->Send(std::make_shared<Message>(std::move(msg)));
}
});
}
bool IsSending() { return _connection && _connection->IsSending(); }
void DoOnDisconnected(ConnPtr const& client) { OnDisconnect(client); }
void DoOnMessage(MsgPtr const& m, ConnPtr const& c) { OnMessage(m, c); }
void DoOnMessageSent(MsgPtr const& m, ConnPtr const& c) { OnMessageSent(m, c); }
void SetId(int id) { _connection->SetId(id); }
protected:
Strand _strand;
ConnPtr _connection;
};
cmake_policy(SET CMP0048 NEW)
project(stackoverflow)
cmake_minimum_required(VERSION 3.5)
set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE)
SET(BOOST_DIR /home/sehe/custom/boost_1_76_0)
set(CMAKE_INSTALL_RPATH "${BOOST_DIR}/stage/lib")
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
set(CMAKE_SKIP_RPATH FALSE)
SET(CMAKE_C_COMPILER gcc-10)
SET(CMAKE_CXX_COMPILER g++-10)
SET(CMAKE_EXE_LINKER_FLAGS "-Wl,--disable-new-dtags")
LINK_LIBRARIES(boost_system)
LINK_LIBRARIES(boost_thread)
LINK_DIRECTORIES("${BOOST_DIR}/stage/lib")
INCLUDE_DIRECTORIES(SYSTEM ${BOOST_DIR})
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Wno-unknown-pragmas ") # -Wconversion
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++2a ")
#SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3 -pthread -march=native")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -fno-omit-frame-pointer -pthread -march=native")
ADD_DEFINITIONS(-DBOOST_USE_ASAN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address,undefined")
#SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto")
ADD_EXECUTABLE(server Server-Test.cpp)
ADD_EXECUTABLE(client Client-Test.cpp)
#pragma once
#include "prerequisites.h"
template <typename Message, typename Executor>
class Connection
: public std::enable_shared_from_this<Connection<Message, Executor>> //
{
using std::enable_shared_from_this<Connection>::shared_from_this;
using socket_t = boost::asio::basic_stream_socket<tcp, Executor>;
public:
using ConnPtr = std::shared_ptr<Connection>;
using MsgPtr = std::shared_ptr<Message const>;
using ClientMessageCallbackType = std::function<void(MsgPtr const&, ConnPtr const&)>;
using ClientMessageSentCallbackType = std::function<void(MsgPtr const&, ConnPtr const&)>;
using ClientDisconnectCallbackType = std::function<void(ConnPtr const&)>;
enum class owner { server, client };
static ConnPtr create(Executor executor, int id,
ClientMessageCallbackType message_handle,
ClientMessageSentCallbackType messagesent_handle,
ClientDisconnectCallbackType disconnect_handle)
{
// should be make_shared, but not posisble due to private constructor
return ConnPtr(new Connection(
executor, id, message_handle, messagesent_handle,
disconnect_handle));
}
socket_t& socket() { return socket_; }
int GetId() { return connectionId; }
void SetId(int id) { connectionId = id; }
void accepted()
{
#ifdef VERBOSE_SERVER_DEBUG
{
std::stringstream ss;
ss << "[ Client " << GetId() << " ] "
<< "Connected.";
std::string output = ss.str();
std::cout << output << std::endl;
output = "";
ss.str("");
}
#endif
tcp::no_delay option(true);
socket_.set_option(option);
ReadHeader();
}
void Disconnect(bool cancel, bool shutdown, bool close)
{
qMessagesOut.clear();
if (socket_.is_open() && !alreadyDisconnected.exchange(true)) {
error_code ec;
if (cancel) {
socket_.cancel(ec);
}
if (shutdown) {
socket_.shutdown(socket_t::shutdown_both, ec);
}
if (close && !ec) {
socket_.close(ec);
}
#ifdef VERBOSE_SERVER_DEBUG
std::stringstream ss;
ss << "[ Client " << GetId() << " ] "
<< "Disconnected.";
std::string output = ss.str();
std::cout << output << std::endl;
output = "";
ss.str("");
#endif
if (disconnect_handler) {
disconnect_handler(shared_from_this());
}
}
invalidState = true;
}
bool IsInvalid() const
{
return !socket_.is_open() || invalidState;
}
void Send(MsgPtr msg)
{
if (IsInvalid())
return;
qMessagesOut.push_back(std::move(msg));
if (qMessagesOut.size() == 1) { // SEHE TODO FIXME Race condition?
WriteMessage();
}
}
size_t GetSendBacklog()
{
if (!IsInvalid()) {
return qMessagesOut.count();
}
return 0;
}
~Connection()
{
Disconnect(true, true, true);
}
private:
Connection(Executor executor, int id,
ClientMessageCallbackType message_handle,
ClientMessageSentCallbackType messagesent_handle,
ClientDisconnectCallbackType disconnect_handle
)
: disconnect_handler(disconnect_handle)
, message_handler(message_handle)
, messagesent_handler(messagesent_handle)
, socket_(executor)
, connectionId(id)
{ }
ClientDisconnectCallbackType disconnect_handler;
ClientMessageCallbackType message_handler;
ClientMessageSentCallbackType messagesent_handler;
bool Report([[maybe_unused]] std::string_view caption, error_code ec,
[[maybe_unused]] auto&&... what)
{
if (ec) {
#ifdef VERBOSE_SERVER_DEBUG
std::cout << "[ Client " << GetId() << " ] " << caption
<< " Failed"
<< ec.value() << " - " << ec.message()
<< std::endl;
#endif
invalidState = true;
Disconnect(true, true, true);
return false;
}
#ifdef VERBOSE_SERVER_DEBUG
{
[[maybe_unused]] auto print_arg = [](auto&& v) {
std::cout << " " << v;
return std::cout.good();
};
std::cout << "[ Client " << GetId() << " ] " << caption << " Success";
std::cout << " " << ec.value() << " - " << ec.message();
if ((true && ... && print_arg(what)))
std::cout << std::endl;
}
#endif
return true;
}
void ReadHeader()
{
async_read( //
socket_,
boost::asio::buffer(&tempInMsg.message_header,
sizeof(tempInMsg.message_header)),
[this, self = shared_from_this()](error_code ec, std::size_t) {
if (Report("Read Header", ec)) {
tempInMsg.body.resize(tempInMsg.message_header.size);
ReadBody();
}
});
}
void ReadBody()
{
async_read( //
socket_, boost::asio::buffer(tempInMsg.body),
[this, self = shared_from_this()](error_code ec, std::size_t) {
if (Report("Read Body", ec, tempInMsg.body.data())) {
CommitIncoming();
// Go Back to waiting for header
ReadHeader();
}
});
}
void WriteMessage()
{
// If this function is called, we know the outgoing message queue must
// have at least one message to send. So allocate a transmission buffer
// to hold the message, and issue the work - asio, send these bytes
if (!qMessagesOut.empty()) {
auto message = std::move(qMessagesOut.front());
qMessagesOut.pop_front();
std::vector bufs = {
boost::asio::buffer(&message->message_header,
sizeof(message->message_header)),
boost::asio::buffer(message->body),
};
async_write( //
socket_, bufs,
[this, self = shared_from_this(),
message](error_code ec, std::size_t length) mutable {
if (Report("WriteMessage", ec, "Wrote", length, "bytes")) {
if (!qMessagesOut.empty()) {
// TODO FIXME Race condition?
WriteMessage();
}
}
});
}
}
void CommitIncoming()
{
if (message_handler) {
message_handler(
std::make_shared<Message>(std::move(tempInMsg)),
shared_from_this());
}
tempInMsg.body.clear();
tempInMsg.message_header.size = 0;
}
socket_t socket_;
int connectionId;
Message tempInMsg;
Message tempOutMsg;
std::deque<MsgPtr> qMessagesOut;
std::atomic_bool invalidState = false;
std::atomic_bool alreadyDisconnected = false;
};
#pragma once
#include "prerequisites.h"
#include <span>
template <typename MsgId> struct Message {
[[nodiscard]] size_t size() const { return body.size(); }
MessageHeader<MsgId> message_header{};
std::vector<unsigned char> body;
using View = std::span<unsigned char const>;
using Text = std::string_view;
std::span<unsigned char> Alloc(size_t n) {
auto const offset = body.size();
body.resize(body.size() + n + sizeof(n));
std::span target = body;
target = target.subspan(offset);
memcpy(target.data(), &n, sizeof(n));
message_header.size = static_cast<uint32_t>(size());
return target.subspan(sizeof(n));
}
auto ByteFragments() const {
std::vector<View> fragments;
View remain = body;
for (size_t n = 0; remain.size() > sizeof(n);) {
memcpy(&n, remain.data(), sizeof(n));
remain = remain.subspan(sizeof(n));
if (remain.size() < n)
throw std::runtime_error("Invalid text body");
fragments.emplace_back(remain.subspan(0u, n));
remain = remain.subspan(n);
}
return fragments;
}
auto TextFragments() const {
// Can be optimized, but would duplicate code
auto raw = ByteFragments();
std::vector<std::string_view> fragments(raw.size());
for (size_t i = 0; i < raw.size(); ++i)
fragments[i] = convert(raw[i]);
return fragments;
}
// Override for iostream - friendly description of message
friend std::ostream& operator<<(std::ostream& os, const Message<MsgId>& msg)
{
os << "ID:" << int(msg.message_header.id)
<< " Size:" << msg.message_header.size;
return os;
}
template <typename T> void put(T const& object)
{
static_assert( //
std::is_trivial<T>::value //
&& std::is_standard_layout<T>::value, //
"T is not trivial");
body.resize(sizeof(T));
std::memcpy(body.data(), &object, sizeof(T));
message_header.size = (uint32_t)size();
}
template <typename T> T get() const
{
static_assert( //
std::is_trivial<T>::value //
&& std::is_standard_layout<T>::value, //
"T is not trivial");
if (sizeof(T) != body.size()) // TODO SEHE
throw std::runtime_error("Unexpected message body");
T object;
std::memcpy(&object, body.data(), sizeof(T));
return object;
}
private:
static constexpr Text convert(View from) {
return Text(reinterpret_cast<char const*>(from.data()), from.size());
}
};
#pragma once
#include "prerequisites.h"
template <typename MsgId> struct MessageHeader {
MsgId id{};
uint32_t size = 0;
};
#pragma once
//#define VERBOSE_SERVER_DEBUG
#include <chrono>
#include <deque>
#include <iomanip>
#include <iostream>
#include <memory>
#include <mutex>
#include <random>
#include <sstream>
#include <thread>
#include <utility>
#include <vector>
#include <boost/asio.hpp>
#ifdef _WIN32
// Windows stuff.
#define _CRT_SECURE_NO_WARNINGS
#define NOMINMAX
#include <ShlObj.h>
#include <Shlwapi.h>
#endif
#include <boost/bind/bind.hpp>
#include <boost/signals2.hpp>
// Function pointer called CallbackType that takes a float
// and returns an int
// typedef int (*CallbackType)(float);
template <typename T, typename Executor> class Connection;
template <typename T> struct Message;
#include "MessageHeader.h"
#include "Message.h"
using boost::asio::ip::tcp;
using boost::system::error_code;
using namespace std::chrono_literals;
using Clock = std::chrono::high_resolution_clock;
using Executor = boost::asio::thread_pool::executor_type;
using Timer = boost::asio::basic_waitable_timer<Clock>;
enum class MessageTypes : uint32_t {
ServerAccept,
ServerDeny,
ServerPing,
MessageAll,
SendText,
ServerMessage,
ServerMessage1,
ServerMessage2,
ServerMessage3,
ServerMessage4,
ServerMessage5,
ServerMessage6,
ServerMessage7,
ServerMessage8,
ServerMessage9,
};
#include "Connection.h"
#include "Server.h"
#include "Client.h"
// Server-Test.cpp : This file contains the 'main' function. Program execution
// begins and ends there.
#include "prerequisites.h"
#include <utility>
class MyServer : public Server<Message<MessageTypes>, Executor> {
public:
using Message = ::Message<MessageTypes>;
MyServer(Executor executor, tcp::endpoint ep)
: MyServer::base_type(executor, std::move(ep))
{
}
void OnClientDisconnect(ConnPtr const& remote) override
{
std::cout << "[ Client " << remote->GetId() << " ] Disconnected"
<< std::endl;
}
bool OnClientConnect(ConnPtr const& remote) override
{
std::cout << "[ Client " << remote->GetId() << " ] Connected" << std::endl;
{
auto msg = std::make_shared<Message>();
msg->message_header.id = MessageTypes::ServerAccept;
msg->put(remote->GetId());
remote->Send(std::move(msg));
}
return true;
}
void OnMessage(MsgPtr const& msg, ConnPtr const& remote) override
{
std::cout << "[ Client " << remote->GetId() << " ] ";
if (msg->message_header.id == MessageTypes::SendText) {
auto message = msg->TextFragments().front();
std::cout << "Received Message (lenth:" << message.length() << ")"
<< std::endl;
remote->Send(msg); // fire it back to the client
}
}
void OnMessageSent(MsgPtr const&,
[[maybe_unused]] ConnPtr const& remote) override
{
// std::cout << "[ Client " << remote->GetId() << " ] ";
// std::cout << " Sent Message" << std::endl;
}
};
MyServer* srv;
int messageCount = 0;
Clock::duration previous_time = 0s;
std::atomic_bool stop = false;
boost::asio::thread_pool context;
//Clock::duration highest_time = 0s;
//int total_thread_count = 0;
//size_t max_thread_count = 1;
Timer timer(context, 1s);
void timedBcast(error_code e)
{
// std::cout << "Beginning BCAST..." << std::endl;
Clock::time_point const tStart = Clock::now();
Clock::time_point tPrepared = tStart;
if (!e && !stop) {
if (srv != nullptr) {
if (messageCount >= 1000) {
messageCount = 0;
}
// std::cout << "SENDING BCAST" << std::endl;
// std::string message = "HELLO WORLD TO ALL BROADCAST! ";
// message += std::to_string(messageCount++);
{
auto msg = std::make_shared<Message<MessageTypes>>();
msg->message_header.id = MessageTypes::MessageAll;
auto space = msg->Alloc(rand() % 102400 + 81920);
{
std::fill(begin(space), end(space), 'a');
auto countmsg = " " + std::to_string(messageCount++);
std::copy(begin(countmsg), end(countmsg),
end(space) - countmsg.length());
}
//msg.TransactionId = "Broadcast";
srv->BroadcastMessage(std::move(msg));
}
auto const tDone = Clock::now();
auto const time = tDone - std::exchange(tPrepared, tDone);
auto const time2 = tDone - tStart;
if (time != previous_time && time > 2us) {
// timer += time - 6;
std::cout << "Broadcast took " << time / 1.0us << "μs | "
<< time2 / 1us << "μs" << std::endl;
previous_time = time;
}
auto time_expire = 99ms - time2;
if (time_expire <= 50ms) {
time_expire = 50ms;
}
// Reschedule the timer
timer.expires_from_now(time_expire);
timer.async_wait(timedBcast);
std::cout << "BROADCAST" << std::endl;
}
}
// std::cout << "Exited bcast" << std::endl;
// timer.cancel();
}
int main()
{
messageCount = 1;
srand(time(nullptr));
tcp::endpoint ep{{}, 40000};
srv = new MyServer(context.get_executor(), ep);
std::cout << "Hello World!\n";
timer.expires_from_now(1s);
timer.async_wait(timedBcast);
std::string str;
std::getline(std::cin, str);
srv->interrupt();
// t1.interrupt();
stop = true;
timer.cancel();
context.stop();
context.join();
std::getline(std::cin, str);
delete srv;
return 0;
}
#pragma once
#include "prerequisites.h"
#include <future>
template <typename Message, typename Executor> class Server {
protected:
using base_type = Server<Message, Executor>;
using Strand = boost::asio::strand<Executor>;
using acceptor_t = boost::asio::basic_socket_acceptor<tcp, Strand>;
using conn_t = Connection<Message, Strand>;
using MsgPtr = typename conn_t::MsgPtr;
using ConnPtr = std::shared_ptr<conn_t>;
using WeakConnPtr = std::weak_ptr<conn_t>;
public:
Server(Executor executor, tcp::endpoint endpoint)
: executor_(executor)
{
acceptor_.open(endpoint.protocol());
acceptor_.set_option(tcp::acceptor::reuse_address(true));
acceptor_.set_option(tcp::acceptor::do_not_route(true));
acceptor_.set_option(tcp::acceptor::keep_alive(false));
acceptor_.set_option(tcp::acceptor::enable_connection_aborted(false));
acceptor_.set_option(tcp::acceptor::linger(false, 3));
acceptor_.bind(endpoint);
acceptor_.listen();
start_accept();
}
void interrupt()
{
shutdownBegan = true;
post(strand_, [this] {
acceptor_.cancel();
acceptor_.close();
for (const auto& [id, handle] : connections) {
if (auto conn = handle.lock())
conn->Disconnect(true, true, true);
}
connections.clear();
shutdownCompleted = true;
});
}
std::future<size_t> CalculateAverageBacklog()
{
std::packaged_task<size_t()> task([this]() -> size_t {
size_t total = 0;
size_t count = 0;
for (const auto& [id, handle] : connections) {
if (auto conn = handle.lock()) {
if (!conn->IsInvalid()) {
size_t backlog = conn->GetSendBacklog();
total += backlog;
count++;
}
}
}
if (count) {
auto average = 1.0 * total / count;
return average;
}
return 0;
});
post(strand_,task);
return task.get_future();
}
void BroadcastMessage(MsgPtr msg)
{
post(strand_, [this, msg = std::move(msg)] {
for (const auto& kvp : connections) {
post(executor_, [this, handle = kvp.second, msg] {
if (auto conn = handle.lock()) {
if (!conn->IsInvalid()) {
conn->Send(std::move(msg));
}
}
});
}
});
}
virtual ~Server() = default; // important for `delete` on derived classes!
protected:
// This server class should override thse functions to implement
// customised functionality
// Called when a client connects, you can veto the connection by returning
// false
virtual bool OnClientConnect(ConnPtr const& /*remote*/) { return false; }
// Called when a client appears to have disconnected
virtual void OnClientDisconnect(ConnPtr const& /*remote*/) { }
// Called when a message arrives
virtual void OnMessage(MsgPtr const& /*message*/, ConnPtr const&) { }
virtual void OnMessageSent(MsgPtr const& /*message*/, ConnPtr const&) { }
private:
void client_disconnected(ConnPtr const& connection)
{
OnClientDisconnect(connection);
}
void client_message(MsgPtr const& message, ConnPtr const& conn)
{
OnMessage(message, conn);
}
void message_sent(MsgPtr const& message, ConnPtr const& conn)
{
OnMessageSent(message, conn);
}
void start_accept()
{
using boost::placeholders::_1;
using boost::placeholders::_2;
if (!shutdownBegan && acceptor_.is_open()) {
auto new_connection = conn_t::create(
make_strand(executor_),
connectionIds++,
boost::bind(&Server::client_message, this, _1, _2),
boost::bind(&Server::message_sent, this, _1, _2),
boost::bind(&Server::client_disconnected, this, _1));
acceptor_.async_accept(
new_connection->socket(),
boost::bind(&Server::handle_accept, this, new_connection,
boost::asio::placeholders::error));
}
}
void handle_accept(ConnPtr new_connection, error_code error)
{
if (!error && !shutdownBegan) {
start_accept();
using boost::placeholders::_1;
if (OnClientConnect(new_connection)) {
new_connection->accepted();
addConnection(std::move(new_connection));
} else {
#ifdef VERBOSE_SERVER_DEBUG
std::cout << "[ Client " << new_connection->GetId()
<< " ] Connection Denied." << std::endl;
#endif
}
} else if (!shutdownBegan) {
start_accept();
#ifdef VERBOSE_SERVER_DEBUG
std::cout << "[ SERVER ] New connection error: " << error.message()
<< std::endl;
#endif
}
}
void addConnection(ConnPtr connection)
{
if (!shutdownBegan) {
post(strand_, [this, conn = std::move(connection)]() mutable {
connections.emplace(conn->GetId(), std::move(conn));
// garbage collect connections
// c++20, otherwise clumsy iterator loop
std::erase_if(connections,
[](auto& kvp) { return kvp.second.expired(); });
});
}
}
Executor executor_;
Strand strand_ = make_strand(executor_);
acceptor_t acceptor_{strand_};
std::map<int, WeakConnPtr> connections;
std::atomic_bool shutdownBegan{false};
std::atomic_bool shutdownCompleted{false};
int connectionIds{10'000};
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment