Skip to content

Instantly share code, notes, and snippets.

@CaryLorrk
Last active May 8, 2018 16:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CaryLorrk/12726181cd3ebd79a1c6ccdedeb8c29f to your computer and use it in GitHub Desktop.
Save CaryLorrk/12726181cd3ebd79a1c6ccdedeb8c29f to your computer and use it in GitHub Desktop.
Linux Socket Performance
#ifndef WOOPS_UTIL_COMM_COMM_H_
#define WOOPS_UTIL_COMM_COMM_H_
#include <vector>
#include <map>
#include <queue>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <sys/epoll.h>
#include "util/typedef.h"
namespace woops
{
using FileDesc = int;
using MsgSize = uint32_t;
class Comm
{
public:
void Initialize();
void Update(Hostid server, Tableid id, Iteration iteration, const Bytes& bytes);
void Pull(Hostid server, Tableid id, Iteration iteration);
void Push(Hostid client, Tableid id, Iteration iteration, const Bytes& bytes);
void SyncStorage(Hostid host, Tableid id, const Bytes& bytes);
void SyncPlacement();
void Barrier();
enum class Command: uint32_t {
UPDATE,
PULL,
PUSH,
SYNC_STORAGE,
SYNC_STORAGE_RES,
SYNC_PLACEMENT,
SYNC_PLACEMENT_RES,
BARRIER_NOTIFY,
};
private:
std::vector<FileDesc> sockfds_;
std::map<std::string, Hostid> ip_to_host_;
std::map<FileDesc, Hostid> sockfd_to_host_;
// init
void init_ip_to_host();
void server_for_connections_func();
void accept_for_connections(FileDesc server_sockfd);
void client_for_connections();
// receiver
std::thread receiver_thread;
void receiver_func();
void set_host_events(FileDesc epollfd, std::vector<epoll_event>& host_events);
size_t read_msg_size(FileDesc sockfd);
void read_msg(FileDesc sockfd, size_t msg_size);
void dispatch(Hostid host, Bytes byts);
// sender
struct SendData {
Hostid host;
Bytes msgbytes;
SendData() = default;
SendData(Hostid in_host, Bytes&& in_msgbytes):
host(in_host),
msgbytes(std::move(in_msgbytes)) {}
};
std::thread sender_thread;
void sender_func();
void send_message(Hostid host, Bytes&& msgbytes);
std::mutex sender_mu_;
std::condition_variable sender_cv_;
std::queue<SendData> sender_queue_;
// sync_storage
void sync_storage_res(Hostid host);
std::mutex sync_storage_mu_;
std::condition_variable sync_storage_cv_;
// sync_placment
void sync_placement_res(Hostid host, const Bytes& bytes);
std::mutex sync_placement_mu_;
std::condition_variable sync_placement_cv_;
// barrier
void barrier_notify(Hostid host);
std::mutex barrier_mu_;
std::condition_variable barrier_cv_;
int barrier_cnt_;
// handler
void update_handler(Hostid host, Bytes& bytes);
void pull_handler(Hostid client, Bytes& bytes);
void push_handler(Hostid server, Bytes& bytes);
void sync_storage_handler(Hostid host, Bytes& bytes);
void sync_storage_res_handler();
void sync_placement_handler(Hostid host);
void sync_placement_res_handler(Bytes& bytes);
void barrier_notify_handler();
};
} /* woops */
#endif /* end of include guard: WOOPS_UTIL_COMM_COMM_H_ */
#include "comm.h"
#include <cstring>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <sys/epoll.h>
#include "lib.h"
#include "util/logging.h"
using namespace std::chrono_literals;
namespace woops
{
static void *get_in_addr(struct sockaddr *sa);
static FileDesc bind_for_connections();
static void listen_for_connections(FileDesc server_sockfd);
void Comm::Initialize() {
sockfds_.resize(Lib::NumHosts());
init_ip_to_host();
std::thread server_for_connections_thread(&Comm::server_for_connections_func, this);
client_for_connections();
server_for_connections_thread.join();
receiver_thread = std::thread(&Comm::receiver_func, this);
sender_thread = std::thread(&Comm::sender_func, this);
std::this_thread::sleep_for(10ms);
}
void Comm::init_ip_to_host() {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_flags = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
addrinfo *hostinfo;
addrinfo *p;
int rv;
char ip[INET6_ADDRSTRLEN];
for (Hostid host = 0; host < Lib::NumHosts(); ++host) {
if ((rv = getaddrinfo(Lib::Hostname(host).c_str(),
Lib::Port().c_str(), &hints, &hostinfo)) != 0) {
LOG(FATAL) << "getaddrinfo: " << gai_strerror(rv);
}
for(p = hostinfo; p != NULL; p = p->ai_next) {
inet_ntop(p->ai_family, get_in_addr(p->ai_addr), ip, sizeof(ip));
ip_to_host_[ip] = host;
}
freeaddrinfo(hostinfo);
}
}
void Comm::server_for_connections_func() {
FileDesc server_sockfd = bind_for_connections();
listen_for_connections(server_sockfd);
LOG(INFO) << "waiting for connections...";
accept_for_connections(server_sockfd);
close(server_sockfd);
}
void Comm::accept_for_connections(FileDesc server_sockfd) {
int client_sockfd;
socklen_t sin_size;
sockaddr_storage their_addr;
char ip[INET6_ADDRSTRLEN];
for(int cnt = 0; cnt < Lib::ThisHost(); ++cnt) {
sin_size = sizeof(their_addr);
client_sockfd = accept(server_sockfd,
(struct sockaddr *)&their_addr, &sin_size);
if (client_sockfd == -1) {
LOG(ERROR);
perror("accept");
exit(1);
}
inet_ntop(their_addr.ss_family,
get_in_addr((struct sockaddr *)&their_addr),
ip, sizeof(ip));
Hostid host = ip_to_host_[ip];
sockfds_[host] = client_sockfd;
sockfd_to_host_[client_sockfd] = host;
LOG(INFO) << "got connection from " << ip;
}
}
void Comm::client_for_connections() {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_flags = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
FileDesc sockfd;
int rv;
addrinfo *serverinfo;
addrinfo *p;
for (Hostid host = Lib::ThisHost() + 1; host < Lib::NumHosts(); ++host) {
std::string hostname = Lib::Hostname(host);
if ((rv = getaddrinfo(hostname.c_str(),
Lib::Port().c_str(), &hints, &serverinfo)) != 0) {
LOG(FATAL) << "getaddrinfo: " << gai_strerror(rv);
}
int yes=1;
while (true) {
for(p = serverinfo; p != NULL; p = p->ai_next) {
if ((sockfd = socket(p->ai_family, p->ai_socktype,
p->ai_protocol)) == -1) {
LOG(ERROR);
perror("socket");
continue;
}
if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &yes,
sizeof(int)) == -1) {
LOG(ERROR);
perror("setsockopt");
exit(1);
}
if (connect(sockfd, p->ai_addr, p->ai_addrlen) == -1) {
LOG(ERROR);
perror("connect");
close(sockfd);
continue;
}
break;
}
if (p == NULL) {
LOG(WARNING) << "failed to connect to " << hostname;
std::this_thread::sleep_for(1s);
} else {
LOG(INFO) << "connect to " << hostname;
break;
}
}
sockfds_[host] = sockfd;
sockfd_to_host_[sockfd] = host;
freeaddrinfo(serverinfo);
}
}
void *get_in_addr(struct sockaddr *sa)
{
if (sa->sa_family == AF_INET) {
return &(((struct sockaddr_in*)sa)->sin_addr);
}
return &(((struct sockaddr_in6*)sa)->sin6_addr);
}
FileDesc bind_for_connections() {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_flags = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
addrinfo *serverinfo;
int rv;
if ((rv = getaddrinfo(NULL, Lib::Port().c_str(), &hints, &serverinfo)) != 0) {
LOG(FATAL) << "getaddrinfo: " << gai_strerror(rv);
}
FileDesc sockfd;
addrinfo *p;
for(p = serverinfo; p != NULL; p = p->ai_next) {
if ((sockfd = socket(p->ai_family, p->ai_socktype,
p->ai_protocol)) == -1) {
LOG(ERROR);
perror("socket");
continue;
}
int yes = 1;
if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1) {
LOG(ERROR);
perror("setsockopt");
exit(1);
}
if (bind(sockfd, p->ai_addr, p->ai_addrlen) == -1) {
close(sockfd);
LOG(ERROR);
perror("bind");
continue;
}
break;
}
if (p == NULL) {
LOG(FATAL) << "failed to bind";
}
freeaddrinfo(serverinfo); // all done with this structure
return sockfd;
}
void listen_for_connections(FileDesc server_sockfd) {
if (listen(server_sockfd, Lib::NumHosts()) == -1) {
LOG(ERROR);
perror("listen");
exit(1);
}
}
} /* woops */
#ifndef WOOPS_UTIL_COMM_SERIALIZE_H_
#define WOOPS_UTIL_COMM_SERIALIZE_H_
#include "comm.h"
#include "util/typedef.h"
#include "util/logging.h"
namespace woops
{
MAYBE_UNUSED static Bytes& operator<<(Bytes& msgbytes, const Bytes& val) {
return msgbytes.append(val);
}
template<typename T>
Bytes& operator<<(Bytes& msgbytes, T& val) {
return msgbytes.append((Byte*)&val, (Byte*)(&val + 1));
}
template<typename T>
void append_vals_to_msgbytes(Bytes& msgbytes, T& val) {
msgbytes << val;
}
template<typename T, typename... Ts>
void append_vals_to_msgbytes(Bytes& msgbytes, T& val, Ts&... ts) {
msgbytes << val;
append_vals_to_msgbytes(msgbytes, ts...);
}
template<typename... Ts>
Bytes serialize(Comm::Command cmd, Ts&... ts) {
Bytes msgbytes(sizeof(MsgSize), 0);
append_vals_to_msgbytes(msgbytes, cmd, ts...);
MsgSize msg_size = msgbytes.size();
std::copy((Byte*)(&msg_size), (Byte*)(&msg_size + 1), msgbytes.begin());
return msgbytes;
}
MAYBE_UNUSED static void deserialize(
Bytes::const_iterator first,
Bytes::const_iterator last,
Bytes& val) {
val.append(first, last);
}
template<typename T>
void deserialize(
Bytes::const_iterator first,
MAYBE_UNUSED Bytes::const_iterator last,
T& val) {
val = *reinterpret_cast<const T*>(&*first);
}
template<typename T, typename... Ts>
void deserialize(
Bytes::const_iterator first,
Bytes::const_iterator last,
T& val, Ts&... ts) {
val = *reinterpret_cast<const T*>(&*first);
std::advance(first, sizeof(T));
deserialize(first, last, ts...);
}
template<typename... Ts>
void deserialize(const Bytes msgbytes, Ts&... ts) {
auto first = std::next(msgbytes.begin(), sizeof(MsgSize) + sizeof(Comm::Command));
deserialize(first, msgbytes.end(), ts...);
}
MAYBE_UNUSED static Comm::Command deserialize_cmd(const Bytes msgbytes) {
auto it = std::next(msgbytes.begin(), sizeof(MsgSize));
return *reinterpret_cast<const Comm::Command*>(&*it);
}
} /* woops */
#endif /* end of include guard: WOOPS_UTIL_COMM_SERIALIZE_H_ */
#include "comm.h"
#include <sys/socket.h>
#include "lib.h"
#include "serialize.h"
namespace woops
{
void Comm::sender_func() {
while(true) {
SendData data;
{
std::unique_lock<std::mutex> lock(sender_mu_);
sender_cv_.wait(lock, [this] {
return !sender_queue_.empty();
});
data = std::move(sender_queue_.front());
sender_queue_.pop();
}
size_t pkt_size = data.msgbytes.size();
int numbytes = 0;
do {
numbytes += send(sockfds_[data.host], data.msgbytes.data() + numbytes, pkt_size - numbytes, 0);
if (numbytes < 0) {
LOG(ERROR);
perror("send");
exit(1);
}
} while((size_t)numbytes < pkt_size);
}
}
void Comm::send_message(Hostid host, Bytes&& msgbytes) {
{
std::lock_guard<std::mutex> lock(sender_mu_);
sender_queue_.emplace(host, std::move(msgbytes));
}
sender_cv_.notify_one();
}
} /* woops */
#include "comm.h"
#include <sys/types.h>
#include <sys/socket.h>
#include "lib.h"
#include "serialize.h"
#include "util/logging.h"
namespace woops
{
void Comm::receiver_func() {
FileDesc epollfd = epoll_create1(0);
if (epollfd < 0) {
LOG(ERROR);
perror("epoll_create1");
exit(1);
}
std::vector<epoll_event> host_events(Lib::NumHosts());
set_host_events(epollfd, host_events);
int maxevent = Lib::NumHosts();
std::vector<epoll_event> avail_events(maxevent);
while(true) {
int nfds = epoll_wait(epollfd, avail_events.data(), maxevent, -1);
if (nfds < 0) {
LOG(ERROR);
perror("epoll_wait");
exit(1);
}
for (int n = 0; n < nfds; ++n) {
FileDesc sockfd = avail_events[n].data.fd;
size_t msg_size = read_msg_size(sockfd);
read_msg(sockfd, msg_size);
}
}
}
void Comm::set_host_events(FileDesc epollfd, std::vector<epoll_event>& host_events) {
int rv;
for (Hostid host = 0; host < Lib::NumHosts(); ++host) {
if (host == Lib::ThisHost()) continue;
FileDesc sockfd = sockfds_[host];
epoll_event& ev = host_events[host];
ev.data.fd = sockfd;
ev.events = EPOLLIN;
if ((rv = epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev)) < 0) {
LOG(ERROR);
perror("epoll_ctl");
exit(1);
}
}
}
size_t Comm::read_msg_size(FileDesc sockfd) {
Byte buffer[sizeof(MsgSize)];
memset(buffer, '\0', sizeof(MsgSize));
int numbytes;
do {
numbytes = recv(sockfd, buffer, sizeof(MsgSize), MSG_PEEK);
if (numbytes < 0) {
LOG(ERROR);
perror("recv");
exit(1);
} else if (numbytes == 0) {
LOG(INFO) << Lib::Hostname(sockfd_to_host_[sockfd]) << " disconncted";
exit(0);
}
} while(numbytes < 4);
return *(MsgSize*)&buffer;
}
void Comm::read_msg(FileDesc sockfd, size_t msg_size) {
Hostid host = sockfd_to_host_[sockfd];
Bytes msgbytes(msg_size, 0);
int numbytes = 0;
do {
numbytes += recv(sockfd, &msgbytes[0] + numbytes, msg_size - numbytes, 0);
if(numbytes < 0){
LOG(ERROR);
perror("recv");
exit(1);
} else if (numbytes == 0) {
LOG(INFO) << Lib::Hostname(host) << " disconncted";
exit(0);
}
} while((size_t)numbytes < msg_size);
std::thread(&Comm::dispatch, this, host, std::move(msgbytes)).detach();
}
void Comm::dispatch(Hostid host, Bytes msgbytes) {
Command cmd = deserialize_cmd(msgbytes);
switch(cmd) {
case Command::UPDATE:
update_handler(host, msgbytes);
break;
case Command::PULL:
pull_handler(host, msgbytes);
break;
case Command::PUSH:
push_handler(host, msgbytes);
break;
case Command::SYNC_STORAGE:
sync_storage_handler(host, msgbytes);
break;
case Command::SYNC_STORAGE_RES:
sync_storage_res_handler();
break;
case Command::SYNC_PLACEMENT:
sync_placement_handler(host);
break;
case Command::SYNC_PLACEMENT_RES:
sync_placement_res_handler(msgbytes);
break;
case Command::BARRIER_NOTIFY:
barrier_notify_handler();
break;
default:
LOG(FATAL) << "unknown command";
}
}
} /* woops */
#include "comm.h"
#include "lib.h"
#include "serialize.h"
namespace woops
{
void Comm::Update(Hostid server, Tableid id, Iteration iteration, const Bytes& bytes) {
if (server == Lib::ThisHost()) {
std::thread(&Server::Update, Lib::Server(), Lib::ThisHost(), id, iteration, bytes).detach();
return;
}
send_message(server, serialize(Command::UPDATE, id, iteration, bytes));
}
void Comm::Pull(Hostid server, Tableid id, Iteration iteration) {
if (server == Lib::ThisHost()) {
std::thread([id, iteration](){
Iteration iter = iteration;
Bytes data = Lib::Server()->GetParameter(Lib::ThisHost(), id, iter);
Lib::Comm()->Push(Lib::ThisHost(), id, iter, data);
}).detach();
return;
}
send_message(server, serialize(Command::PULL, id, iteration));
}
void Comm::Push(Hostid client, Tableid id, Iteration iteration, const Bytes& bytes) {
if (client == Lib::ThisHost()) {
std::thread(&Client::ServerUpdate, Lib::Client(), Lib::ThisHost(), id, iteration, bytes).detach();
return;
}
send_message(client, serialize(Command::PUSH, id, iteration, bytes));
}
void Comm::SyncStorage(Hostid host, Tableid id, const Bytes& bytes) {
send_message(host, serialize(Command::SYNC_STORAGE, id, bytes));
std::unique_lock<std::mutex> lock(sync_storage_mu_);
sync_storage_cv_.wait(lock);
}
void Comm::sync_storage_res(Hostid host) {
send_message(host, serialize(Command::SYNC_STORAGE_RES));
}
void Comm::SyncPlacement() {
send_message(0, serialize(Command::SYNC_PLACEMENT));
std::unique_lock<std::mutex> lock(sync_placement_mu_);
sync_placement_cv_.wait(lock);
}
void Comm::sync_placement_res(Hostid host, const Bytes& bytes) {
send_message(host, serialize(Command::SYNC_PLACEMENT_RES, bytes));
}
void Comm::Barrier() {
std::unique_lock<std::mutex> lock(barrier_mu_);
if (Lib::ThisHost() == 0) {
barrier_cv_.wait(lock, [this]{return barrier_cnt_ >= Lib::NumHosts() - 1;});
barrier_cnt_ = 0;
for (Hostid host = 1; host < Lib::NumHosts(); ++host) {
barrier_notify(host);
}
return;
}
barrier_notify(0);
barrier_cv_.wait(lock, [this]{return barrier_cnt_;});
barrier_cnt_ = 0;
}
void Comm::barrier_notify(Hostid host) {
Command cmd = Command::BARRIER_NOTIFY;
send_message(host, serialize(cmd));
}
} /* woops */
#include "comm.h"
#include "lib.h"
#include "serialize.h"
namespace woops
{
void Comm::update_handler(Hostid client, Bytes& msgbytes) {
Tableid id;
Iteration iteration;
Bytes bytes;
deserialize(msgbytes, id, iteration, bytes);
Lib::Server()->Update(client, id, iteration, bytes);
}
void Comm::pull_handler(Hostid client, Bytes& msgbytes) {
Tableid id;
Iteration iteration;
deserialize(msgbytes, id, iteration);
Bytes bytes = Lib::Server()->GetParameter(client, id, iteration);
Push(client, id, iteration, bytes);
}
void Comm::push_handler(Hostid server, Bytes& msgbytes) {
Tableid id;
Iteration iteration;
Bytes bytes;
deserialize(msgbytes, id, iteration, bytes);
Lib::Client()->ServerUpdate(server, id, iteration, bytes);
}
void Comm::sync_storage_handler(Hostid host, Bytes& msgbytes) {
Tableid id;
Bytes bytes;
deserialize(msgbytes, id, bytes);
Lib::Client()->ServerSyncStorage(id, bytes);
sync_storage_res(host);
}
void Comm::sync_storage_res_handler() {
sync_storage_cv_.notify_all();
}
void Comm::sync_placement_handler(Hostid host) {
sync_placement_res(host, Lib::Placement()->Serialize());
}
void Comm::sync_placement_res_handler(Bytes& msgbytes) {
Bytes bytes;
deserialize(msgbytes, bytes);
Lib::Placement()->Deserialize(bytes);
sync_placement_cv_.notify_all();
}
void Comm::barrier_notify_handler() {
std::lock_guard<std::mutex> lock(barrier_mu_);
barrier_cnt_ += 1;
barrier_cv_.notify_all();
}
} /* woops */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment